""" Graph Database Core Operations 提供图数据库的基本操作功能 """ from neo4j import GraphDatabase from flask import current_app from app.services.neo4j_driver import Neo4jDriver import json import logging logger = logging.getLogger(__name__) class MyEncoder(json.JSONEncoder): """Neo4j数据序列化的自定义JSON编码器""" def default(self, obj): if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))): return super(MyEncoder, self).default(obj) return str(obj) class GraphOperations: def __init__(self): self.driver = Neo4jDriver() def get_connection(self): return self.driver.connect() def close(self): self.driver.close() def connect_graph(): """ 连接到Neo4j图数据库 Returns: Neo4j driver实例,如果连接失败则返回None """ try: # 从Config获取Neo4j连接参数 uri = current_app.config.get('NEO4J_URI') user = current_app.config.get('NEO4J_USER') password = current_app.config.get('NEO4J_PASSWORD') encrypted = current_app.config.get('NEO4J_ENCRYPTED') # 创建Neo4j驱动 driver = GraphDatabase.driver( uri=uri, auth=(user, password), encrypted=encrypted ) # 验证连接 driver.verify_connectivity() return driver except Exception as e: # 处理连接错误 logger.error(f"Error connecting to Neo4j database: {str(e)}") return None def create_or_get_node(label, **properties): """ 创建具有给定标签和属性的新节点或获取现有节点 如果具有相同id的节点存在,则更新属性 Args: label (str): Neo4j节点标签 **properties: 作为关键字参数的节点属性 Returns: 节点id """ try: with connect_graph().session() as session: # 移除 id_list 属性 if 'id_list' in properties: properties.pop('id_list') # 检查是否提供了id if 'id' in properties: node_id = properties['id'] # 检查节点是否存在 query = f""" MATCH (n:{label}) WHERE id(n) = $node_id RETURN n """ result = session.run(query, node_id=node_id).single() if result: # 节点存在,更新属性 props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id']) if props_string: update_query = f""" MATCH (n:{label}) WHERE id(n) = $node_id SET {props_string} RETURN id(n) as node_id """ result = session.run(update_query, node_id=node_id, **properties).single() return result["node_id"] return node_id # 如果到这里,则创建新节点 props_keys = ", ".join([f"{key}: ${key}" for key in properties]) create_query = f""" CREATE (n:{label} {{{props_keys}}}) RETURN id(n) as node_id """ result = session.run(create_query, **properties).single() return result["node_id"] except Exception as e: logger.error(f"Error in create_or_get_node: {str(e)}") raise e def create_relationship(start_node, end_node, relationship_type, properties=None): """ 创建两个节点之间的关系 Args: start_node: 起始节点 end_node: 结束节点 relationship_type: 关系类型 properties: 关系属性 Returns: 创建的关系对象 """ if not hasattr(start_node, 'id') or not hasattr(end_node, 'id'): raise ValueError("Invalid node objects provided") if properties is None: properties = {} query = """ MATCH (start), (end) WHERE id(start) = $start_id AND id(end) = $end_id MERGE (start)-[r:%s]->(end) SET r += $properties RETURN r """ % relationship_type with connect_graph().session() as session: result = session.run(query, start_id=start_node.id, end_id=end_node.id, properties=properties) return result.single()["r"] def get_subgraph(node_ids, rel_types=None, max_depth=1): """ 获取以指定节点为起点的子图 Args: node_ids: 节点ID列表 rel_types: 关系类型列表(可选) max_depth: 最大深度,默认为1 Returns: 包含节点和关系的字典 """ try: # 处理节点ID列表 node_ids_str = ', '.join([str(nid) for nid in node_ids]) # 处理关系类型过滤 rel_filter = '' if rel_types: rel_types_str = '|'.join(rel_types) rel_filter = f":{rel_types_str}" # 构建Cypher语句 cypher = f""" MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m) WHERE id(n) IN [{node_ids_str}] RETURN path """ # 执行查询 with connect_graph().session() as session: result = session.run(cypher) # 处理结果为图谱数据 nodes = {} relationships = {} for record in result: path = record["path"] # 处理节点 for node in path.nodes: if node.id not in nodes: node_dict = dict(node) node_dict['id'] = node.id node_dict['labels'] = list(node.labels) nodes[node.id] = node_dict # 处理关系 for rel in path.relationships: if rel.id not in relationships: rel_dict = dict(rel) rel_dict['id'] = rel.id rel_dict['type'] = rel.type rel_dict['source'] = rel.start_node.id rel_dict['target'] = rel.end_node.id relationships[rel.id] = rel_dict # 转换为列表形式 graph_data = { 'nodes': list(nodes.values()), 'relationships': list(relationships.values()) } return graph_data except Exception as e: logger.error(f"Error getting subgraph: {str(e)}") raise e def execute_cypher_query(cypher, params=None): """ 执行Cypher查询并返回结果 Args: cypher: Cypher查询语句 params: 查询参数(可选) Returns: 查询结果的列表 """ if params is None: params = {} try: with connect_graph().session() as session: result = session.run(cypher, **params) # 处理查询结果 data = [] for record in result: record_dict = {} for key, value in record.items(): # 节点处理 if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'): node_dict = dict(value) node_dict['_id'] = value.id node_dict['_labels'] = list(value.labels) record_dict[key] = node_dict # 关系处理 elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'): rel_dict = dict(value) rel_dict['_id'] = value.id rel_dict['_type'] = value.type rel_dict['_start_node_id'] = value.start_node.id rel_dict['_end_node_id'] = value.end_node.id record_dict[key] = rel_dict # 路径处理 elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'): path_dict = { 'nodes': [dict(node) for node in value.nodes], 'relationships': [dict(rel) for rel in value.relationships] } record_dict[key] = path_dict # 其他类型直接转换 else: record_dict[key] = value data.append(record_dict) return data except Exception as e: logger.error(f"Error executing Cypher query: {str(e)}") raise e def get_node(label, **properties): """ 查询具有给定标签和属性的节点 Args: label (str): Neo4j节点标签 **properties: 作为关键字参数的节点属性 Returns: 节点对象,如果不存在则返回None """ try: with connect_graph().session() as session: # 构建查询条件 conditions = [] params = {} # 处理ID参数 if 'id' in properties: conditions.append("id(n) = $node_id") params['node_id'] = properties['id'] # 移除id属性,避免在后续属性匹配中重复 properties_copy = properties.copy() properties_copy.pop('id') properties = properties_copy # 处理其他属性 for key, value in properties.items(): conditions.append(f"n.{key} = ${key}") params[key] = value # 构建查询语句 where_clause = " AND ".join(conditions) if conditions else "TRUE" query = f""" MATCH (n:{label}) WHERE {where_clause} RETURN id(n) as node_id LIMIT 1 """ # 执行查询 result = session.run(query, **params).single() return result["node_id"] if result else None except Exception as e: logger.error(f"Error in get_node: {str(e)}") return None def relationship_exists(start_node, rel_type, end_node, **properties): """ 检查两个节点之间是否存在指定类型和属性的关系 Args: start_node: 起始节点或节点ID rel_type: 关系类型 end_node: 结束节点或节点ID **properties: 关系的属性 Returns: bool: 是否存在关系 """ try: with connect_graph().session() as session: # 确定节点ID start_id = start_node.id if hasattr(start_node, 'id') else start_node end_id = end_node.id if hasattr(end_node, 'id') else end_node # 构建查询语句 query = """ MATCH (a)-[r:%s]->(b) WHERE id(a) = $start_id AND id(b) = $end_id """ % rel_type # 添加属性条件 if properties: conditions = [] for key, value in properties.items(): conditions.append(f"r.{key} = ${key}") query += " AND " + " AND ".join(conditions) query += "\nRETURN count(r) > 0 as exists" # 执行查询 params = { 'start_id': int(start_id), 'end_id': int(end_id), **properties } result = session.run(query, **params).single() return result and result["exists"] except Exception as e: logger.error(f"Error in relationship_exists: {str(e)}") return False