#!/usr/bin/env python # -*- coding: utf-8 -*- """ 测试 graph_all 函数 """ import json import sys from pathlib import Path # 修复 Windows 控制台编码问题 if sys.platform == "win32": import io sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8") # 添加项目根目录到Python路径 PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) # 设置环境变量以使用 production 配置(或从配置文件读取) import os from app.config.config import config # 默认使用 production 配置,如果环境变量未设置 if "NEO4J_URI" not in os.environ: prod_config = config.get("production") if prod_config: os.environ["NEO4J_URI"] = prod_config.NEO4J_URI os.environ["NEO4J_USER"] = prod_config.NEO4J_USER os.environ["NEO4J_PASSWORD"] = prod_config.NEO4J_PASSWORD from app.core.data_interface import interface def test_graph_all(domain_id: int, include_meta: bool = True): """测试 graph_all 函数""" print("=" * 60) print(f"测试 graph_all 函数") print("=" * 60) print(f"起始节点ID: {domain_id}") print(f"包含元数据: {include_meta}") print() # 调用函数 result = interface.graph_all(domain_id, include_meta) # 输出结果 nodes = result.get("nodes", []) lines = result.get("lines", []) print(f"找到节点数: {len(nodes)}") print(f"找到关系数: {len(lines)}") print() # 按节点类型分组统计 node_types = {} for node in nodes: node_type = node.get("node_type", "Unknown") node_types[node_type] = node_types.get(node_type, 0) + 1 print("节点类型统计:") for node_type, count in node_types.items(): print(f" {node_type}: {count}") print() # 按关系类型分组统计 rel_types = {} for line in lines: rel_type = line.get("text", "Unknown") rel_types[rel_type] = rel_types.get(rel_type, 0) + 1 print("关系类型统计:") for rel_type, count in rel_types.items(): print(f" {rel_type}: {count}") print() # 显示所有节点详情 print("=" * 60) print("节点详情:") print("=" * 60) for node in nodes: node_id = node.get("id") node_type = node.get("node_type", "Unknown") name_zh = node.get("name_zh", node.get("name", "N/A")) name_en = node.get("name_en", "N/A") print(f" ID: {node_id}, Type: {node_type}, Name: {name_zh} ({name_en})") print() # 显示所有关系详情 print("=" * 60) print("关系详情:") print("=" * 60) for line in lines: rel_id = line.get("id") from_node = line.get("from") to_node = line.get("to") rel_type = line.get("text", "Unknown") print(f" {from_node} -[{rel_type}]-> {to_node} (rel_id: {rel_id})") print() # 验证预期结果 print("=" * 60) print("验证预期结果:") print("=" * 60) # 检查起始节点是否存在 start_node = next((n for n in nodes if n.get("id") == domain_id), None) if start_node: print(f"[OK] 起始节点 {domain_id} 存在: {start_node.get('name_zh', 'N/A')}") else: print(f"[FAIL] 起始节点 {domain_id} 不存在") # 检查是否有 INPUT 关系从起始节点出发 input_lines = [l for l in lines if l.get("from") == str(domain_id) and l.get("text") == "INPUT"] if input_lines: print(f"[OK] 找到 {len(input_lines)} 个 INPUT 关系从节点 {domain_id} 出发") for line in input_lines: df_id = line.get("to") df_node = next((n for n in nodes if str(n.get("id")) == df_id), None) if df_node: print(f" -> DataFlow {df_id}: {df_node.get('name_zh', 'N/A')}") else: print(f"[FAIL] 未找到从节点 {domain_id} 出发的 INPUT 关系") # 检查 DataFlow 节点是否有 OUTPUT 关系 dataflow_nodes = [n for n in nodes if n.get("node_type") == "DataFlow"] for df_node in dataflow_nodes: df_id = df_node.get("id") output_lines = [l for l in lines if l.get("from") == str(df_id) and l.get("text") == "OUTPUT"] if output_lines: print(f"[OK] DataFlow {df_id} 有 {len(output_lines)} 个 OUTPUT 关系:") for line in output_lines: target_bd_id = line.get("to") target_node = next((n for n in nodes if str(n.get("id")) == target_bd_id), None) if target_node: print(f" -> BusinessDomain {target_bd_id}: {target_node.get('name_zh', 'N/A')}") else: print(f"[WARN] DataFlow {df_id} 没有 OUTPUT 关系(但可能应该在数据库中存在)") # 检查预期目标节点 2272 target_node_2272 = next((n for n in nodes if n.get("id") == 2272), None) if target_node_2272: print(f"[OK] 找到预期目标节点 2272: {target_node_2272.get('name_zh', 'N/A')}") else: print(f"[FAIL] 未找到预期目标节点 2272") print() print("=" * 60) # 保存完整结果到 JSON 文件(用于调试) output_file = PROJECT_ROOT / "logs" / f"graph_all_test_{domain_id}.json" output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2, default=str) print(f"完整结果已保存到: {output_file}") return result if __name__ == "__main__": # 测试节点 2213 domain_id = 2213 test_graph_all(domain_id, include_meta=True)