| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 测试 graph_all 函数
- """
- import json
- import sys
- from pathlib import Path
- # 修复 Windows 控制台编码问题
- if sys.platform == "win32":
- import io
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
- # 添加项目根目录到Python路径
- PROJECT_ROOT = Path(__file__).parent.parent
- sys.path.insert(0, str(PROJECT_ROOT))
- # 设置环境变量以使用 production 配置(或从配置文件读取)
- import os
- from app.config.config import config
- # 默认使用 production 配置,如果环境变量未设置
- if "NEO4J_URI" not in os.environ:
- prod_config = config.get("production")
- if prod_config:
- os.environ["NEO4J_URI"] = prod_config.NEO4J_URI
- os.environ["NEO4J_USER"] = prod_config.NEO4J_USER
- os.environ["NEO4J_PASSWORD"] = prod_config.NEO4J_PASSWORD
- from app.core.data_interface import interface
- def test_graph_all(domain_id: int, include_meta: bool = True):
- """测试 graph_all 函数"""
- print("=" * 60)
- print(f"测试 graph_all 函数")
- print("=" * 60)
- print(f"起始节点ID: {domain_id}")
- print(f"包含元数据: {include_meta}")
- print()
- # 调用函数
- result = interface.graph_all(domain_id, include_meta)
- # 输出结果
- nodes = result.get("nodes", [])
- lines = result.get("lines", [])
- print(f"找到节点数: {len(nodes)}")
- print(f"找到关系数: {len(lines)}")
- print()
- # 按节点类型分组统计
- node_types = {}
- for node in nodes:
- node_type = node.get("node_type", "Unknown")
- node_types[node_type] = node_types.get(node_type, 0) + 1
- print("节点类型统计:")
- for node_type, count in node_types.items():
- print(f" {node_type}: {count}")
- print()
- # 按关系类型分组统计
- rel_types = {}
- for line in lines:
- rel_type = line.get("text", "Unknown")
- rel_types[rel_type] = rel_types.get(rel_type, 0) + 1
- print("关系类型统计:")
- for rel_type, count in rel_types.items():
- print(f" {rel_type}: {count}")
- print()
- # 显示所有节点详情
- print("=" * 60)
- print("节点详情:")
- print("=" * 60)
- for node in nodes:
- node_id = node.get("id")
- node_type = node.get("node_type", "Unknown")
- name_zh = node.get("name_zh", node.get("name", "N/A"))
- name_en = node.get("name_en", "N/A")
- print(f" ID: {node_id}, Type: {node_type}, Name: {name_zh} ({name_en})")
- print()
- # 显示所有关系详情
- print("=" * 60)
- print("关系详情:")
- print("=" * 60)
- for line in lines:
- rel_id = line.get("id")
- from_node = line.get("from")
- to_node = line.get("to")
- rel_type = line.get("text", "Unknown")
- print(f" {from_node} -[{rel_type}]-> {to_node} (rel_id: {rel_id})")
- print()
- # 验证预期结果
- print("=" * 60)
- print("验证预期结果:")
- print("=" * 60)
- # 检查起始节点是否存在
- start_node = next((n for n in nodes if n.get("id") == domain_id), None)
- if start_node:
- print(f"[OK] 起始节点 {domain_id} 存在: {start_node.get('name_zh', 'N/A')}")
- else:
- print(f"[FAIL] 起始节点 {domain_id} 不存在")
- # 检查是否有 INPUT 关系从起始节点出发
- input_lines = [l for l in lines if l.get("from") == str(domain_id) and l.get("text") == "INPUT"]
- if input_lines:
- print(f"[OK] 找到 {len(input_lines)} 个 INPUT 关系从节点 {domain_id} 出发")
- for line in input_lines:
- df_id = line.get("to")
- df_node = next((n for n in nodes if str(n.get("id")) == df_id), None)
- if df_node:
- print(f" -> DataFlow {df_id}: {df_node.get('name_zh', 'N/A')}")
- else:
- print(f"[FAIL] 未找到从节点 {domain_id} 出发的 INPUT 关系")
- # 检查 DataFlow 节点是否有 OUTPUT 关系
- dataflow_nodes = [n for n in nodes if n.get("node_type") == "DataFlow"]
- for df_node in dataflow_nodes:
- df_id = df_node.get("id")
- output_lines = [l for l in lines if l.get("from") == str(df_id) and l.get("text") == "OUTPUT"]
- if output_lines:
- print(f"[OK] DataFlow {df_id} 有 {len(output_lines)} 个 OUTPUT 关系:")
- for line in output_lines:
- target_bd_id = line.get("to")
- target_node = next((n for n in nodes if str(n.get("id")) == target_bd_id), None)
- if target_node:
- print(f" -> BusinessDomain {target_bd_id}: {target_node.get('name_zh', 'N/A')}")
- else:
- print(f"[WARN] DataFlow {df_id} 没有 OUTPUT 关系(但可能应该在数据库中存在)")
- # 检查预期目标节点 2272
- target_node_2272 = next((n for n in nodes if n.get("id") == 2272), None)
- if target_node_2272:
- print(f"[OK] 找到预期目标节点 2272: {target_node_2272.get('name_zh', 'N/A')}")
- else:
- print(f"[FAIL] 未找到预期目标节点 2272")
- print()
- print("=" * 60)
- # 保存完整结果到 JSON 文件(用于调试)
- output_file = PROJECT_ROOT / "logs" / f"graph_all_test_{domain_id}.json"
- output_file.parent.mkdir(parents=True, exist_ok=True)
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2, default=str)
- print(f"完整结果已保存到: {output_file}")
- return result
- if __name__ == "__main__":
- # 测试节点 2213
- domain_id = 2213
- test_graph_all(domain_id, include_meta=True)
|