test_graph_all.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试 graph_all 函数
  5. """
  6. import json
  7. import sys
  8. from pathlib import Path
  9. # 修复 Windows 控制台编码问题
  10. if sys.platform == "win32":
  11. import io
  12. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
  13. sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
  14. # 添加项目根目录到Python路径
  15. PROJECT_ROOT = Path(__file__).parent.parent
  16. sys.path.insert(0, str(PROJECT_ROOT))
  17. # 设置环境变量以使用 production 配置(或从配置文件读取)
  18. import os
  19. from app.config.config import config
  20. # 默认使用 production 配置,如果环境变量未设置
  21. if "NEO4J_URI" not in os.environ:
  22. prod_config = config.get("production")
  23. if prod_config:
  24. os.environ["NEO4J_URI"] = prod_config.NEO4J_URI
  25. os.environ["NEO4J_USER"] = prod_config.NEO4J_USER
  26. os.environ["NEO4J_PASSWORD"] = prod_config.NEO4J_PASSWORD
  27. from app.core.data_interface import interface
  28. def test_graph_all(domain_id: int, include_meta: bool = True):
  29. """测试 graph_all 函数"""
  30. print("=" * 60)
  31. print(f"测试 graph_all 函数")
  32. print("=" * 60)
  33. print(f"起始节点ID: {domain_id}")
  34. print(f"包含元数据: {include_meta}")
  35. print()
  36. # 调用函数
  37. result = interface.graph_all(domain_id, include_meta)
  38. # 输出结果
  39. nodes = result.get("nodes", [])
  40. lines = result.get("lines", [])
  41. print(f"找到节点数: {len(nodes)}")
  42. print(f"找到关系数: {len(lines)}")
  43. print()
  44. # 按节点类型分组统计
  45. node_types = {}
  46. for node in nodes:
  47. node_type = node.get("node_type", "Unknown")
  48. node_types[node_type] = node_types.get(node_type, 0) + 1
  49. print("节点类型统计:")
  50. for node_type, count in node_types.items():
  51. print(f" {node_type}: {count}")
  52. print()
  53. # 按关系类型分组统计
  54. rel_types = {}
  55. for line in lines:
  56. rel_type = line.get("text", "Unknown")
  57. rel_types[rel_type] = rel_types.get(rel_type, 0) + 1
  58. print("关系类型统计:")
  59. for rel_type, count in rel_types.items():
  60. print(f" {rel_type}: {count}")
  61. print()
  62. # 显示所有节点详情
  63. print("=" * 60)
  64. print("节点详情:")
  65. print("=" * 60)
  66. for node in nodes:
  67. node_id = node.get("id")
  68. node_type = node.get("node_type", "Unknown")
  69. name_zh = node.get("name_zh", node.get("name", "N/A"))
  70. name_en = node.get("name_en", "N/A")
  71. print(f" ID: {node_id}, Type: {node_type}, Name: {name_zh} ({name_en})")
  72. print()
  73. # 显示所有关系详情
  74. print("=" * 60)
  75. print("关系详情:")
  76. print("=" * 60)
  77. for line in lines:
  78. rel_id = line.get("id")
  79. from_node = line.get("from")
  80. to_node = line.get("to")
  81. rel_type = line.get("text", "Unknown")
  82. print(f" {from_node} -[{rel_type}]-> {to_node} (rel_id: {rel_id})")
  83. print()
  84. # 验证预期结果
  85. print("=" * 60)
  86. print("验证预期结果:")
  87. print("=" * 60)
  88. # 检查起始节点是否存在
  89. start_node = next((n for n in nodes if n.get("id") == domain_id), None)
  90. if start_node:
  91. print(f"[OK] 起始节点 {domain_id} 存在: {start_node.get('name_zh', 'N/A')}")
  92. else:
  93. print(f"[FAIL] 起始节点 {domain_id} 不存在")
  94. # 检查是否有 INPUT 关系从起始节点出发
  95. input_lines = [l for l in lines if l.get("from") == str(domain_id) and l.get("text") == "INPUT"]
  96. if input_lines:
  97. print(f"[OK] 找到 {len(input_lines)} 个 INPUT 关系从节点 {domain_id} 出发")
  98. for line in input_lines:
  99. df_id = line.get("to")
  100. df_node = next((n for n in nodes if str(n.get("id")) == df_id), None)
  101. if df_node:
  102. print(f" -> DataFlow {df_id}: {df_node.get('name_zh', 'N/A')}")
  103. else:
  104. print(f"[FAIL] 未找到从节点 {domain_id} 出发的 INPUT 关系")
  105. # 检查 DataFlow 节点是否有 OUTPUT 关系
  106. dataflow_nodes = [n for n in nodes if n.get("node_type") == "DataFlow"]
  107. for df_node in dataflow_nodes:
  108. df_id = df_node.get("id")
  109. output_lines = [l for l in lines if l.get("from") == str(df_id) and l.get("text") == "OUTPUT"]
  110. if output_lines:
  111. print(f"[OK] DataFlow {df_id} 有 {len(output_lines)} 个 OUTPUT 关系:")
  112. for line in output_lines:
  113. target_bd_id = line.get("to")
  114. target_node = next((n for n in nodes if str(n.get("id")) == target_bd_id), None)
  115. if target_node:
  116. print(f" -> BusinessDomain {target_bd_id}: {target_node.get('name_zh', 'N/A')}")
  117. else:
  118. print(f"[WARN] DataFlow {df_id} 没有 OUTPUT 关系(但可能应该在数据库中存在)")
  119. # 检查预期目标节点 2272
  120. target_node_2272 = next((n for n in nodes if n.get("id") == 2272), None)
  121. if target_node_2272:
  122. print(f"[OK] 找到预期目标节点 2272: {target_node_2272.get('name_zh', 'N/A')}")
  123. else:
  124. print(f"[FAIL] 未找到预期目标节点 2272")
  125. print()
  126. print("=" * 60)
  127. # 保存完整结果到 JSON 文件(用于调试)
  128. output_file = PROJECT_ROOT / "logs" / f"graph_all_test_{domain_id}.json"
  129. output_file.parent.mkdir(parents=True, exist_ok=True)
  130. with open(output_file, "w", encoding="utf-8") as f:
  131. json.dump(result, f, ensure_ascii=False, indent=2, default=str)
  132. print(f"完整结果已保存到: {output_file}")
  133. return result
  134. if __name__ == "__main__":
  135. # 测试节点 2213
  136. domain_id = 2213
  137. test_graph_all(domain_id, include_meta=True)