|
@@ -421,31 +421,80 @@ def get_model_dependency_graph(table_names: list) -> dict:
|
|
|
driver = GraphDatabase.driver(uri, auth=auth)
|
|
|
try:
|
|
|
with driver.session() as session:
|
|
|
- for table_name in table_names:
|
|
|
- # 修改查询,移除对节点类型的限制,但保留对表名集合的过滤
|
|
|
- query = """
|
|
|
- MATCH (t {en_name: $table_name})-[:DERIVED_FROM]->(up)
|
|
|
- WHERE up.en_name IN $all_tables
|
|
|
- RETURN up.en_name AS upstream
|
|
|
- """
|
|
|
- logger.info(f"执行Neo4j查询: 查找 {table_name} 在当前批次中的上游依赖")
|
|
|
- result = session.run(query, table_name=table_name, all_tables=table_names)
|
|
|
- deps = [record['upstream'] for record in result if 'upstream' in record]
|
|
|
- logger.info(f"表 {table_name} 的上游依赖(当前批次内): {deps}")
|
|
|
+ # 使用一次性查询获取所有表之间的依赖关系
|
|
|
+ # 注意:这里查询的是 A-[:DERIVED_FROM]->B 关系,表示A依赖B
|
|
|
+
|
|
|
+ # 记录原始查询参数用于调试
|
|
|
+ logger.info(f"查询参数 table_names: {table_names}, 类型: {type(table_names)}")
|
|
|
+
|
|
|
+ # 第一层查询 - 更明确的查询形式
|
|
|
+ query = """
|
|
|
+ MATCH (source)-[r:DERIVED_FROM]->(target)
|
|
|
+ WHERE source.en_name IN $table_names AND target.en_name IN $table_names
|
|
|
+ RETURN source.en_name AS source, target.en_name AS target, r.script_name AS script_name
|
|
|
+ """
|
|
|
+ logger.info(f"执行Neo4j查询: 查找所有表之间的依赖关系")
|
|
|
+ result = session.run(query, table_names=table_names)
|
|
|
+
|
|
|
+ # 转换结果为列表,确保结果被消费
|
|
|
+ result_records = list(result)
|
|
|
+ logger.info(f"第一层查询返回记录数: {len(result_records)}")
|
|
|
+
|
|
|
+ # 处理依赖关系
|
|
|
+ found_deps = 0
|
|
|
+ # 初始化依赖字典
|
|
|
+ dependency_dict = {name: [] for name in table_names}
|
|
|
+
|
|
|
+ # 这里是问题所在 - 需要正确处理记录
|
|
|
+ for record in result_records:
|
|
|
+ # 直接将记录转换为字典,避免访问问题
|
|
|
+ record_dict = dict(record)
|
|
|
|
|
|
- # 同时查询所有上游依赖(不限于当前批次),用于日志记录
|
|
|
- all_deps_query = """
|
|
|
- MATCH (t {en_name: $table_name})-[:DERIVED_FROM]->(up)
|
|
|
- RETURN up.en_name AS upstream
|
|
|
- """
|
|
|
- all_deps_result = session.run(all_deps_query, table_name=table_name)
|
|
|
- all_deps = [record['upstream'] for record in all_deps_result if 'upstream' in record]
|
|
|
- logger.info(f"表 {table_name} 的所有上游依赖: {all_deps}")
|
|
|
+ # 从字典中获取值
|
|
|
+ source = record_dict.get('source')
|
|
|
+ target = record_dict.get('target')
|
|
|
+ script_name = record_dict.get('script_name', 'unknown_script')
|
|
|
+
|
|
|
+ # 确保字段存在且有值
|
|
|
+ if source and target:
|
|
|
+ logger.info(f"发现依赖关系: {source} -[:DERIVED_FROM]-> {target}, 脚本: {script_name}")
|
|
|
+
|
|
|
+ # 添加依赖关系到字典
|
|
|
+ if source in dependency_dict:
|
|
|
+ dependency_dict[source].append(target)
|
|
|
+ found_deps += 1
|
|
|
+
|
|
|
+ # 添加边到图 - 把被依赖方指向依赖方,表示执行顺序(被依赖方先执行)
|
|
|
+ G.add_edge(target, source)
|
|
|
+ logger.info(f"添加执行顺序边: {target} -> {source} (因为{source}依赖{target})")
|
|
|
+
|
|
|
+ logger.info(f"总共发现 {found_deps} 个依赖关系")
|
|
|
+
|
|
|
+ # 如果没有找到依赖关系,尝试检查所有可能的表对关系
|
|
|
+ if found_deps == 0:
|
|
|
+ logger.warning("仍未找到依赖关系,尝试检查所有表对之间的关系")
|
|
|
+ logger.info("第三层查询: 开始表对之间的循环检查")
|
|
|
+ logger.info(f"要检查的表对数量: {len(table_names) * (len(table_names) - 1)}")
|
|
|
|
|
|
- # 添加依赖边
|
|
|
- for dep in deps:
|
|
|
- logger.info(f"添加依赖边: {dep} -> {table_name}")
|
|
|
- G.add_edge(dep, table_name)
|
|
|
+ pair_count = 0
|
|
|
+ for source_table in table_names:
|
|
|
+ for target_table in table_names:
|
|
|
+ if source_table != target_table:
|
|
|
+ pair_count += 1
|
|
|
+ logger.info(f"检查表对[{pair_count}]: {source_table} -> {target_table}")
|
|
|
+
|
|
|
+ check_result = check_table_relationship(source_table, target_table)
|
|
|
+
|
|
|
+ # 检查forward方向的关系
|
|
|
+ if 'forward' in check_result and check_result['forward']['exists']:
|
|
|
+ script_name = check_result['forward'].get('script_name', 'unknown_script')
|
|
|
+ logger.info(f"表对检查发现关系: {source_table} -[:DERIVED_FROM]-> {target_table}, 脚本: {script_name}")
|
|
|
+
|
|
|
+ dependency_dict[source_table].append(target_table)
|
|
|
+ G.add_edge(target_table, source_table)
|
|
|
+ found_deps += 1
|
|
|
+
|
|
|
+ logger.info(f"表对检查后找到 {found_deps} 个依赖关系")
|
|
|
finally:
|
|
|
driver.close()
|
|
|
|
|
@@ -457,15 +506,14 @@ def get_model_dependency_graph(table_names: list) -> dict:
|
|
|
except Exception as e:
|
|
|
logger.error(f"检查循环依赖失败: {str(e)}")
|
|
|
|
|
|
- # 转换为字典格式返回
|
|
|
- dependency_dict = {}
|
|
|
+ # 将图转换为字典格式
|
|
|
+ final_dependency_dict = {}
|
|
|
for table_name in table_names:
|
|
|
- predecessors = list(G.predecessors(table_name))
|
|
|
- dependency_dict[table_name] = predecessors
|
|
|
- logger.info(f"最终依赖关系 - 表 {table_name} 依赖于: {predecessors}")
|
|
|
+ final_dependency_dict[table_name] = dependency_dict.get(table_name, [])
|
|
|
+ logger.info(f"最终依赖关系 - 表 {table_name} 依赖于: {final_dependency_dict[table_name]}")
|
|
|
|
|
|
- logger.info(f"完整依赖图: {dependency_dict}")
|
|
|
- return dependency_dict
|
|
|
+ logger.info(f"完整依赖图: {final_dependency_dict}")
|
|
|
+ return final_dependency_dict
|
|
|
|
|
|
|
|
|
def generate_optimized_execution_order(table_names: list) -> list:
|
|
@@ -630,4 +678,212 @@ def check_table_relationship(table1, table2):
|
|
|
finally:
|
|
|
driver.close()
|
|
|
|
|
|
- return relationship_info
|
|
|
+ return relationship_info
|
|
|
+
|
|
|
+def build_model_dependency_dag(table_names, model_tables):
|
|
|
+ """
|
|
|
+ 基于表名列表构建模型依赖DAG,返回优化后的执行顺序和依赖关系图
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ table_names: 表名列表
|
|
|
+ model_tables: 表配置列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tuple: (优化后的表执行顺序, 依赖关系图)
|
|
|
+ """
|
|
|
+ # 使用优化函数生成执行顺序,可以处理循环依赖
|
|
|
+ optimized_table_order = generate_optimized_execution_order(table_names)
|
|
|
+ logger.info(f"生成优化执行顺序, 共 {len(optimized_table_order)} 个表")
|
|
|
+
|
|
|
+ # 获取依赖图
|
|
|
+ dependency_graph = get_model_dependency_graph(table_names)
|
|
|
+ logger.info(f"构建了 {len(dependency_graph)} 个表的依赖关系图")
|
|
|
+
|
|
|
+ return optimized_table_order, dependency_graph
|
|
|
+
|
|
|
+
|
|
|
+def create_task_dict(optimized_table_order, model_tables, dag, execution_type, **task_options):
|
|
|
+ """
|
|
|
+ 根据优化后的表执行顺序创建任务字典
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ optimized_table_order: 优化后的表执行顺序
|
|
|
+ model_tables: 表配置列表
|
|
|
+ dag: Airflow DAG对象
|
|
|
+ execution_type: 执行类型(daily, monthly等)
|
|
|
+ task_options: 任务创建的额外选项
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ dict: 任务字典 {表名: 任务对象}
|
|
|
+ """
|
|
|
+ from airflow.operators.python import PythonOperator
|
|
|
+
|
|
|
+ task_dict = {}
|
|
|
+ for table_name in optimized_table_order:
|
|
|
+ # 获取表的配置信息
|
|
|
+ table_config = next((t for t in model_tables if t['table_name'] == table_name), None)
|
|
|
+ if table_config:
|
|
|
+ try:
|
|
|
+ # 构建基础参数
|
|
|
+ task_params = {
|
|
|
+ "task_id": f"process_{execution_type}_{table_name}",
|
|
|
+ "python_callable": run_model_script,
|
|
|
+ "op_kwargs": {"table_name": table_name, "execution_mode": table_config['execution_mode']},
|
|
|
+ "dag": dag
|
|
|
+ }
|
|
|
+
|
|
|
+ # 添加额外选项
|
|
|
+ if task_options:
|
|
|
+ # 如果有表特定的选项,使用它们
|
|
|
+ if table_name in task_options:
|
|
|
+ task_params.update(task_options[table_name])
|
|
|
+ # 如果有全局选项,使用它们
|
|
|
+ elif 'default' in task_options:
|
|
|
+ task_params.update(task_options['default'])
|
|
|
+
|
|
|
+ task = PythonOperator(**task_params)
|
|
|
+ task_dict[table_name] = task
|
|
|
+ logger.info(f"创建模型处理任务: {task_params['task_id']}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"创建任务 process_{execution_type}_{table_name} 时出错: {str(e)}")
|
|
|
+ raise
|
|
|
+ return task_dict
|
|
|
+
|
|
|
+
|
|
|
+def build_task_dependencies(task_dict, dependency_graph):
|
|
|
+ """
|
|
|
+ 根据依赖图设置任务间的依赖关系
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ task_dict: 任务字典
|
|
|
+ dependency_graph: 依赖关系图
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tuple: (tasks_with_upstream, tasks_with_downstream, dependency_count)
|
|
|
+ """
|
|
|
+ tasks_with_upstream = set() # 用于跟踪已经有上游任务的节点
|
|
|
+ dependency_count = 0
|
|
|
+
|
|
|
+ for target, upstream_list in dependency_graph.items():
|
|
|
+ if target in task_dict:
|
|
|
+ for upstream in upstream_list:
|
|
|
+ if upstream in task_dict:
|
|
|
+ logger.info(f"建立任务依赖: {upstream} >> {target}")
|
|
|
+ task_dict[upstream] >> task_dict[target]
|
|
|
+ tasks_with_upstream.add(target) # 记录此任务已有上游
|
|
|
+ dependency_count += 1
|
|
|
+
|
|
|
+ # 找出有下游任务的节点
|
|
|
+ tasks_with_downstream = set()
|
|
|
+ for target, upstream_list in dependency_graph.items():
|
|
|
+ if target in task_dict: # 目标任务在当前DAG中
|
|
|
+ for upstream in upstream_list:
|
|
|
+ if upstream in task_dict: # 上游任务也在当前DAG中
|
|
|
+ tasks_with_downstream.add(upstream) # 这个上游任务有下游
|
|
|
+
|
|
|
+ logger.info(f"总共建立了 {dependency_count} 个任务之间的依赖关系")
|
|
|
+ logger.info(f"已有上游任务的节点: {tasks_with_upstream}")
|
|
|
+
|
|
|
+ return tasks_with_upstream, tasks_with_downstream, dependency_count
|
|
|
+
|
|
|
+
|
|
|
+def connect_start_and_end_tasks(task_dict, tasks_with_upstream, tasks_with_downstream,
|
|
|
+ wait_task, completed_task, dag_type):
|
|
|
+ """
|
|
|
+ 连接开始节点到等待任务,末端节点到完成标记
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ task_dict: 任务字典
|
|
|
+ tasks_with_upstream: 有上游任务的节点集合
|
|
|
+ tasks_with_downstream: 有下游任务的节点集合
|
|
|
+ wait_task: 等待任务
|
|
|
+ completed_task: 完成标记任务
|
|
|
+ dag_type: DAG类型名称(用于日志)
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tuple: (start_tasks, end_tasks)
|
|
|
+ """
|
|
|
+ # 连接开始节点
|
|
|
+ start_tasks = []
|
|
|
+ for table_name, task in task_dict.items():
|
|
|
+ if table_name not in tasks_with_upstream:
|
|
|
+ start_tasks.append(table_name)
|
|
|
+ logger.info(f"任务 {table_name} 没有上游任务,应该连接到{dag_type}等待任务")
|
|
|
+
|
|
|
+ logger.info(f"需要连接到{dag_type}等待任务的任务: {start_tasks}")
|
|
|
+
|
|
|
+ for task_name in start_tasks:
|
|
|
+ wait_task >> task_dict[task_name]
|
|
|
+ logger.info(f"连接 {wait_task.task_id} >> {task_name}")
|
|
|
+
|
|
|
+ # 连接末端节点
|
|
|
+ end_tasks = []
|
|
|
+ for table_name, task in task_dict.items():
|
|
|
+ if table_name not in tasks_with_downstream:
|
|
|
+ end_tasks.append(table_name)
|
|
|
+ logger.info(f"任务 {table_name} 没有下游任务,是末端任务")
|
|
|
+
|
|
|
+ logger.info(f"需要连接到{dag_type}完成标记的末端任务: {end_tasks}")
|
|
|
+
|
|
|
+ for end_task in end_tasks:
|
|
|
+ task_dict[end_task] >> completed_task
|
|
|
+ logger.info(f"连接 {end_task} >> {completed_task.task_id}")
|
|
|
+
|
|
|
+ # 处理特殊情况
|
|
|
+ logger.info("处理特殊情况")
|
|
|
+ if not start_tasks:
|
|
|
+ logger.warning(f"没有找到开始任务,将{dag_type}等待任务直接连接到完成标记")
|
|
|
+ wait_task >> completed_task
|
|
|
+
|
|
|
+ if not end_tasks:
|
|
|
+ logger.warning(f"没有找到末端任务,将所有任务连接到{dag_type}完成标记")
|
|
|
+ for table_name, task in task_dict.items():
|
|
|
+ task >> completed_task
|
|
|
+ logger.info(f"直接连接任务到完成标记: {table_name} >> {completed_task.task_id}")
|
|
|
+
|
|
|
+ return start_tasks, end_tasks
|
|
|
+
|
|
|
+
|
|
|
+def process_model_tables(enabled_tables, dag_type, wait_task, completed_task, dag, **task_options):
|
|
|
+ """
|
|
|
+ 处理模型表并构建DAG
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ enabled_tables: 已启用的表列表
|
|
|
+ dag_type: DAG类型 (daily, monthly等)
|
|
|
+ wait_task: 等待任务
|
|
|
+ completed_task: 完成标记任务
|
|
|
+ dag: Airflow DAG对象
|
|
|
+ task_options: 创建任务的额外选项
|
|
|
+ """
|
|
|
+ model_tables = [t for t in enabled_tables if is_data_model_table(t['table_name'])]
|
|
|
+ logger.info(f"获取到 {len(model_tables)} 个启用的 {dag_type} 模型表")
|
|
|
+
|
|
|
+ if not model_tables:
|
|
|
+ # 如果没有模型表需要处理,直接将等待任务与完成标记相连接
|
|
|
+ logger.info(f"没有找到需要处理的{dag_type}模型表,DAG将直接标记为完成")
|
|
|
+ wait_task >> completed_task
|
|
|
+ return
|
|
|
+
|
|
|
+ # 获取表名列表
|
|
|
+ table_names = [t['table_name'] for t in model_tables]
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构建模型依赖DAG
|
|
|
+ optimized_table_order, dependency_graph = build_model_dependency_dag(table_names, model_tables)
|
|
|
+
|
|
|
+ # 创建任务字典
|
|
|
+ task_dict = create_task_dict(optimized_table_order, model_tables, dag, dag_type, **task_options)
|
|
|
+
|
|
|
+ # 建立任务依赖关系
|
|
|
+ tasks_with_upstream, tasks_with_downstream, _ = build_task_dependencies(task_dict, dependency_graph)
|
|
|
+
|
|
|
+ # 连接开始节点和末端节点
|
|
|
+ connect_start_and_end_tasks(task_dict, tasks_with_upstream, tasks_with_downstream,
|
|
|
+ wait_task, completed_task, dag_type)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理{dag_type}模型表时出错: {str(e)}")
|
|
|
+ # 出错时也要确保完成标记被触发
|
|
|
+ wait_task >> completed_task
|
|
|
+ raise
|