""" Graph Database Core Operations 提供图数据库的基本操作功能 """ from neo4j import GraphDatabase from flask import current_app from app.services.neo4j_driver import Neo4jDriver import json import logging logger = logging.getLogger(__name__) class MyEncoder(json.JSONEncoder): """Neo4j数据序列化的自定义JSON编码器""" def default(self, obj): if isinstance(obj, (int, float, str, bool, list, dict, tuple, type(None))): return super(MyEncoder, self).default(obj) return str(obj) class GraphOperations: def __init__(self): self.driver = Neo4jDriver() def get_connection(self): return self.driver.connect() def close(self): self.driver.close() def connect_graph(): """ 连接到Neo4j图数据库 Returns: Neo4j driver实例,如果连接失败则返回None """ try: # 从Config获取Neo4j连接参数 uri = current_app.config.get('NEO4J_URI') user = current_app.config.get('NEO4J_USER') password = current_app.config.get('NEO4J_PASSWORD') encrypted = current_app.config.get('NEO4J_ENCRYPTED') # 创建Neo4j驱动 driver = GraphDatabase.driver( uri=uri, auth=(user, password), encrypted=encrypted ) # 验证连接 driver.verify_connectivity() return driver except Exception as e: # 处理连接错误 logger.error(f"Error connecting to Neo4j database: {str(e)}") return None def create_or_get_node(label, **properties): """ 创建具有给定标签和属性的新节点或获取现有节点 如果具有相同id的节点存在,则更新属性 Args: label (str): Neo4j节点标签 **properties: 作为关键字参数的节点属性 Returns: 节点id """ try: with connect_graph().session() as session: # 检查是否提供了id if 'id' in properties: node_id = properties['id'] # 检查节点是否存在 query = f""" MATCH (n:{label}) WHERE id(n) = $node_id RETURN n """ result = session.run(query, node_id=node_id).single() if result: # 节点存在,更新属性 props_string = ", ".join([f"n.{key} = ${key}" for key in properties if key != 'id']) if props_string: update_query = f""" MATCH (n:{label}) WHERE id(n) = $node_id SET {props_string} RETURN id(n) as node_id """ result = session.run(update_query, node_id=node_id, **properties).single() return result["node_id"] return node_id # 如果到这里,则创建新节点 props_keys = ", ".join([f"{key}: ${key}" for key in properties]) create_query = f""" CREATE (n:{label} {{{props_keys}}}) RETURN id(n) as node_id """ result = session.run(create_query, **properties).single() return result["node_id"] except Exception as e: logger.error(f"Error in create_or_get_node: {str(e)}") raise e def create_relationship(start_node_id, end_node_id, rel_type, **properties): """ 在两个节点之间创建关系 Args: start_node_id: 起始节点ID end_node_id: 结束节点ID rel_type: 关系类型 **properties: 关系的属性 Returns: 关系的ID """ try: # 构建属性部分 properties_str = ', '.join([f"{k}: ${k}" for k in properties.keys()]) properties_part = f" {{{properties_str}}}" if properties else "" # 构建Cypher语句 cypher = f""" MATCH (a), (b) WHERE id(a) = $start_node_id AND id(b) = $end_node_id CREATE (a)-[r:{rel_type}{properties_part}]->(b) RETURN id(r) as rel_id """ # 执行创建 with connect_graph().session() as session: params = { 'start_node_id': int(start_node_id), 'end_node_id': int(end_node_id), **properties } result = session.run(cypher, **params).single() if result: return result["rel_id"] else: logger.error("Failed to create relationship") return None except Exception as e: logger.error(f"Error creating relationship: {str(e)}") raise e def get_subgraph(node_ids, rel_types=None, max_depth=1): """ 获取以指定节点为起点的子图 Args: node_ids: 节点ID列表 rel_types: 关系类型列表(可选) max_depth: 最大深度,默认为1 Returns: 包含节点和关系的字典 """ try: # 处理节点ID列表 node_ids_str = ', '.join([str(nid) for nid in node_ids]) # 处理关系类型过滤 rel_filter = '' if rel_types: rel_types_str = '|'.join(rel_types) rel_filter = f":{rel_types_str}" # 构建Cypher语句 cypher = f""" MATCH path = (n)-[r{rel_filter}*0..{max_depth}]-(m) WHERE id(n) IN [{node_ids_str}] RETURN path """ # 执行查询 with connect_graph().session() as session: result = session.run(cypher) # 处理结果为图谱数据 nodes = {} relationships = {} for record in result: path = record["path"] # 处理节点 for node in path.nodes: if node.id not in nodes: node_dict = dict(node) node_dict['id'] = node.id node_dict['labels'] = list(node.labels) nodes[node.id] = node_dict # 处理关系 for rel in path.relationships: if rel.id not in relationships: rel_dict = dict(rel) rel_dict['id'] = rel.id rel_dict['type'] = rel.type rel_dict['source'] = rel.start_node.id rel_dict['target'] = rel.end_node.id relationships[rel.id] = rel_dict # 转换为列表形式 graph_data = { 'nodes': list(nodes.values()), 'relationships': list(relationships.values()) } return graph_data except Exception as e: logger.error(f"Error getting subgraph: {str(e)}") raise e def execute_cypher_query(cypher, params=None): """ 执行Cypher查询并返回结果 Args: cypher: Cypher查询语句 params: 查询参数(可选) Returns: 查询结果的列表 """ if params is None: params = {} try: with connect_graph().session() as session: result = session.run(cypher, **params) # 处理查询结果 data = [] for record in result: record_dict = {} for key, value in record.items(): # 节点处理 if hasattr(value, 'id') and hasattr(value, 'labels') and hasattr(value, 'items'): node_dict = dict(value) node_dict['_id'] = value.id node_dict['_labels'] = list(value.labels) record_dict[key] = node_dict # 关系处理 elif hasattr(value, 'id') and hasattr(value, 'type') and hasattr(value, 'start_node'): rel_dict = dict(value) rel_dict['_id'] = value.id rel_dict['_type'] = value.type rel_dict['_start_node_id'] = value.start_node.id rel_dict['_end_node_id'] = value.end_node.id record_dict[key] = rel_dict # 路径处理 elif hasattr(value, 'start_node') and hasattr(value, 'end_node') and hasattr(value, 'nodes'): path_dict = { 'nodes': [dict(node) for node in value.nodes], 'relationships': [dict(rel) for rel in value.relationships] } record_dict[key] = path_dict # 其他类型直接转换 else: record_dict[key] = value data.append(record_dict) return data except Exception as e: logger.error(f"Error executing Cypher query: {str(e)}") raise e