graph_operations.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. """
  2. Graph Database Core Operations
  3. 提供图数据库的基本操作功能
  4. """
  5. import json
  6. import logging
  7. from flask import current_app
  8. from neo4j import GraphDatabase
  9. from app.services.neo4j_driver import Neo4jDriver
  10. logger = logging.getLogger(__name__)
  11. class MyEncoder(json.JSONEncoder):
  12. """Neo4j数据序列化的自定义JSON编码器"""
  13. def default(self, obj):
  14. if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))):
  15. return super(MyEncoder, self).default(obj)
  16. # 处理DateTime对象
  17. if hasattr(obj, "isoformat"):
  18. return obj.isoformat()
  19. return str(obj)
  20. class GraphOperations:
  21. def __init__(self):
  22. self.driver = Neo4jDriver()
  23. def get_connection(self):
  24. return self.driver.connect()
  25. def close(self):
  26. self.driver.close()
  27. def connect_graph():
  28. """
  29. 连接到Neo4j图数据库
  30. Returns:
  31. Neo4j driver实例
  32. Raises:
  33. ConnectionError: 如果无法连接到Neo4j数据库
  34. ValueError: 如果配置参数缺失
  35. """
  36. try:
  37. # 从Config获取Neo4j连接参数
  38. uri = current_app.config.get("NEO4J_URI")
  39. user = current_app.config.get("NEO4J_USER")
  40. password = current_app.config.get("NEO4J_PASSWORD")
  41. encrypted = current_app.config.get("NEO4J_ENCRYPTED")
  42. # 检查必需的配置参数
  43. if not uri:
  44. raise ValueError("Neo4j URI配置缺失,请检查NEO4J_URI配置")
  45. if not user:
  46. raise ValueError("Neo4j用户配置缺失,请检查NEO4J_USER配置")
  47. if password is None:
  48. raise ValueError("Neo4j密码配置缺失,请检查NEO4J_PASSWORD配置")
  49. # 创建Neo4j驱动
  50. driver = GraphDatabase.driver(
  51. uri=uri, auth=(user, password), encrypted=bool(encrypted)
  52. )
  53. # 验证连接
  54. driver.verify_connectivity()
  55. return driver
  56. except Exception as e:
  57. # 处理连接错误,抛出异常而不是返回None
  58. error_msg = f"无法连接到Neo4j图数据库: {str(e)}"
  59. logger.error(error_msg)
  60. raise ConnectionError(error_msg) from e
  61. def create_or_get_node(label, **properties):
  62. """
  63. 创建具有给定标签和属性的新节点或获取现有节点
  64. 如果具有相同id的节点存在,则更新属性
  65. Args:
  66. label (str): Neo4j节点标签
  67. **properties: 作为关键字参数的节点属性
  68. Returns:
  69. 节点id
  70. """
  71. try:
  72. with connect_graph().session() as session:
  73. # 移除 id_list 属性
  74. if "id_list" in properties:
  75. properties.pop("id_list")
  76. # 检查是否提供了id
  77. if "id" in properties:
  78. node_id = properties["id"]
  79. # 检查节点是否存在
  80. query = f"""
  81. MATCH (n:{label}) WHERE id(n) = $node_id
  82. RETURN n
  83. """
  84. result = session.run(
  85. query, # type: ignore[arg-type]
  86. node_id=int(node_id),
  87. ).single()
  88. if result:
  89. # 节点存在,更新属性
  90. props_string = ", ".join(
  91. [f"n.{key} = ${key}" for key in properties if key != "id"]
  92. )
  93. if props_string:
  94. update_query = f"""
  95. MATCH (n:{label}) WHERE id(n) = $node_id
  96. SET {props_string}
  97. RETURN id(n) as node_id
  98. """
  99. result = session.run(
  100. update_query, # type: ignore[arg-type]
  101. node_id=node_id,
  102. **properties,
  103. ).single()
  104. if result:
  105. return result["node_id"]
  106. return node_id
  107. # 如果到这里,则创建新节点
  108. props_keys = ", ".join([f"{key}: ${key}" for key in properties])
  109. create_query = f"""
  110. CREATE (n:{label} {{{props_keys}}})
  111. RETURN id(n) as node_id
  112. """
  113. result = session.run(
  114. create_query, # type: ignore[arg-type]
  115. **properties,
  116. ).single()
  117. if result:
  118. return result["node_id"]
  119. return None
  120. except Exception as e:
  121. logger.error(f"Error in create_or_get_node: {str(e)}")
  122. raise e
  123. def create_relationship(start_node, end_node, relationship_type, properties=None):
  124. """
  125. 创建两个节点之间的关系
  126. Args:
  127. start_node: 起始节点
  128. end_node: 结束节点
  129. relationship_type: 关系类型
  130. properties: 关系属性
  131. Returns:
  132. 创建的关系对象
  133. """
  134. if not hasattr(start_node, "id") or not hasattr(end_node, "id"):
  135. raise ValueError("Invalid node objects provided")
  136. if properties is None:
  137. properties = {}
  138. query = (
  139. """
  140. MATCH (start), (end)
  141. WHERE id(start) = $start_id AND id(end) = $end_id
  142. MERGE (start)-[r:%s]->(end)
  143. SET r += $properties
  144. RETURN r
  145. """
  146. % relationship_type
  147. )
  148. with connect_graph().session() as session:
  149. result = session.run(
  150. query, # type: ignore[arg-type]
  151. start_id=start_node.id,
  152. end_id=end_node.id,
  153. properties=properties,
  154. )
  155. single_result = result.single()
  156. return single_result["r"] if single_result else None
  157. def get_subgraph(node_ids, rel_types=None, max_depth=1):
  158. """
  159. 获取以指定节点为起点的子图
  160. Args:
  161. node_ids: 节点ID列表
  162. rel_types: 关系类型列表(可选)
  163. max_depth: 最大深度,默认为1
  164. Returns:
  165. 包含节点和关系的字典
  166. """
  167. try:
  168. # 处理节点ID列表
  169. node_ids_str = ", ".join([str(nid) for nid in node_ids])
  170. # 处理关系类型过滤
  171. rel_filter = ""
  172. if rel_types:
  173. rel_types_str = "|".join(rel_types)
  174. rel_filter = f":{rel_types_str}"
  175. # 构建Cypher语句
  176. cypher = f"""
  177. MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m)
  178. WHERE id(n) IN [{node_ids_str}]
  179. RETURN path
  180. """
  181. # 执行查询
  182. with connect_graph().session() as session:
  183. result = session.run(cypher) # type: ignore[arg-type]
  184. # 处理结果为图谱数据
  185. nodes = {}
  186. relationships = {}
  187. for record in result:
  188. path = record["path"]
  189. # 处理节点
  190. for node in path.nodes:
  191. if node.id not in nodes:
  192. node_dict = dict(node)
  193. node_dict["id"] = node.id
  194. node_dict["labels"] = list(node.labels)
  195. nodes[node.id] = node_dict
  196. # 处理关系
  197. for rel in path.relationships:
  198. if rel.id not in relationships:
  199. rel_dict = dict(rel)
  200. rel_dict["id"] = rel.id
  201. rel_dict["type"] = rel.type
  202. rel_dict["source"] = rel.start_node.id
  203. rel_dict["target"] = rel.end_node.id
  204. relationships[rel.id] = rel_dict
  205. # 转换为列表形式
  206. graph_data = {
  207. "nodes": list(nodes.values()),
  208. "relationships": list(relationships.values()),
  209. }
  210. return graph_data
  211. except Exception as e:
  212. logger.error(f"Error getting subgraph: {str(e)}")
  213. raise e
  214. def execute_cypher_query(cypher, params=None):
  215. """
  216. 执行Cypher查询并返回结果
  217. Args:
  218. cypher: Cypher查询语句
  219. params: 查询参数(可选)
  220. Returns:
  221. 查询结果的列表
  222. """
  223. if params is None:
  224. params = {}
  225. def convert_value(value):
  226. """转换Neo4j返回的值为JSON可序列化的格式"""
  227. # 处理DateTime对象
  228. if hasattr(value, "isoformat"):
  229. return value.isoformat()
  230. # 处理Date对象
  231. elif (
  232. hasattr(value, "year") and hasattr(value, "month") and hasattr(value, "day")
  233. ):
  234. return str(value)
  235. # 处理Time对象
  236. elif (
  237. hasattr(value, "hour")
  238. and hasattr(value, "minute")
  239. and hasattr(value, "second")
  240. ):
  241. return str(value)
  242. # 处理其他对象
  243. else:
  244. return value
  245. try:
  246. with connect_graph().session() as session:
  247. result = session.run(cypher, **params)
  248. # 处理查询结果
  249. data = []
  250. for record in result:
  251. record_dict = {}
  252. for key, value in record.items():
  253. # 节点处理
  254. if (
  255. hasattr(value, "id")
  256. and hasattr(value, "labels")
  257. and hasattr(value, "items")
  258. ):
  259. node_dict = {}
  260. for prop_key, prop_value in dict(value).items():
  261. node_dict[prop_key] = convert_value(prop_value)
  262. node_dict["_id"] = value.id
  263. node_dict["_labels"] = list(value.labels)
  264. record_dict[key] = node_dict
  265. # 关系处理
  266. elif (
  267. hasattr(value, "id")
  268. and hasattr(value, "type")
  269. and hasattr(value, "start_node")
  270. ):
  271. rel_dict = {}
  272. for prop_key, prop_value in dict(value).items():
  273. rel_dict[prop_key] = convert_value(prop_value)
  274. rel_dict["_id"] = value.id
  275. rel_dict["_type"] = value.type
  276. rel_dict["_start_node_id"] = value.start_node.id
  277. rel_dict["_end_node_id"] = value.end_node.id
  278. record_dict[key] = rel_dict
  279. # 路径处理
  280. elif (
  281. hasattr(value, "start_node")
  282. and hasattr(value, "end_node")
  283. and hasattr(value, "nodes")
  284. ):
  285. path_dict = {"nodes": [], "relationships": []}
  286. # 处理路径中的节点
  287. for node in value.nodes:
  288. node_dict = {}
  289. for prop_key, prop_value in dict(node).items():
  290. node_dict[prop_key] = convert_value(prop_value)
  291. path_dict["nodes"].append(node_dict)
  292. # 处理路径中的关系
  293. for rel in value.relationships:
  294. rel_dict = {}
  295. for prop_key, prop_value in dict(rel).items():
  296. rel_dict[prop_key] = convert_value(prop_value)
  297. path_dict["relationships"].append(rel_dict)
  298. record_dict[key] = path_dict
  299. # 其他类型直接转换
  300. else:
  301. record_dict[key] = convert_value(value)
  302. data.append(record_dict)
  303. return data
  304. except Exception as e:
  305. logger.error(f"Error executing Cypher query: {str(e)}")
  306. raise e
  307. def get_node(label, **properties):
  308. """
  309. 查询具有给定标签和属性的节点
  310. Args:
  311. label (str): Neo4j节点标签
  312. **properties: 作为关键字参数的节点属性
  313. Returns:
  314. 节点对象,如果不存在则返回None
  315. """
  316. try:
  317. with connect_graph().session() as session:
  318. # 构建查询条件
  319. conditions = []
  320. params = {}
  321. # 处理ID参数
  322. if "id" in properties:
  323. conditions.append("id(n) = $node_id")
  324. params["node_id"] = properties["id"]
  325. # 移除id属性,避免在后续属性匹配中重复
  326. properties_copy = properties.copy()
  327. properties_copy.pop("id")
  328. properties = properties_copy
  329. # 处理其他属性
  330. for key, value in properties.items():
  331. conditions.append(f"n.{key} = ${key}")
  332. params[key] = value
  333. # 构建查询语句
  334. where_clause = " AND ".join(conditions) if conditions else "TRUE"
  335. query = f"""
  336. MATCH (n:{label})
  337. WHERE {where_clause}
  338. RETURN id(n) as node_id
  339. LIMIT 1
  340. """
  341. # 执行查询
  342. result = session.run(
  343. query, # type: ignore[arg-type]
  344. **params,
  345. ).single()
  346. return result["node_id"] if result else None
  347. except Exception as e:
  348. logger.error(f"Error in get_node: {str(e)}")
  349. return None
  350. def relationship_exists(start_node_id, rel_type, end_node_id, **properties):
  351. """
  352. 检查两个节点之间是否存在指定类型和属性的关系
  353. Args:
  354. start_node_id: 起始节点ID (必须是整数ID)
  355. rel_type: 关系类型
  356. end_node_id: 结束节点ID (必须是整数ID)
  357. **properties: 关系的属性
  358. Returns:
  359. bool: 是否存在关系
  360. """
  361. try:
  362. with connect_graph().session() as session:
  363. # 确保输入的是有效的节点ID
  364. if not isinstance(start_node_id, (int, str)) or not isinstance(
  365. end_node_id, (int, str)
  366. ):
  367. logger.warning(
  368. f"无效的节点ID类型: start_node_id={type(start_node_id)}, end_node_id={type(end_node_id)}"
  369. )
  370. return False
  371. # 转换为整数
  372. try:
  373. start_id = int(start_node_id)
  374. end_id = int(end_node_id)
  375. except (ValueError, TypeError):
  376. logger.warning(
  377. f"无法转换节点ID为整数: start_node_id={start_node_id}, end_node_id={end_node_id}"
  378. )
  379. return False
  380. # 构建查询语句
  381. query = (
  382. """
  383. MATCH (a)-[r:%s]->(b)
  384. WHERE id(a) = $start_id AND id(b) = $end_id
  385. """
  386. % rel_type
  387. )
  388. # 添加属性条件
  389. if properties:
  390. conditions = []
  391. for key, value in properties.items():
  392. conditions.append(f"r.{key} = ${key}")
  393. query += " AND " + " AND ".join(conditions)
  394. query += "\nRETURN count(r) > 0 as exists"
  395. # 执行查询
  396. params = {"start_id": start_id, "end_id": end_id, **properties}
  397. result = session.run(query, **params).single()
  398. return result and result["exists"]
  399. except Exception as e:
  400. logger.error(f"Error in relationship_exists: {str(e)}")
  401. return False