graph_retrieval.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #!/usr/bin/python
  2. # -*- coding: <utf-8> -*-
  3. import json
  4. import os
  5. from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
  6. from langchain.prompts import (
  7. PromptTemplate,
  8. )
  9. from typing import List
  10. from langchain.output_parsers import ResponseSchema,StructuredOutputParser
  11. from langchain_community.graphs import Neo4jGraph
  12. from langchain.schema import AIMessage
  13. def connect():
  14. os.environ["NEO4J_URI"] = "bolt://192.168.3.91:27687"
  15. os.environ["NEO4J_USERNAME"] = "neo4j"
  16. os.environ["NEO4J_PASSWORD"] = "123456"
  17. graph = Neo4jGraph()
  18. return graph
  19. def extract_question_info(question:str,llm)->List[str]:
  20. # 定义要接收的响应模式
  21. response_schemas = [
  22. ResponseSchema(name="entity", description="All the person, organization, or business entities that"""
  23. "appear in the text")
  24. ]
  25. # 创建输出解析器
  26. output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
  27. # 获取格式指示
  28. format_instructions = output_parser.get_format_instructions()
  29. # 根据模板创建提示,同时在提示中加入输出解析器的说明
  30. prompt_template = PromptTemplate(
  31. template="Answer the user query.\n{format_instructions}\n{query}\n",
  32. input_variables=["query"],
  33. partial_variables={"format_instructions": format_instructions},
  34. )
  35. # 根据提示准备模型的输入
  36. inputData = prompt_template.format(query=question)
  37. # 获取模型的输出
  38. output = llm.invoke(inputData)
  39. # 去掉 JSON 内容前后的 ```json 和 ``` 标记
  40. if isinstance(output, AIMessage):
  41. # 从 AIMessage 对象中提取内容
  42. json_content = output.content.strip('```json').strip('```').strip()
  43. else:
  44. raise TypeError("Expected an AIMessage object")
  45. # 解析 JSON 内容
  46. data = json.loads(json_content)
  47. # 获取 names 列表
  48. names = data.get('entity',[])
  49. # 用户问题的实体输出
  50. # print(names)
  51. if isinstance(names, str):
  52. names = [names]
  53. return names
  54. def generate_full_text_query(input: str) -> str:
  55. """
  56. Generate a full-text search query for a given input string.
  57. This function constructs a query string suitable for a full-text search.
  58. It processes the input string by splitting it into words and appending a
  59. similarity threshold (~2 changed characters) to each word, then combines
  60. them using the AND operator. Useful for mapping entities from user questions
  61. to database values, and allows for some misspelings.
  62. """
  63. full_text_query = ""
  64. words = [el for el in remove_lucene_chars(input).split() if el]
  65. for word in words[:-1]:
  66. full_text_query += f" {word}~2 AND"
  67. full_text_query += f" {words[-1]}~2"
  68. return full_text_query.strip()
  69. # Fulltext index query
  70. def structured_retriever(llm,graph,question: str) -> str:
  71. """
  72. Collects the neighborhood of entities mentioned
  73. in the question
  74. """
  75. result = ""
  76. # 前面提取到的实体
  77. names = extract_question_info(question,llm)
  78. for entity in names:
  79. # 图谱中匹配到的节点限制返回相似度不得低于0.5
  80. # query = generate_full_text_query(entity)
  81. # print(f"Query:{query}")
  82. response = graph.query(
  83. """CALL db.index.fulltext.queryNodes('dataops', $query, {limit:2})
  84. YIELD node, score
  85. WHERE score >= 0.5
  86. // score 判断
  87. CALL {
  88. WITH node
  89. MATCH (node)-[r]->(neighbor)
  90. RETURN node.name + ' - ' + type(r) + ' -> ' + neighbor.name AS output
  91. UNION ALL
  92. WITH node
  93. MATCH (node)<-[r]-(neighbor)
  94. RETURN neighbor.name + ' - ' + type(r) + ' -> ' + node.name AS output
  95. }
  96. RETURN output LIMIT 50
  97. """,
  98. {"query": entity},
  99. )
  100. result += "\n".join([el['output'] for el in response])
  101. return result
  102. # 非结构化的全文索引
  103. def text_structured_retriever(llm,graph,question: str) -> str:
  104. """
  105. Collects the neighborhood of entities mentioned
  106. in the question
  107. """
  108. result = ""
  109. # 前面提取到的实体
  110. names = extract_question_info(question,llm)
  111. for entity in names:
  112. # 图谱中匹配到的节点限制返回相似度不得低于0.5
  113. # query = generate_full_text_query(entity)
  114. # print(f"Query:{query}")
  115. response = graph.query(
  116. """CALL db.index.fulltext.queryNodes('unstructure', $query, {limit:4})
  117. YIELD node, score
  118. WHERE score >= 0.2
  119. // score 判断
  120. CALL {
  121. WITH node
  122. MATCH (node)-[r]->(neighbor)
  123. RETURN node.name + ' - ' + type(r) + ' -> ' + neighbor.name AS output
  124. UNION ALL
  125. WITH node
  126. MATCH (node)<-[r]-(neighbor)
  127. RETURN neighbor.name + ' - ' + type(r) + ' -> ' + node.name AS output
  128. }
  129. RETURN output LIMIT 50
  130. """,
  131. {"query": entity},
  132. )
  133. result += "\n".join([el['output'] for el in response])
  134. return result