graph_operations.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. """
  2. Graph Database Core Operations
  3. 提供图数据库的基本操作功能
  4. """
  5. from neo4j import GraphDatabase
  6. from flask import current_app
  7. from app.services.neo4j_driver import Neo4jDriver
  8. import json
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. class MyEncoder(json.JSONEncoder):
  12. """Neo4j数据序列化的自定义JSON编码器"""
  13. def default(self, obj):
  14. if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))):
  15. return super(MyEncoder, self).default(obj)
  16. return str(obj)
  17. class GraphOperations:
  18. def __init__(self):
  19. self.driver = Neo4jDriver()
  20. def get_connection(self):
  21. return self.driver.connect()
  22. def close(self):
  23. self.driver.close()
  24. def connect_graph():
  25. """
  26. 连接到Neo4j图数据库
  27. Returns:
  28. Neo4j driver实例,如果连接失败则返回None
  29. """
  30. try:
  31. # 从Config获取Neo4j连接参数
  32. uri = current_app.config.get('NEO4J_URI')
  33. user = current_app.config.get('NEO4J_USER')
  34. password = current_app.config.get('NEO4J_PASSWORD')
  35. encrypted = current_app.config.get('NEO4J_ENCRYPTED')
  36. # 创建Neo4j驱动
  37. driver = GraphDatabase.driver(
  38. uri=uri,
  39. auth=(user, password),
  40. encrypted=encrypted
  41. )
  42. # 验证连接
  43. driver.verify_connectivity()
  44. return driver
  45. except Exception as e:
  46. # 处理连接错误
  47. logger.error(f"Error connecting to Neo4j database: {str(e)}")
  48. return None
  49. def create_or_get_node(label, **properties):
  50. """
  51. 创建具有给定标签和属性的新节点或获取现有节点
  52. 如果具有相同id的节点存在,则更新属性
  53. Args:
  54. label (str): Neo4j节点标签
  55. **properties: 作为关键字参数的节点属性
  56. Returns:
  57. 节点id
  58. """
  59. try:
  60. with connect_graph().session() as session:
  61. # 移除 id_list 属性
  62. if 'id_list' in properties:
  63. properties.pop('id_list')
  64. # 检查是否提供了id
  65. if 'id' in properties:
  66. node_id = properties['id']
  67. # 检查节点是否存在
  68. query = f"""
  69. MATCH (n:{label}) WHERE id(n) = $node_id
  70. RETURN n
  71. """
  72. result = session.run(query, node_id=node_id).single()
  73. if result:
  74. # 节点存在,更新属性
  75. props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id'])
  76. if props_string:
  77. update_query = f"""
  78. MATCH (n:{label}) WHERE id(n) = $node_id
  79. SET {props_string}
  80. RETURN id(n) as node_id
  81. """
  82. result = session.run(update_query, node_id=node_id, **properties).single()
  83. return result["node_id"]
  84. return node_id
  85. # 如果到这里,则创建新节点
  86. props_keys = ", ".join([f"{key}: ${key}" for key in properties])
  87. create_query = f"""
  88. CREATE (n:{label} {{{props_keys}}})
  89. RETURN id(n) as node_id
  90. """
  91. result = session.run(create_query, **properties).single()
  92. return result["node_id"]
  93. except Exception as e:
  94. logger.error(f"Error in create_or_get_node: {str(e)}")
  95. raise e
  96. def create_relationship(start_node, end_node, relationship_type, properties=None):
  97. """
  98. 创建两个节点之间的关系
  99. Args:
  100. start_node: 起始节点
  101. end_node: 结束节点
  102. relationship_type: 关系类型
  103. properties: 关系属性
  104. Returns:
  105. 创建的关系对象
  106. """
  107. if not hasattr(start_node, 'id') or not hasattr(end_node, 'id'):
  108. raise ValueError("Invalid node objects provided")
  109. if properties is None:
  110. properties = {}
  111. query = """
  112. MATCH (start), (end)
  113. WHERE id(start) = $start_id AND id(end) = $end_id
  114. MERGE (start)-[r:%s]->(end)
  115. SET r += $properties
  116. RETURN r
  117. """ % relationship_type
  118. with connect_graph().session() as session:
  119. result = session.run(query,
  120. start_id=start_node.id,
  121. end_id=end_node.id,
  122. properties=properties)
  123. return result.single()["r"]
  124. def get_subgraph(node_ids, rel_types=None, max_depth=1):
  125. """
  126. 获取以指定节点为起点的子图
  127. Args:
  128. node_ids: 节点ID列表
  129. rel_types: 关系类型列表(可选)
  130. max_depth: 最大深度,默认为1
  131. Returns:
  132. 包含节点和关系的字典
  133. """
  134. try:
  135. # 处理节点ID列表
  136. node_ids_str = ', '.join([str(nid) for nid in node_ids])
  137. # 处理关系类型过滤
  138. rel_filter = ''
  139. if rel_types:
  140. rel_types_str = '|'.join(rel_types)
  141. rel_filter = f":{rel_types_str}"
  142. # 构建Cypher语句
  143. cypher = f"""
  144. MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m)
  145. WHERE id(n) IN [{node_ids_str}]
  146. RETURN path
  147. """
  148. # 执行查询
  149. with connect_graph().session() as session:
  150. result = session.run(cypher)
  151. # 处理结果为图谱数据
  152. nodes = {}
  153. relationships = {}
  154. for record in result:
  155. path = record["path"]
  156. # 处理节点
  157. for node in path.nodes:
  158. if node.id not in nodes:
  159. node_dict = dict(node)
  160. node_dict['id'] = node.id
  161. node_dict['labels'] = list(node.labels)
  162. nodes[node.id] = node_dict
  163. # 处理关系
  164. for rel in path.relationships:
  165. if rel.id not in relationships:
  166. rel_dict = dict(rel)
  167. rel_dict['id'] = rel.id
  168. rel_dict['type'] = rel.type
  169. rel_dict['source'] = rel.start_node.id
  170. rel_dict['target'] = rel.end_node.id
  171. relationships[rel.id] = rel_dict
  172. # 转换为列表形式
  173. graph_data = {
  174. 'nodes': list(nodes.values()),
  175. 'relationships': list(relationships.values())
  176. }
  177. return graph_data
  178. except Exception as e:
  179. logger.error(f"Error getting subgraph: {str(e)}")
  180. raise e
  181. def execute_cypher_query(cypher, params=None):
  182. """
  183. 执行Cypher查询并返回结果
  184. Args:
  185. cypher: Cypher查询语句
  186. params: 查询参数(可选)
  187. Returns:
  188. 查询结果的列表
  189. """
  190. if params is None:
  191. params = {}
  192. try:
  193. with connect_graph().session() as session:
  194. result = session.run(cypher, **params)
  195. # 处理查询结果
  196. data = []
  197. for record in result:
  198. record_dict = {}
  199. for key, value in record.items():
  200. # 节点处理
  201. if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'):
  202. node_dict = dict(value)
  203. node_dict['_id'] = value.id
  204. node_dict['_labels'] = list(value.labels)
  205. record_dict[key] = node_dict
  206. # 关系处理
  207. elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'):
  208. rel_dict = dict(value)
  209. rel_dict['_id'] = value.id
  210. rel_dict['_type'] = value.type
  211. rel_dict['_start_node_id'] = value.start_node.id
  212. rel_dict['_end_node_id'] = value.end_node.id
  213. record_dict[key] = rel_dict
  214. # 路径处理
  215. elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'):
  216. path_dict = {
  217. 'nodes': [dict(node) for node in value.nodes],
  218. 'relationships': [dict(rel) for rel in value.relationships]
  219. }
  220. record_dict[key] = path_dict
  221. # 其他类型直接转换
  222. else:
  223. record_dict[key] = value
  224. data.append(record_dict)
  225. return data
  226. except Exception as e:
  227. logger.error(f"Error executing Cypher query: {str(e)}")
  228. raise e
  229. def get_node(label, **properties):
  230. """
  231. 查询具有给定标签和属性的节点
  232. Args:
  233. label (str): Neo4j节点标签
  234. **properties: 作为关键字参数的节点属性
  235. Returns:
  236. 节点对象,如果不存在则返回None
  237. """
  238. try:
  239. with connect_graph().session() as session:
  240. # 构建查询条件
  241. conditions = []
  242. params = {}
  243. # 处理ID参数
  244. if 'id' in properties:
  245. conditions.append("id(n) = $node_id")
  246. params['node_id'] = properties['id']
  247. # 移除id属性,避免在后续属性匹配中重复
  248. properties_copy = properties.copy()
  249. properties_copy.pop('id')
  250. properties = properties_copy
  251. # 处理其他属性
  252. for key, value in properties.items():
  253. conditions.append(f"n.{key} = ${key}")
  254. params[key] = value
  255. # 构建查询语句
  256. where_clause = " AND ".join(conditions) if conditions else "TRUE"
  257. query = f"""
  258. MATCH (n:{label})
  259. WHERE {where_clause}
  260. RETURN id(n) as node_id
  261. LIMIT 1
  262. """
  263. # 执行查询
  264. result = session.run(query, **params).single()
  265. return result["node_id"] if result else None
  266. except Exception as e:
  267. logger.error(f"Error in get_node: {str(e)}")
  268. return None
  269. def relationship_exists(start_node, rel_type, end_node, **properties):
  270. """
  271. 检查两个节点之间是否存在指定类型和属性的关系
  272. Args:
  273. start_node: 起始节点或节点ID
  274. rel_type: 关系类型
  275. end_node: 结束节点或节点ID
  276. **properties: 关系的属性
  277. Returns:
  278. bool: 是否存在关系
  279. """
  280. try:
  281. with connect_graph().session() as session:
  282. # 确定节点ID
  283. start_id = start_node.id if hasattr(start_node, 'id') else start_node
  284. end_id = end_node.id if hasattr(end_node, 'id') else end_node
  285. # 构建查询语句
  286. query = """
  287. MATCH (a)-[r:%s]->(b)
  288. WHERE id(a) = $start_id AND id(b) = $end_id
  289. """ % rel_type
  290. # 添加属性条件
  291. if properties:
  292. conditions = []
  293. for key, value in properties.items():
  294. conditions.append(f"r.{key} = ${key}")
  295. query += " AND " + " AND ".join(conditions)
  296. query += "\nRETURN count(r) > 0 as exists"
  297. # 执行查询
  298. params = {
  299. 'start_id': int(start_id),
  300. 'end_id': int(end_id),
  301. **properties
  302. }
  303. result = session.run(query, **params).single()
  304. return result and result["exists"]
  305. except Exception as e:
  306. logger.error(f"Error in relationship_exists: {str(e)}")
  307. return False