graph_operations.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. """
  2. Graph Database Core Operations
  3. 提供图数据库的基本操作功能
  4. """
  5. from neo4j import GraphDatabase
  6. from flask import current_app
  7. from app.services.neo4j_driver import Neo4jDriver
  8. import json
  9. import logging
  10. from datetime import datetime
  11. logger = logging.getLogger(__name__)
  12. class MyEncoder(json.JSONEncoder):
  13. """Neo4j数据序列化的自定义JSON编码器"""
  14. def default(self, obj):
  15. if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))):
  16. return super(MyEncoder, self).default(obj)
  17. # 处理DateTime对象
  18. if hasattr(obj, 'isoformat'):
  19. return obj.isoformat()
  20. return str(obj)
  21. class GraphOperations:
  22. def __init__(self):
  23. self.driver = Neo4jDriver()
  24. def get_connection(self):
  25. return self.driver.connect()
  26. def close(self):
  27. self.driver.close()
  28. def connect_graph():
  29. """
  30. 连接到Neo4j图数据库
  31. Returns:
  32. Neo4j driver实例,如果连接失败则返回None
  33. """
  34. try:
  35. # 从Config获取Neo4j连接参数
  36. uri = current_app.config.get('NEO4J_URI')
  37. user = current_app.config.get('NEO4J_USER')
  38. password = current_app.config.get('NEO4J_PASSWORD')
  39. encrypted = current_app.config.get('NEO4J_ENCRYPTED')
  40. # 创建Neo4j驱动
  41. driver = GraphDatabase.driver(
  42. uri=uri,
  43. auth=(user, password),
  44. encrypted=encrypted
  45. )
  46. # 验证连接
  47. driver.verify_connectivity()
  48. return driver
  49. except Exception as e:
  50. # 处理连接错误
  51. logger.error(f"Error connecting to Neo4j database: {str(e)}")
  52. return None
  53. def create_or_get_node(label, **properties):
  54. """
  55. 创建具有给定标签和属性的新节点或获取现有节点
  56. 如果具有相同id的节点存在,则更新属性
  57. Args:
  58. label (str): Neo4j节点标签
  59. **properties: 作为关键字参数的节点属性
  60. Returns:
  61. 节点id
  62. """
  63. try:
  64. with connect_graph().session() as session:
  65. # 移除 id_list 属性
  66. if 'id_list' in properties:
  67. properties.pop('id_list')
  68. # 检查是否提供了id
  69. if 'id' in properties:
  70. node_id = properties['id']
  71. # 检查节点是否存在
  72. query = f"""
  73. MATCH (n:{label}) WHERE id(n) = $node_id
  74. RETURN n
  75. """
  76. result = session.run(query, node_id=node_id).single()
  77. if result:
  78. # 节点存在,更新属性
  79. props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id'])
  80. if props_string:
  81. update_query = f"""
  82. MATCH (n:{label}) WHERE id(n) = $node_id
  83. SET {props_string}
  84. RETURN id(n) as node_id
  85. """
  86. result = session.run(update_query, node_id=node_id, **properties).single()
  87. return result["node_id"]
  88. return node_id
  89. # 如果到这里,则创建新节点
  90. props_keys = ", ".join([f"{key}: ${key}" for key in properties])
  91. create_query = f"""
  92. CREATE (n:{label} {{{props_keys}}})
  93. RETURN id(n) as node_id
  94. """
  95. result = session.run(create_query, **properties).single()
  96. return result["node_id"]
  97. except Exception as e:
  98. logger.error(f"Error in create_or_get_node: {str(e)}")
  99. raise e
  100. def create_relationship(start_node, end_node, relationship_type, properties=None):
  101. """
  102. 创建两个节点之间的关系
  103. Args:
  104. start_node: 起始节点
  105. end_node: 结束节点
  106. relationship_type: 关系类型
  107. properties: 关系属性
  108. Returns:
  109. 创建的关系对象
  110. """
  111. if not hasattr(start_node, 'id') or not hasattr(end_node, 'id'):
  112. raise ValueError("Invalid node objects provided")
  113. if properties is None:
  114. properties = {}
  115. query = """
  116. MATCH (start), (end)
  117. WHERE id(start) = $start_id AND id(end) = $end_id
  118. MERGE (start)-[r:%s]->(end)
  119. SET r += $properties
  120. RETURN r
  121. """ % relationship_type
  122. with connect_graph().session() as session:
  123. result = session.run(query,
  124. start_id=start_node.id,
  125. end_id=end_node.id,
  126. properties=properties)
  127. return result.single()["r"]
  128. def get_subgraph(node_ids, rel_types=None, max_depth=1):
  129. """
  130. 获取以指定节点为起点的子图
  131. Args:
  132. node_ids: 节点ID列表
  133. rel_types: 关系类型列表(可选)
  134. max_depth: 最大深度,默认为1
  135. Returns:
  136. 包含节点和关系的字典
  137. """
  138. try:
  139. # 处理节点ID列表
  140. node_ids_str = ', '.join([str(nid) for nid in node_ids])
  141. # 处理关系类型过滤
  142. rel_filter = ''
  143. if rel_types:
  144. rel_types_str = '|'.join(rel_types)
  145. rel_filter = f":{rel_types_str}"
  146. # 构建Cypher语句
  147. cypher = f"""
  148. MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m)
  149. WHERE id(n) IN [{node_ids_str}]
  150. RETURN path
  151. """
  152. # 执行查询
  153. with connect_graph().session() as session:
  154. result = session.run(cypher)
  155. # 处理结果为图谱数据
  156. nodes = {}
  157. relationships = {}
  158. for record in result:
  159. path = record["path"]
  160. # 处理节点
  161. for node in path.nodes:
  162. if node.id not in nodes:
  163. node_dict = dict(node)
  164. node_dict['id'] = node.id
  165. node_dict['labels'] = list(node.labels)
  166. nodes[node.id] = node_dict
  167. # 处理关系
  168. for rel in path.relationships:
  169. if rel.id not in relationships:
  170. rel_dict = dict(rel)
  171. rel_dict['id'] = rel.id
  172. rel_dict['type'] = rel.type
  173. rel_dict['source'] = rel.start_node.id
  174. rel_dict['target'] = rel.end_node.id
  175. relationships[rel.id] = rel_dict
  176. # 转换为列表形式
  177. graph_data = {
  178. 'nodes': list(nodes.values()),
  179. 'relationships': list(relationships.values())
  180. }
  181. return graph_data
  182. except Exception as e:
  183. logger.error(f"Error getting subgraph: {str(e)}")
  184. raise e
  185. def execute_cypher_query(cypher, params=None):
  186. """
  187. 执行Cypher查询并返回结果
  188. Args:
  189. cypher: Cypher查询语句
  190. params: 查询参数(可选)
  191. Returns:
  192. 查询结果的列表
  193. """
  194. if params is None:
  195. params = {}
  196. def convert_value(value):
  197. """转换Neo4j返回的值为JSON可序列化的格式"""
  198. # 处理DateTime对象
  199. if hasattr(value, 'isoformat'):
  200. return value.isoformat()
  201. # 处理Date对象
  202. elif hasattr(value, 'year') and hasattr(value, 'month') and hasattr(value, 'day'):
  203. return str(value)
  204. # 处理Time对象
  205. elif hasattr(value, 'hour') and hasattr(value, 'minute') and hasattr(value, 'second'):
  206. return str(value)
  207. # 处理其他对象
  208. else:
  209. return value
  210. try:
  211. with connect_graph().session() as session:
  212. result = session.run(cypher, **params)
  213. # 处理查询结果
  214. data = []
  215. for record in result:
  216. record_dict = {}
  217. for key, value in record.items():
  218. # 节点处理
  219. if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'):
  220. node_dict = {}
  221. for prop_key, prop_value in dict(value).items():
  222. node_dict[prop_key] = convert_value(prop_value)
  223. node_dict['_id'] = value.id
  224. node_dict['_labels'] = list(value.labels)
  225. record_dict[key] = node_dict
  226. # 关系处理
  227. elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'):
  228. rel_dict = {}
  229. for prop_key, prop_value in dict(value).items():
  230. rel_dict[prop_key] = convert_value(prop_value)
  231. rel_dict['_id'] = value.id
  232. rel_dict['_type'] = value.type
  233. rel_dict['_start_node_id'] = value.start_node.id
  234. rel_dict['_end_node_id'] = value.end_node.id
  235. record_dict[key] = rel_dict
  236. # 路径处理
  237. elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'):
  238. path_dict = {
  239. 'nodes': [],
  240. 'relationships': []
  241. }
  242. # 处理路径中的节点
  243. for node in value.nodes:
  244. node_dict = {}
  245. for prop_key, prop_value in dict(node).items():
  246. node_dict[prop_key] = convert_value(prop_value)
  247. path_dict['nodes'].append(node_dict)
  248. # 处理路径中的关系
  249. for rel in value.relationships:
  250. rel_dict = {}
  251. for prop_key, prop_value in dict(rel).items():
  252. rel_dict[prop_key] = convert_value(prop_value)
  253. path_dict['relationships'].append(rel_dict)
  254. record_dict[key] = path_dict
  255. # 其他类型直接转换
  256. else:
  257. record_dict[key] = convert_value(value)
  258. data.append(record_dict)
  259. return data
  260. except Exception as e:
  261. logger.error(f"Error executing Cypher query: {str(e)}")
  262. raise e
  263. def get_node(label, **properties):
  264. """
  265. 查询具有给定标签和属性的节点
  266. Args:
  267. label (str): Neo4j节点标签
  268. **properties: 作为关键字参数的节点属性
  269. Returns:
  270. 节点对象,如果不存在则返回None
  271. """
  272. try:
  273. with connect_graph().session() as session:
  274. # 构建查询条件
  275. conditions = []
  276. params = {}
  277. # 处理ID参数
  278. if 'id' in properties:
  279. conditions.append("id(n) = $node_id")
  280. params['node_id'] = properties['id']
  281. # 移除id属性,避免在后续属性匹配中重复
  282. properties_copy = properties.copy()
  283. properties_copy.pop('id')
  284. properties = properties_copy
  285. # 处理其他属性
  286. for key, value in properties.items():
  287. conditions.append(f"n.{key} = ${key}")
  288. params[key] = value
  289. # 构建查询语句
  290. where_clause = " AND ".join(conditions) if conditions else "TRUE"
  291. query = f"""
  292. MATCH (n:{label})
  293. WHERE {where_clause}
  294. RETURN id(n) as node_id
  295. LIMIT 1
  296. """
  297. # 执行查询
  298. result = session.run(query, **params).single()
  299. return result["node_id"] if result else None
  300. except Exception as e:
  301. logger.error(f"Error in get_node: {str(e)}")
  302. return None
  303. def relationship_exists(start_node_id, rel_type, end_node_id, **properties):
  304. """
  305. 检查两个节点之间是否存在指定类型和属性的关系
  306. Args:
  307. start_node_id: 起始节点ID (必须是整数ID)
  308. rel_type: 关系类型
  309. end_node_id: 结束节点ID (必须是整数ID)
  310. **properties: 关系的属性
  311. Returns:
  312. bool: 是否存在关系
  313. """
  314. try:
  315. with connect_graph().session() as session:
  316. # 确保输入的是有效的节点ID
  317. if not isinstance(start_node_id, (int, str)) or not isinstance(end_node_id, (int, str)):
  318. logger.warning(f"无效的节点ID类型: start_node_id={type(start_node_id)}, end_node_id={type(end_node_id)}")
  319. return False
  320. # 转换为整数
  321. try:
  322. start_id = int(start_node_id)
  323. end_id = int(end_node_id)
  324. except (ValueError, TypeError):
  325. logger.warning(f"无法转换节点ID为整数: start_node_id={start_node_id}, end_node_id={end_node_id}")
  326. return False
  327. # 构建查询语句
  328. query = """
  329. MATCH (a)-[r:%s]->(b)
  330. WHERE id(a) = $start_id AND id(b) = $end_id
  331. """ % rel_type
  332. # 添加属性条件
  333. if properties:
  334. conditions = []
  335. for key, value in properties.items():
  336. conditions.append(f"r.{key} = ${key}")
  337. query += " AND " + " AND ".join(conditions)
  338. query += "\nRETURN count(r) > 0 as exists"
  339. # 执行查询
  340. params = {
  341. 'start_id': start_id,
  342. 'end_id': end_id,
  343. **properties
  344. }
  345. result = session.run(query, **params).single()
  346. return result and result["exists"]
  347. except Exception as e:
  348. logger.error(f"Error in relationship_exists: {str(e)}")
  349. return False