graph_operations.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. Graph Database Core Operations
  3. 提供图数据库的基本操作功能
  4. """
  5. from neo4j import GraphDatabase
  6. from app.config.config import Config
  7. import json
  8. import logging
  9. logger = logging.getLogger(__name__)
  10. class MyEncoder(json.JSONEncoder):
  11. """Neo4j数据序列化的自定义JSON编码器"""
  12. def default(self, obj):
  13. if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))):
  14. return super(MyEncoder, self).default(obj)
  15. return str(obj)
  16. def connect_graph():
  17. """
  18. 连接到Neo4j图数据库
  19. Returns:
  20. Neo4j driver实例,如果连接失败则返回None
  21. """
  22. try:
  23. # 从Config获取Neo4j连接参数
  24. uri = Config.NEO4J_URI
  25. user = Config.NEO4J_USER
  26. password = Config.NEO4J_PASSWORD
  27. encrypted = Config.NEO4J_ENCRYPTED
  28. # 创建Neo4j驱动
  29. driver = GraphDatabase.driver(
  30. uri=uri,
  31. auth=(user, password),
  32. encrypted=encrypted
  33. )
  34. # 验证连接
  35. driver.verify_connectivity()
  36. return driver
  37. except Exception as e:
  38. # 处理连接错误
  39. logger.error(f"Error connecting to Neo4j database: {str(e)}")
  40. return None
  41. def create_or_get_node(label, **properties):
  42. """
  43. 创建具有给定标签和属性的新节点或获取现有节点
  44. 如果具有相同id的节点存在,则更新属性
  45. Args:
  46. label (str): Neo4j节点标签
  47. **properties: 作为关键字参数的节点属性
  48. Returns:
  49. 节点id
  50. """
  51. try:
  52. with connect_graph().session() as session:
  53. # 检查是否提供了id
  54. if 'id' in properties:
  55. node_id = properties['id']
  56. # 检查节点是否存在
  57. query = f"""
  58. MATCH (n:{label}) WHERE id(n) = $node_id
  59. RETURN n
  60. """
  61. result = session.run(query, node_id=node_id).single()
  62. if result:
  63. # 节点存在,更新属性
  64. props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id'])
  65. if props_string:
  66. update_query = f"""
  67. MATCH (n:{label}) WHERE id(n) = $node_id
  68. SET {props_string}
  69. RETURN id(n) as node_id
  70. """
  71. result = session.run(update_query, node_id=node_id, **properties).single()
  72. return result["node_id"]
  73. return node_id
  74. # 如果到这里,则创建新节点
  75. props_keys = ", ".join([f"{key}: ${key}" for key in properties])
  76. create_query = f"""
  77. CREATE (n:{label} {{{props_keys}}})
  78. RETURN id(n) as node_id
  79. """
  80. result = session.run(create_query, **properties).single()
  81. return result["node_id"]
  82. except Exception as e:
  83. logger.error(f"Error in create_or_get_node: {str(e)}")
  84. raise e
  85. def create_relationship(start_node_id, end_node_id, rel_type, **properties):
  86. """
  87. 在两个节点之间创建关系
  88. Args:
  89. start_node_id: 起始节点ID
  90. end_node_id: 结束节点ID
  91. rel_type: 关系类型
  92. **properties: 关系的属性
  93. Returns:
  94. 关系的ID
  95. """
  96. try:
  97. # 构建属性部分
  98. properties_str = ', '.join([f"{k}: ${k}" for k in properties.keys()])
  99. properties_part = f" {{{properties_str}}}" if properties else ""
  100. # 构建Cypher语句
  101. cypher = f"""
  102. MATCH (a), (b)
  103. WHERE id(a) = $start_node_id AND id(b) = $end_node_id
  104. CREATE (a)-[r:{rel_type}{properties_part}]->(b)
  105. RETURN id(r) as rel_id
  106. """
  107. # 执行创建
  108. with connect_graph().session() as session:
  109. params = {
  110. 'start_node_id': int(start_node_id),
  111. 'end_node_id': int(end_node_id),
  112. **properties
  113. }
  114. result = session.run(cypher, **params).single()
  115. if result:
  116. return result["rel_id"]
  117. else:
  118. logger.error("Failed to create relationship")
  119. return None
  120. except Exception as e:
  121. logger.error(f"Error creating relationship: {str(e)}")
  122. raise e
  123. def get_subgraph(node_ids, rel_types=None, max_depth=1):
  124. """
  125. 获取以指定节点为起点的子图
  126. Args:
  127. node_ids: 节点ID列表
  128. rel_types: 关系类型列表(可选)
  129. max_depth: 最大深度,默认为1
  130. Returns:
  131. 包含节点和关系的字典
  132. """
  133. try:
  134. # 处理节点ID列表
  135. node_ids_str = ', '.join([str(nid) for nid in node_ids])
  136. # 处理关系类型过滤
  137. rel_filter = ''
  138. if rel_types:
  139. rel_types_str = '|'.join(rel_types)
  140. rel_filter = f":{rel_types_str}"
  141. # 构建Cypher语句
  142. cypher = f"""
  143. MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m)
  144. WHERE id(n) IN [{node_ids_str}]
  145. RETURN path
  146. """
  147. # 执行查询
  148. with connect_graph().session() as session:
  149. result = session.run(cypher)
  150. # 处理结果为图谱数据
  151. nodes = {}
  152. relationships = {}
  153. for record in result:
  154. path = record["path"]
  155. # 处理节点
  156. for node in path.nodes:
  157. if node.id not in nodes:
  158. node_dict = dict(node)
  159. node_dict['id'] = node.id
  160. node_dict['labels'] = list(node.labels)
  161. nodes[node.id] = node_dict
  162. # 处理关系
  163. for rel in path.relationships:
  164. if rel.id not in relationships:
  165. rel_dict = dict(rel)
  166. rel_dict['id'] = rel.id
  167. rel_dict['type'] = rel.type
  168. rel_dict['source'] = rel.start_node.id
  169. rel_dict['target'] = rel.end_node.id
  170. relationships[rel.id] = rel_dict
  171. # 转换为列表形式
  172. graph_data = {
  173. 'nodes': list(nodes.values()),
  174. 'relationships': list(relationships.values())
  175. }
  176. return graph_data
  177. except Exception as e:
  178. logger.error(f"Error getting subgraph: {str(e)}")
  179. raise e
  180. def execute_cypher_query(cypher, params=None):
  181. """
  182. 执行Cypher查询并返回结果
  183. Args:
  184. cypher: Cypher查询语句
  185. params: 查询参数(可选)
  186. Returns:
  187. 查询结果的列表
  188. """
  189. if params is None:
  190. params = {}
  191. try:
  192. with connect_graph().session() as session:
  193. result = session.run(cypher, **params)
  194. # 处理查询结果
  195. data = []
  196. for record in result:
  197. record_dict = {}
  198. for key, value in record.items():
  199. # 节点处理
  200. if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'):
  201. node_dict = dict(value)
  202. node_dict['_id'] = value.id
  203. node_dict['_labels'] = list(value.labels)
  204. record_dict[key] = node_dict
  205. # 关系处理
  206. elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'):
  207. rel_dict = dict(value)
  208. rel_dict['_id'] = value.id
  209. rel_dict['_type'] = value.type
  210. rel_dict['_start_node_id'] = value.start_node.id
  211. rel_dict['_end_node_id'] = value.end_node.id
  212. record_dict[key] = rel_dict
  213. # 路径处理
  214. elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'):
  215. path_dict = {
  216. 'nodes': [dict(node) for node in value.nodes],
  217. 'relationships': [dict(rel) for rel in value.relationships]
  218. }
  219. record_dict[key] = path_dict
  220. # 其他类型直接转换
  221. else:
  222. record_dict[key] = value
  223. data.append(record_dict)
  224. return data
  225. except Exception as e:
  226. logger.error(f"Error executing Cypher query: {str(e)}")
  227. raise e