Chat_QA_chain_self.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. from langchain_core.runnables import (
  2. RunnableBranch,
  3. RunnableLambda,
  4. )
  5. from langchain_core.output_parsers import StrOutputParser
  6. from langchain_core.runnables import RunnableParallel, RunnablePassthrough
  7. from langchain.prompts import (
  8. ChatPromptTemplate,
  9. )
  10. from typing import List, Tuple
  11. from langchain.prompts import PromptTemplate
  12. from langchain_core.messages import AIMessage, HumanMessage
  13. from embedding.embedding import get_embedding
  14. from qa_chain.get_vectordb import get_vectordb
  15. from graph.graph_retrieval import connect, structured_retriever, text_structured_retriever
  16. from llm.llm import LLM
  17. import os
  18. class Chat_QA_chain_self:
  19. """
  20. 带历史记录的问答链
  21. - model:调用的模型名称
  22. - temperature:温度系数,控制生成的随机性
  23. - top_k:返回检索的前k个相似文档
  24. - chat_history:历史记录,输入一个列表,默认是一个空列表
  25. - file_path:建库文件所在路径
  26. - persist_path:向量数据库持久化路径
  27. - embeddings:使用的embedding模型
  28. """
  29. def __init__(self, temperature: float = 0.0, top_k: int = 2, chat_history: List[Tuple[str, str]] = [],
  30. file_path: str = None, persist_path: str = None, embedding: str = "m3e"):
  31. self.temperature = temperature
  32. self.top_k = top_k
  33. self.chat_history = chat_history
  34. self.file_path = file_path
  35. self.persist_path = persist_path
  36. self.embedding = get_embedding(embedding)
  37. self.llm_instance = LLM(model_name='qwen')
  38. self.llm = self.llm_instance.get_llm()
  39. self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding)
  40. self.graph = connect()
  41. def clear_chat_history(self):
  42. """
  43. 清空历史记录
  44. :return:
  45. """
  46. self.chat_history = []
  47. # print("Chat history has been cleared.")
  48. def add_to_chat_history(self, human_message: str, ai_message: str):
  49. """
  50. 添加一条聊天记录到历史记录中
  51. :param human_message: 人类用户的消息
  52. :param ai_message: AI的回复消息
  53. :return:
  54. """
  55. self.chat_history.append((human_message, ai_message))
  56. def get_chat_history(self):
  57. """
  58. 获取所有的聊天历史记录
  59. :return: 聊天历史记录列表
  60. """
  61. return self.chat_history
  62. def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
  63. buffer = []
  64. for human, ai in chat_history:
  65. buffer.append(HumanMessage(content=human))
  66. buffer.append(AIMessage(content=ai))
  67. return buffer
  68. def retriever(self, question: str):
  69. structured_data = structured_retriever(self.llm, self.graph, question)
  70. unstructured_data = self.rag_retriever(question)
  71. final_data = f"""Unstructured data:{unstructured_data}\n
  72. Structured data:{structured_data}
  73. """
  74. return final_data
  75. # 非结构化文本图谱+rag
  76. def text_retriever(self, question: str):
  77. structured_data = text_structured_retriever(self.llm, self.graph, question)
  78. unstructured_data = self.rag_retriever(question)
  79. final_data = f"""Structured data:{structured_data}\n
  80. Unstructured data:{unstructured_data}\n
  81. """
  82. print(f"final_data:{final_data}")
  83. return final_data
  84. # 单纯的rag
  85. def rag_retriever(self, question: str):
  86. # 获取与查询问题最相似的文档
  87. # docs = self.vectordb.similarity_search(question, k=self.top_k)
  88. # docs = self.vectordb.max_marginal_relevance_search_by_vector(question)
  89. # 将文档内容拼接成一个字符串
  90. # final_data = "\n".join([doc.page_content for doc in docs])
  91. # print(f"unstructured_data:{final_data}")
  92. retriever = self.vectordb.as_retriever(search_type = 'mmr',search_kwargs = {'k':self.top_k})
  93. docs = retriever.get_relevant_documents(question)
  94. final_data = "\n".join([doc.page_content for doc in docs])
  95. return final_data
  96. def build_chain(self):
  97. llm = self.llm
  98. # Condense a chat history and follow-up question into a standalone question
  99. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  100. in its original language.
  101. Chat History:
  102. {chat_history}
  103. Follow Up Input: {question}
  104. Standalone question:""" # noqa: E501
  105. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  106. _search_query = RunnableBranch(
  107. # If input includes chat_history, we condense it with the follow-up question
  108. (
  109. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  110. run_name="HasChatHistoryCheck"
  111. ), # Condense follow-up question and chat into a standalone_question
  112. RunnablePassthrough.assign(
  113. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  114. )
  115. | CONDENSE_QUESTION_PROMPT
  116. | llm
  117. | StrOutputParser(),
  118. ),
  119. # Else, we have no chat history, so just pass through the question
  120. RunnableLambda(lambda x: x["question"]),
  121. )
  122. template = """Answer the question based only on the following context:
  123. {context}
  124. Question: {question}
  125. Use natural language and be concise.
  126. Answer:"""
  127. prompt = ChatPromptTemplate.from_template(template)
  128. chain = (
  129. RunnableParallel(
  130. {
  131. "context": _search_query | self.retriever,
  132. "question": RunnablePassthrough(),
  133. }
  134. )
  135. | prompt
  136. | llm
  137. | StrOutputParser()
  138. )
  139. return chain
  140. def build_rag_chain(self):
  141. llm = self.llm
  142. # Condense a chat history and follow-up question into a standalone question
  143. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  144. in its original language.
  145. Chat History:
  146. {chat_history}
  147. Follow Up Input: {question}
  148. Standalone question:""" # noqa: E501
  149. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  150. _search_query = RunnableBranch(
  151. # If input includes chat_history, we condense it with the follow-up question
  152. (
  153. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  154. run_name="HasChatHistoryCheck"
  155. ), # Condense follow-up question and chat into a standalone_question
  156. RunnablePassthrough.assign(
  157. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  158. )
  159. | CONDENSE_QUESTION_PROMPT
  160. | llm
  161. | StrOutputParser(),
  162. ),
  163. # Else, we have no chat history, so just pass through the question
  164. RunnableLambda(lambda x: x["question"]),
  165. )
  166. template = """Answer the question based only on the following context:
  167. {context}
  168. Question: {question}
  169. Use natural language and be concise.
  170. Answer:"""
  171. prompt = ChatPromptTemplate.from_template(template)
  172. chain = (
  173. RunnableParallel(
  174. {
  175. "context": _search_query | self.rag_retriever,
  176. "question": RunnablePassthrough(),
  177. }
  178. )
  179. | prompt
  180. | llm
  181. | StrOutputParser()
  182. )
  183. return chain
  184. # 非结构化+图谱
  185. def build_text_chain(self):
  186. llm = self.llm
  187. # Condense a chat history and follow-up question into a standalone question
  188. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  189. in its original language.
  190. Chat History:
  191. {chat_history}
  192. Follow Up Input: {question}
  193. Standalone question:""" # noqa: E501
  194. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  195. _search_query = RunnableBranch(
  196. # If input includes chat_history, we condense it with the follow-up question
  197. (
  198. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  199. run_name="HasChatHistoryCheck"
  200. ), # Condense follow-up question and chat into a standalone_question
  201. RunnablePassthrough.assign(
  202. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  203. )
  204. | CONDENSE_QUESTION_PROMPT
  205. | llm
  206. | StrOutputParser(),
  207. ),
  208. # Else, we have no chat history, so just pass through the question
  209. RunnableLambda(lambda x: x["question"]),
  210. )
  211. template = """Answer the question based only on the following context:
  212. {context}
  213. Question: {question}
  214. Use natural language and be concise.
  215. Answer:"""
  216. prompt = ChatPromptTemplate.from_template(template)
  217. chain = (
  218. RunnableParallel(
  219. {
  220. "context": _search_query | self.text_retriever,
  221. "question": RunnablePassthrough(),
  222. }
  223. )
  224. | prompt
  225. | llm
  226. | StrOutputParser()
  227. )
  228. return chain