graph_retrieval.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/python
  2. # -*- coding: <utf-8> -*-
  3. import json
  4. import os
  5. from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
  6. from langchain.prompts import (
  7. PromptTemplate,
  8. )
  9. from typing import List
  10. from langchain.output_parsers import ResponseSchema,StructuredOutputParser
  11. from langchain_community.graphs import Neo4jGraph
  12. from langchain.schema import AIMessage
  13. def connect():
  14. # os.environ["NEO4J_URI"] = "bolt://172.16.48.8:7687"
  15. # os.environ["NEO4J_USERNAME"] = "neo4j"
  16. # os.environ["NEO4J_PASSWORD"] = "!@#qwe123^&*"
  17. os.environ["NEO4J_URI"] = "bolt://192.168.3.91:27687"
  18. os.environ["NEO4J_USERNAME"] = "neo4j"
  19. os.environ["NEO4J_PASSWORD"] = "citu2099@@CCA."
  20. graph = Neo4jGraph()
  21. return graph
  22. def extract_question_info(question:str,llm)->List[str]:
  23. # 定义要接收的响应模式
  24. response_schemas = [
  25. ResponseSchema(name="entity", description="All the person, organization, or business entities that"""
  26. "appear in the text")
  27. ]
  28. # 创建输出解析器
  29. output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
  30. # 获取格式指示
  31. format_instructions = output_parser.get_format_instructions()
  32. # 根据模板创建提示,同时在提示中加入输出解析器的说明
  33. prompt_template = PromptTemplate(
  34. template="Answer the user query.\n{format_instructions}\n{query}\n",
  35. input_variables=["query"],
  36. partial_variables={"format_instructions": format_instructions},
  37. )
  38. # 根据提示准备模型的输入
  39. inputData = prompt_template.format(query=question)
  40. # 获取模型的输出
  41. output = llm.invoke(inputData)
  42. # 去掉 JSON 内容前后的 ```json 和 ``` 标记
  43. if isinstance(output, AIMessage):
  44. # 从 AIMessage 对象中提取内容
  45. json_content = output.content.strip('```json').strip('```').strip()
  46. else:
  47. raise TypeError("Expected an AIMessage object")
  48. # 解析 JSON 内容
  49. data = json.loads(json_content)
  50. # 获取 names 列表
  51. names = data.get('entity',[])
  52. # 用户问题的实体输出
  53. # print(names)
  54. if isinstance(names, str):
  55. names = [names]
  56. return names
  57. def generate_full_text_query(input: str) -> str:
  58. """
  59. Generate a full-text search query for a given input string.
  60. This function constructs a query string suitable for a full-text search.
  61. It processes the input string by splitting it into words and appending a
  62. similarity threshold (~2 changed characters) to each word, then combines
  63. them using the AND operator. Useful for mapping entities from user questions
  64. to database values, and allows for some misspelings.
  65. """
  66. full_text_query = ""
  67. words = [el for el in remove_lucene_chars(input).split() if el]
  68. for word in words[:-1]:
  69. full_text_query += f" {word}~2 AND"
  70. full_text_query += f" {words[-1]}~2"
  71. return full_text_query.strip()
  72. # Fulltext index query
  73. def structured_retriever(llm,graph,question: str) -> str:
  74. """
  75. Collects the neighborhood of entities mentioned
  76. in the question
  77. """
  78. result = ""
  79. # 前面提取到的实体
  80. names = extract_question_info(question,llm)
  81. for entity in names:
  82. # 图谱中匹配到的节点限制返回相似度不得低于0.5
  83. # query = generate_full_text_query(entity)
  84. # print(f"Query:{query}")
  85. response = graph.query(
  86. """CALL db.index.fulltext.queryNodes('dataops', $query, {limit:2})
  87. YIELD node, score
  88. WHERE score >= 0.5
  89. // score 判断
  90. CALL {
  91. WITH node
  92. MATCH (node)-[r]->(neighbor)
  93. RETURN node.name + ' - ' + type(r) + ' -> ' + neighbor.name AS output
  94. UNION ALL
  95. WITH node
  96. MATCH (node)<-[r]-(neighbor)
  97. RETURN neighbor.name + ' - ' + type(r) + ' -> ' + node.name AS output
  98. }
  99. RETURN output LIMIT 50
  100. """,
  101. {"query": entity},
  102. )
  103. result += "\n".join([el['output'] for el in response])
  104. return result
  105. # 非结构化的全文索引
  106. def text_structured_retriever(llm,graph,question: str) -> str:
  107. """
  108. Collects the neighborhood of entities mentioned
  109. in the question
  110. """
  111. result = ""
  112. # 前面提取到的实体
  113. names = extract_question_info(question,llm)
  114. for entity in names:
  115. # 图谱中匹配到的节点限制返回相似度不得低于0.5
  116. # query = generate_full_text_query(entity)
  117. # print(f"Query:{query}")
  118. response = graph.query(
  119. """CALL db.index.fulltext.queryNodes('unstructure', $query, {limit:4})
  120. YIELD node, score
  121. WHERE score >= 0.2
  122. // score 判断
  123. CALL {
  124. WITH node
  125. MATCH (node)-[r]->(neighbor)
  126. RETURN node.name + ' - ' + type(r) + ' -> ' + neighbor.name AS output
  127. UNION ALL
  128. WITH node
  129. MATCH (node)<-[r]-(neighbor)
  130. RETURN neighbor.name + ' - ' + type(r) + ' -> ' + node.name AS output
  131. }
  132. RETURN output LIMIT 50
  133. """,
  134. {"query": entity},
  135. )
  136. result += "\n".join([el['output'] for el in response])
  137. return result