Chat_QA_chain_self.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from langchain_core.runnables import (
  2. RunnableBranch,
  3. RunnableLambda,
  4. )
  5. from langchain_core.output_parsers import StrOutputParser
  6. from langchain_core.runnables import RunnableParallel, RunnablePassthrough
  7. from langchain.prompts import (
  8. ChatPromptTemplate,
  9. )
  10. from typing import List, Tuple
  11. from langchain.prompts import PromptTemplate
  12. from langchain_core.messages import AIMessage, HumanMessage
  13. from embedding.embedding import get_embedding
  14. from qa_chain.get_vectordb import get_vectordb
  15. from graph.graph_retrieval import connect, structured_retriever, text_structured_retriever
  16. from llm.llm import LLM
  17. import os
  18. DEFAULT_DB_PATH = os.path.join("..", "knowledge_db")
  19. DEFAULT_PERSIST_PATH = os.path.join("..", "vector_db", "chroma")
  20. class Chat_QA_chain_self:
  21. """
  22. 带历史记录的问答链
  23. - model:调用的模型名称
  24. - temperature:温度系数,控制生成的随机性
  25. - top_k:返回检索的前k个相似文档
  26. - chat_history:历史记录,输入一个列表,默认是一个空列表
  27. - file_path:建库文件所在路径
  28. - persist_path:向量数据库持久化路径
  29. - embeddings:使用的embedding模型
  30. """
  31. def __init__(self, temperature: float = 0.0, top_k: int = 2, chat_history: List[Tuple[str, str]] = [],
  32. file_path: str = None, persist_path: str = None, embedding: str = "m3e"):
  33. self.temperature = temperature
  34. self.top_k = top_k
  35. self.chat_history = chat_history
  36. self.file_path = file_path
  37. self.persist_path = persist_path
  38. self.embedding = get_embedding(embedding)
  39. self.llm_instance = LLM(model_name='qwen')
  40. self.llm = self.llm_instance.get_llm()
  41. self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding)
  42. self.graph = connect()
  43. def clear_chat_history(self):
  44. """
  45. 清空历史记录
  46. :return:
  47. """
  48. self.chat_history = []
  49. # print("Chat history has been cleared.")
  50. def add_to_chat_history(self, human_message: str, ai_message: str):
  51. """
  52. 添加一条聊天记录到历史记录中
  53. :param human_message: 人类用户的消息
  54. :param ai_message: AI的回复消息
  55. :return:
  56. """
  57. self.chat_history.append((human_message, ai_message))
  58. def get_chat_history(self):
  59. """
  60. 获取所有的聊天历史记录
  61. :return: 聊天历史记录列表
  62. """
  63. return self.chat_history
  64. def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
  65. buffer = []
  66. for human, ai in chat_history:
  67. buffer.append(HumanMessage(content=human))
  68. buffer.append(AIMessage(content=ai))
  69. return buffer
  70. def retriever(self, question: str):
  71. structured_data = structured_retriever(self.llm, self.graph, question)
  72. unstructured_data = self.rag_retriever(question)
  73. final_data = f"""Unstructured data:{unstructured_data}\n
  74. Structured data:{structured_data}
  75. """
  76. return final_data
  77. # 非结构化文本图谱+rag
  78. def text_retriever(self, question: str):
  79. structured_data = text_structured_retriever(self.llm, self.graph, question)
  80. unstructured_data = self.rag_retriever(question)
  81. final_data = f"""Structured data:{structured_data}\n
  82. Unstructured data:{unstructured_data}\n
  83. """
  84. print(f"final_data:{final_data}")
  85. return final_data
  86. # 单纯的rag
  87. def rag_retriever(self, question: str):
  88. # 获取与查询问题最相似的文档
  89. # docs = self.vectordb.similarity_search(question, k=self.top_k)
  90. # docs = self.vectordb.max_marginal_relevance_search_by_vector(question)
  91. # 将文档内容拼接成一个字符串
  92. # final_data = "\n".join([doc.page_content for doc in docs])
  93. # print(f"unstructured_data:{final_data}")
  94. retriever = self.vectordb.as_retriever(search_type = 'mmr',search_kwargs = {'k':self.top_k})
  95. docs = retriever.get_relevant_documents(question)
  96. final_data = "\n".join([doc.page_content for doc in docs])
  97. return final_data
  98. def build_chain(self):
  99. llm = self.llm
  100. # Condense a chat history and follow-up question into a standalone question
  101. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  102. in its original language.
  103. Chat History:
  104. {chat_history}
  105. Follow Up Input: {question}
  106. Standalone question:""" # noqa: E501
  107. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  108. _search_query = RunnableBranch(
  109. # If input includes chat_history, we condense it with the follow-up question
  110. (
  111. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  112. run_name="HasChatHistoryCheck"
  113. ), # Condense follow-up question and chat into a standalone_question
  114. RunnablePassthrough.assign(
  115. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  116. )
  117. | CONDENSE_QUESTION_PROMPT
  118. | llm
  119. | StrOutputParser(),
  120. ),
  121. # Else, we have no chat history, so just pass through the question
  122. RunnableLambda(lambda x: x["question"]),
  123. )
  124. template = """Answer the question based only on the following context:
  125. {context}
  126. Question: {question}
  127. Use natural language and be concise.
  128. Answer:"""
  129. prompt = ChatPromptTemplate.from_template(template)
  130. chain = (
  131. RunnableParallel(
  132. {
  133. "context": _search_query | self.retriever,
  134. "question": RunnablePassthrough(),
  135. }
  136. )
  137. | prompt
  138. | llm
  139. | StrOutputParser()
  140. )
  141. return chain
  142. def build_rag_chain(self):
  143. llm = self.llm
  144. # Condense a chat history and follow-up question into a standalone question
  145. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  146. in its original language.
  147. Chat History:
  148. {chat_history}
  149. Follow Up Input: {question}
  150. Standalone question:""" # noqa: E501
  151. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  152. _search_query = RunnableBranch(
  153. # If input includes chat_history, we condense it with the follow-up question
  154. (
  155. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  156. run_name="HasChatHistoryCheck"
  157. ), # Condense follow-up question and chat into a standalone_question
  158. RunnablePassthrough.assign(
  159. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  160. )
  161. | CONDENSE_QUESTION_PROMPT
  162. | llm
  163. | StrOutputParser(),
  164. ),
  165. # Else, we have no chat history, so just pass through the question
  166. RunnableLambda(lambda x: x["question"]),
  167. )
  168. template = """Answer the question based only on the following context:
  169. {context}
  170. Question: {question}
  171. Use natural language and be concise.
  172. Answer:"""
  173. prompt = ChatPromptTemplate.from_template(template)
  174. chain = (
  175. RunnableParallel(
  176. {
  177. "context": _search_query | self.rag_retriever,
  178. "question": RunnablePassthrough(),
  179. }
  180. )
  181. | prompt
  182. | llm
  183. | StrOutputParser()
  184. )
  185. return chain
  186. # 非结构化+图谱
  187. def build_text_chain(self):
  188. llm = self.llm
  189. # Condense a chat history and follow-up question into a standalone question
  190. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  191. in its original language.
  192. Chat History:
  193. {chat_history}
  194. Follow Up Input: {question}
  195. Standalone question:""" # noqa: E501
  196. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  197. _search_query = RunnableBranch(
  198. # If input includes chat_history, we condense it with the follow-up question
  199. (
  200. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  201. run_name="HasChatHistoryCheck"
  202. ), # Condense follow-up question and chat into a standalone_question
  203. RunnablePassthrough.assign(
  204. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  205. )
  206. | CONDENSE_QUESTION_PROMPT
  207. | llm
  208. | StrOutputParser(),
  209. ),
  210. # Else, we have no chat history, so just pass through the question
  211. RunnableLambda(lambda x: x["question"]),
  212. )
  213. template = """Answer the question based only on the following context:
  214. {context}
  215. Question: {question}
  216. Use natural language and be concise.
  217. Answer:"""
  218. prompt = ChatPromptTemplate.from_template(template)
  219. chain = (
  220. RunnableParallel(
  221. {
  222. "context": _search_query | self.text_retriever,
  223. "question": RunnablePassthrough(),
  224. }
  225. )
  226. | prompt
  227. | llm
  228. | StrOutputParser()
  229. )
  230. return chain