graph_operations.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. """
  2. Graph Database Core Operations
  3. 提供图数据库的基本操作功能
  4. """
  5. from neo4j import GraphDatabase
  6. from flask import current_app
  7. from app.services.neo4j_driver import Neo4jDriver
  8. import json
  9. import logging
  10. from datetime import datetime
  11. logger = logging.getLogger(__name__)
  12. class MyEncoder(json.JSONEncoder):
  13. """Neo4j数据序列化的自定义JSON编码器"""
  14. def default(self, obj):
  15. if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))):
  16. return super(MyEncoder, self).default(obj)
  17. # 处理DateTime对象
  18. if hasattr(obj, 'isoformat'):
  19. return obj.isoformat()
  20. return str(obj)
  21. class GraphOperations:
  22. def __init__(self):
  23. self.driver = Neo4jDriver()
  24. def get_connection(self):
  25. return self.driver.connect()
  26. def close(self):
  27. self.driver.close()
  28. def connect_graph():
  29. """
  30. 连接到Neo4j图数据库
  31. Returns:
  32. Neo4j driver实例
  33. Raises:
  34. ConnectionError: 如果无法连接到Neo4j数据库
  35. ValueError: 如果配置参数缺失
  36. """
  37. try:
  38. # 从Config获取Neo4j连接参数
  39. uri = current_app.config.get('NEO4J_URI')
  40. user = current_app.config.get('NEO4J_USER')
  41. password = current_app.config.get('NEO4J_PASSWORD')
  42. encrypted = current_app.config.get('NEO4J_ENCRYPTED')
  43. # 检查必需的配置参数
  44. if not uri:
  45. raise ValueError("Neo4j URI配置缺失,请检查NEO4J_URI配置")
  46. if not user:
  47. raise ValueError("Neo4j用户配置缺失,请检查NEO4J_USER配置")
  48. if password is None:
  49. raise ValueError("Neo4j密码配置缺失,请检查NEO4J_PASSWORD配置")
  50. # 创建Neo4j驱动
  51. driver = GraphDatabase.driver(
  52. uri=uri,
  53. auth=(user, password),
  54. encrypted=encrypted
  55. )
  56. # 验证连接
  57. driver.verify_connectivity()
  58. return driver
  59. except Exception as e:
  60. # 处理连接错误,抛出异常而不是返回None
  61. error_msg = f"无法连接到Neo4j图数据库: {str(e)}"
  62. logger.error(error_msg)
  63. raise ConnectionError(error_msg) from e
  64. def create_or_get_node(label, **properties):
  65. """
  66. 创建具有给定标签和属性的新节点或获取现有节点
  67. 如果具有相同id的节点存在,则更新属性
  68. Args:
  69. label (str): Neo4j节点标签
  70. **properties: 作为关键字参数的节点属性
  71. Returns:
  72. 节点id
  73. """
  74. try:
  75. with connect_graph().session() as session:
  76. # 移除 id_list 属性
  77. if 'id_list' in properties:
  78. properties.pop('id_list')
  79. # 检查是否提供了id
  80. if 'id' in properties:
  81. node_id = properties['id']
  82. # 检查节点是否存在
  83. query = f"""
  84. MATCH (n:{label}) WHERE id(n) = $node_id
  85. RETURN n
  86. """
  87. result = session.run(query, node_id=node_id).single()
  88. if result:
  89. # 节点存在,更新属性
  90. props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id'])
  91. if props_string:
  92. update_query = f"""
  93. MATCH (n:{label}) WHERE id(n) = $node_id
  94. SET {props_string}
  95. RETURN id(n) as node_id
  96. """
  97. result = session.run(update_query, node_id=node_id, **properties).single()
  98. return result["node_id"]
  99. return node_id
  100. # 如果到这里,则创建新节点
  101. props_keys = ", ".join([f"{key}: ${key}" for key in properties])
  102. create_query = f"""
  103. CREATE (n:{label} {{{props_keys}}})
  104. RETURN id(n) as node_id
  105. """
  106. result = session.run(create_query, **properties).single()
  107. return result["node_id"]
  108. except Exception as e:
  109. logger.error(f"Error in create_or_get_node: {str(e)}")
  110. raise e
  111. def create_relationship(start_node, end_node, relationship_type, properties=None):
  112. """
  113. 创建两个节点之间的关系
  114. Args:
  115. start_node: 起始节点
  116. end_node: 结束节点
  117. relationship_type: 关系类型
  118. properties: 关系属性
  119. Returns:
  120. 创建的关系对象
  121. """
  122. if not hasattr(start_node, 'id') or not hasattr(end_node, 'id'):
  123. raise ValueError("Invalid node objects provided")
  124. if properties is None:
  125. properties = {}
  126. query = """
  127. MATCH (start), (end)
  128. WHERE id(start) = $start_id AND id(end) = $end_id
  129. MERGE (start)-[r:%s]->(end)
  130. SET r += $properties
  131. RETURN r
  132. """ % relationship_type
  133. with connect_graph().session() as session:
  134. result = session.run(query,
  135. start_id=start_node.id,
  136. end_id=end_node.id,
  137. properties=properties)
  138. return result.single()["r"]
  139. def get_subgraph(node_ids, rel_types=None, max_depth=1):
  140. """
  141. 获取以指定节点为起点的子图
  142. Args:
  143. node_ids: 节点ID列表
  144. rel_types: 关系类型列表(可选)
  145. max_depth: 最大深度,默认为1
  146. Returns:
  147. 包含节点和关系的字典
  148. """
  149. try:
  150. # 处理节点ID列表
  151. node_ids_str = ', '.join([str(nid) for nid in node_ids])
  152. # 处理关系类型过滤
  153. rel_filter = ''
  154. if rel_types:
  155. rel_types_str = '|'.join(rel_types)
  156. rel_filter = f":{rel_types_str}"
  157. # 构建Cypher语句
  158. cypher = f"""
  159. MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m)
  160. WHERE id(n) IN [{node_ids_str}]
  161. RETURN path
  162. """
  163. # 执行查询
  164. with connect_graph().session() as session:
  165. result = session.run(cypher)
  166. # 处理结果为图谱数据
  167. nodes = {}
  168. relationships = {}
  169. for record in result:
  170. path = record["path"]
  171. # 处理节点
  172. for node in path.nodes:
  173. if node.id not in nodes:
  174. node_dict = dict(node)
  175. node_dict['id'] = node.id
  176. node_dict['labels'] = list(node.labels)
  177. nodes[node.id] = node_dict
  178. # 处理关系
  179. for rel in path.relationships:
  180. if rel.id not in relationships:
  181. rel_dict = dict(rel)
  182. rel_dict['id'] = rel.id
  183. rel_dict['type'] = rel.type
  184. rel_dict['source'] = rel.start_node.id
  185. rel_dict['target'] = rel.end_node.id
  186. relationships[rel.id] = rel_dict
  187. # 转换为列表形式
  188. graph_data = {
  189. 'nodes': list(nodes.values()),
  190. 'relationships': list(relationships.values())
  191. }
  192. return graph_data
  193. except Exception as e:
  194. logger.error(f"Error getting subgraph: {str(e)}")
  195. raise e
  196. def execute_cypher_query(cypher, params=None):
  197. """
  198. 执行Cypher查询并返回结果
  199. Args:
  200. cypher: Cypher查询语句
  201. params: 查询参数(可选)
  202. Returns:
  203. 查询结果的列表
  204. """
  205. if params is None:
  206. params = {}
  207. def convert_value(value):
  208. """转换Neo4j返回的值为JSON可序列化的格式"""
  209. # 处理DateTime对象
  210. if hasattr(value, 'isoformat'):
  211. return value.isoformat()
  212. # 处理Date对象
  213. elif hasattr(value, 'year') and hasattr(value, 'month') and hasattr(value, 'day'):
  214. return str(value)
  215. # 处理Time对象
  216. elif hasattr(value, 'hour') and hasattr(value, 'minute') and hasattr(value, 'second'):
  217. return str(value)
  218. # 处理其他对象
  219. else:
  220. return value
  221. try:
  222. with connect_graph().session() as session:
  223. result = session.run(cypher, **params)
  224. # 处理查询结果
  225. data = []
  226. for record in result:
  227. record_dict = {}
  228. for key, value in record.items():
  229. # 节点处理
  230. if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'):
  231. node_dict = {}
  232. for prop_key, prop_value in dict(value).items():
  233. node_dict[prop_key] = convert_value(prop_value)
  234. node_dict['_id'] = value.id
  235. node_dict['_labels'] = list(value.labels)
  236. record_dict[key] = node_dict
  237. # 关系处理
  238. elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'):
  239. rel_dict = {}
  240. for prop_key, prop_value in dict(value).items():
  241. rel_dict[prop_key] = convert_value(prop_value)
  242. rel_dict['_id'] = value.id
  243. rel_dict['_type'] = value.type
  244. rel_dict['_start_node_id'] = value.start_node.id
  245. rel_dict['_end_node_id'] = value.end_node.id
  246. record_dict[key] = rel_dict
  247. # 路径处理
  248. elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'):
  249. path_dict = {
  250. 'nodes': [],
  251. 'relationships': []
  252. }
  253. # 处理路径中的节点
  254. for node in value.nodes:
  255. node_dict = {}
  256. for prop_key, prop_value in dict(node).items():
  257. node_dict[prop_key] = convert_value(prop_value)
  258. path_dict['nodes'].append(node_dict)
  259. # 处理路径中的关系
  260. for rel in value.relationships:
  261. rel_dict = {}
  262. for prop_key, prop_value in dict(rel).items():
  263. rel_dict[prop_key] = convert_value(prop_value)
  264. path_dict['relationships'].append(rel_dict)
  265. record_dict[key] = path_dict
  266. # 其他类型直接转换
  267. else:
  268. record_dict[key] = convert_value(value)
  269. data.append(record_dict)
  270. return data
  271. except Exception as e:
  272. logger.error(f"Error executing Cypher query: {str(e)}")
  273. raise e
  274. def get_node(label, **properties):
  275. """
  276. 查询具有给定标签和属性的节点
  277. Args:
  278. label (str): Neo4j节点标签
  279. **properties: 作为关键字参数的节点属性
  280. Returns:
  281. 节点对象,如果不存在则返回None
  282. """
  283. try:
  284. with connect_graph().session() as session:
  285. # 构建查询条件
  286. conditions = []
  287. params = {}
  288. # 处理ID参数
  289. if 'id' in properties:
  290. conditions.append("id(n) = $node_id")
  291. params['node_id'] = properties['id']
  292. # 移除id属性,避免在后续属性匹配中重复
  293. properties_copy = properties.copy()
  294. properties_copy.pop('id')
  295. properties = properties_copy
  296. # 处理其他属性
  297. for key, value in properties.items():
  298. conditions.append(f"n.{key} = ${key}")
  299. params[key] = value
  300. # 构建查询语句
  301. where_clause = " AND ".join(conditions) if conditions else "TRUE"
  302. query = f"""
  303. MATCH (n:{label})
  304. WHERE {where_clause}
  305. RETURN id(n) as node_id
  306. LIMIT 1
  307. """
  308. # 执行查询
  309. result = session.run(query, **params).single()
  310. return result["node_id"] if result else None
  311. except Exception as e:
  312. logger.error(f"Error in get_node: {str(e)}")
  313. return None
  314. def relationship_exists(start_node_id, rel_type, end_node_id, **properties):
  315. """
  316. 检查两个节点之间是否存在指定类型和属性的关系
  317. Args:
  318. start_node_id: 起始节点ID (必须是整数ID)
  319. rel_type: 关系类型
  320. end_node_id: 结束节点ID (必须是整数ID)
  321. **properties: 关系的属性
  322. Returns:
  323. bool: 是否存在关系
  324. """
  325. try:
  326. with connect_graph().session() as session:
  327. # 确保输入的是有效的节点ID
  328. if not isinstance(start_node_id, (int, str)) or not isinstance(end_node_id, (int, str)):
  329. logger.warning(f"无效的节点ID类型: start_node_id={type(start_node_id)}, end_node_id={type(end_node_id)}")
  330. return False
  331. # 转换为整数
  332. try:
  333. start_id = int(start_node_id)
  334. end_id = int(end_node_id)
  335. except (ValueError, TypeError):
  336. logger.warning(f"无法转换节点ID为整数: start_node_id={start_node_id}, end_node_id={end_node_id}")
  337. return False
  338. # 构建查询语句
  339. query = """
  340. MATCH (a)-[r:%s]->(b)
  341. WHERE id(a) = $start_id AND id(b) = $end_id
  342. """ % rel_type
  343. # 添加属性条件
  344. if properties:
  345. conditions = []
  346. for key, value in properties.items():
  347. conditions.append(f"r.{key} = ${key}")
  348. query += " AND " + " AND ".join(conditions)
  349. query += "\nRETURN count(r) > 0 as exists"
  350. # 执行查询
  351. params = {
  352. 'start_id': start_id,
  353. 'end_id': end_id,
  354. **properties
  355. }
  356. result = session.run(query, **params).single()
  357. return result and result["exists"]
  358. except Exception as e:
  359. logger.error(f"Error in relationship_exists: {str(e)}")
  360. return False