Chat_QA_chain_self.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # from langchain_core.runnables import (
  2. # RunnableBranch,
  3. # RunnableLambda,
  4. # )
  5. # from langchain_core.output_parsers import StrOutputParser
  6. # from langchain_core.runnables import RunnableParallel, RunnablePassthrough
  7. # from langchain.prompts import (
  8. # ChatPromptTemplate,
  9. # )
  10. # from typing import List, Tuple
  11. # from langchain.prompts import PromptTemplate
  12. # from langchain_core.messages import AIMessage, HumanMessage
  13. # from qa_chain.get_vectordb import get_vectordb
  14. # from graph.graph_retrieval import connect, structured_retriever
  15. # from llm.llm import deepseek_llm
  16. # # from llm.llm import qwen_llm
  17. #
  18. #
  19. # class Chat_QA_chain_self:
  20. # """
  21. # 带历史记录的问答链
  22. # - model:调用的模型名称
  23. # - temperature:温度系数,控制生成的随机性
  24. # - top_k:返回检索的前k个相似文档
  25. # - chat_history:历史记录,输入一个列表,默认是一个空列表
  26. # - file_path:建库文件所在路径
  27. # - persist_path:向量数据库持久化路径
  28. # - embeddings:使用的embedding模型
  29. # """
  30. #
  31. # def __init__(self, temperature: float = 0.0, top_k: int = 4, chat_history: List[Tuple[str, str]] = [],
  32. # file_path: str = None, persist_path: str = None, embedding: str = "m3e"):
  33. # self.temperature = temperature
  34. # self.top_k = top_k
  35. # self.chat_history = chat_history
  36. # self.file_path = file_path
  37. # self.persist_path = persist_path
  38. # self.embedding = embedding
  39. # self.llm = deepseek_llm
  40. # self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding)
  41. # self.graph = connect()
  42. #
  43. # def clear_chat_history(self):
  44. # """
  45. # 清空历史记录
  46. # :return:
  47. # """
  48. # self.chat_history = []
  49. # # print("Chat history has been cleared.")
  50. #
  51. # def add_to_chat_history(self, human_message: str, ai_message: str):
  52. # """
  53. # 添加一条聊天记录到历史记录中
  54. # :param human_message: 人类用户的消息
  55. # :param ai_message: AI的回复消息
  56. # :return:
  57. # """
  58. # self.chat_history.append((human_message, ai_message))
  59. #
  60. # def get_chat_history(self):
  61. # """
  62. # 获取所有的聊天历史记录
  63. # :return: 聊天历史记录列表
  64. # """
  65. # return self.chat_history
  66. #
  67. # # 原来的函数
  68. # # def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
  69. # # buffer = []
  70. # # for human, ai in chat_history:
  71. # # buffer.append(HumanMessage(content=human))
  72. # # buffer.append(AIMessage(content=ai))
  73. # # buffer.append(chat_history)
  74. # # return buffer
  75. #
  76. # def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
  77. # buffer = []
  78. # for human, ai in chat_history:
  79. # buffer.append(HumanMessage(content=human))
  80. # buffer.append(AIMessage(content=ai))
  81. # return buffer
  82. #
  83. # def retriever(self, question: str):
  84. # # print(f"Search query: {question}")
  85. # structured_data = structured_retriever(self.llm, self.graph, question)
  86. # unstructured_data = self.vectordb.as_retriever(search_type="similarity",
  87. # search_kwargs={'k': self.top_k}) # 默认similarity,k=4
  88. # final_data = f"""Unstructured data:{unstructured_data}\n
  89. # Structured data:{structured_data}
  90. # """
  91. # # final_data = f"""Unstructured data:{unstructured_data}\n"""
  92. # # print(f"unstructured_data:{unstructured_data}")
  93. # return final_data
  94. #
  95. # # # def build_chain(self, question: str):
  96. # def build_chain(self):
  97. # llm = self.llm
  98. #
  99. # # Condense a chat history and follow-up question into a standalone question
  100. # _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  101. # in its original language.
  102. # Chat History:
  103. # {chat_history}
  104. # Follow Up Input: {question}
  105. # Standalone question:""" # noqa: E501
  106. # CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  107. #
  108. # _search_query = RunnableBranch(
  109. # # If input includes chat_history, we condense it with the follow-up question
  110. # (
  111. # RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  112. # run_name="HasChatHistoryCheck"
  113. # ), # Condense follow-up question and chat into a standalone_question
  114. # RunnablePassthrough.assign(
  115. # chat_history=lambda x: self._format_chat_history(x["chat_history"])
  116. # )
  117. # | CONDENSE_QUESTION_PROMPT
  118. # | llm
  119. # | StrOutputParser(),
  120. # ),
  121. # # Else, we have no chat history, so just pass through the question
  122. # RunnableLambda(lambda x: x["question"]),
  123. # )
  124. #
  125. # template = """Answer the question based only on the following context:
  126. # {context}
  127. #
  128. # Question: {question}
  129. # Use natural language and be concise.
  130. # Answer:"""
  131. # prompt = ChatPromptTemplate.from_template(template)
  132. #
  133. # chain = (
  134. # RunnableParallel(
  135. # {
  136. # "context": _search_query | self.retriever,
  137. # "question": RunnablePassthrough(),
  138. # }
  139. # )
  140. # | prompt
  141. # | llm
  142. # | StrOutputParser()
  143. # )
  144. # return chain
  145. from langchain_core.runnables import (
  146. RunnableBranch,
  147. RunnableLambda,
  148. )
  149. from langchain_core.output_parsers import StrOutputParser
  150. from langchain_core.runnables import RunnableParallel, RunnablePassthrough
  151. from langchain.prompts import (
  152. ChatPromptTemplate,
  153. )
  154. from typing import List, Tuple
  155. from langchain.prompts import PromptTemplate
  156. from langchain_core.messages import AIMessage, HumanMessage
  157. from embedding.embedding import get_embedding
  158. from qa_chain.get_vectordb import get_vectordb
  159. from graph.graph_retrieval import connect, structured_retriever, text_structured_retriever
  160. from llm.llm import LLM
  161. class Chat_QA_chain_self:
  162. """
  163. 带历史记录的问答链
  164. - model:调用的模型名称
  165. - temperature:温度系数,控制生成的随机性
  166. - top_k:返回检索的前k个相似文档
  167. - chat_history:历史记录,输入一个列表,默认是一个空列表
  168. - file_path:建库文件所在路径
  169. - persist_path:向量数据库持久化路径
  170. - embeddings:使用的embedding模型
  171. """
  172. def __init__(self, temperature: float = 0.0, top_k: int = 2, chat_history: List[Tuple[str, str]] = [],
  173. file_path: str = None, persist_path: str = None, embedding: str = "m3e"):
  174. self.temperature = temperature
  175. self.top_k = top_k
  176. self.chat_history = chat_history
  177. self.file_path = file_path
  178. self.persist_path = persist_path
  179. self.embedding = get_embedding(embedding)
  180. self.llm_instance = LLM(model_name='qwen')
  181. self.llm = self.llm_instance.get_llm()
  182. self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding)
  183. self.graph = connect()
  184. def clear_chat_history(self):
  185. """
  186. 清空历史记录
  187. :return:
  188. """
  189. self.chat_history = []
  190. # print("Chat history has been cleared.")
  191. def add_to_chat_history(self, human_message: str, ai_message: str):
  192. """
  193. 添加一条聊天记录到历史记录中
  194. :param human_message: 人类用户的消息
  195. :param ai_message: AI的回复消息
  196. :return:
  197. """
  198. self.chat_history.append((human_message, ai_message))
  199. def get_chat_history(self):
  200. """
  201. 获取所有的聊天历史记录
  202. :return: 聊天历史记录列表
  203. """
  204. return self.chat_history
  205. def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
  206. buffer = []
  207. for human, ai in chat_history:
  208. buffer.append(HumanMessage(content=human))
  209. buffer.append(AIMessage(content=ai))
  210. return buffer
  211. def retriever(self, question: str):
  212. # print(f"Search query: {question}")
  213. structured_data = structured_retriever(self.llm, self.graph, question)
  214. # unstructured_data = self.vectordb.as_retriever(search_type="similarity",
  215. # search_kwargs={'k': self.top_k}) # 默认similarity,k=4
  216. unstructured_data = self.rag_retriever(question)
  217. final_data = f"""Unstructured data:{unstructured_data}\n
  218. Structured data:{structured_data}
  219. """
  220. # final_data = f"""Unstructured data:{unstructured_data}\n"""
  221. # print(f"unstructured_data:{unstructured_data}")
  222. return final_data
  223. # 非结构化文本图谱+rag
  224. def text_retriever(self, question: str):
  225. # print(f"Search query: {question}")
  226. structured_data = text_structured_retriever(self.llm, self.graph, question)
  227. # unstructured_data = self.vectordb.as_retriever(search_type="similarity",
  228. # search_kwargs={'k': self.top_k}) # 默认similarity,k=4
  229. unstructured_data = self.rag_retriever(question)
  230. final_data = f"""Structured data:{structured_data}\n
  231. Unstructured data:{unstructured_data}\n
  232. """
  233. # final_data = f"""Unstructured data:{unstructured_data}\n"""
  234. print(f"final_data:{final_data}")
  235. return final_data
  236. # 单纯的rag
  237. def rag_retriever(self, question: str):
  238. # 获取与查询问题最相似的文档
  239. # docs = self.vectordb.similarity_search(question, k=self.top_k)
  240. # docs = self.vectordb.max_marginal_relevance_search_by_vector(question)
  241. # 将文档内容拼接成一个字符串
  242. # final_data = "\n".join([doc.page_content for doc in docs])
  243. # print(f"unstructured_data:{final_data}")
  244. retriever = self.vectordb.as_retriever(search_type = 'mmr',search_kwargs = {'k':self.top_k})
  245. docs = retriever.get_relevant_documents(question)
  246. final_data = "\n".join([doc.page_content for doc in docs])
  247. return final_data
  248. def build_chain(self):
  249. llm = self.llm
  250. # Condense a chat history and follow-up question into a standalone question
  251. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  252. in its original language.
  253. Chat History:
  254. {chat_history}
  255. Follow Up Input: {question}
  256. Standalone question:""" # noqa: E501
  257. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  258. _search_query = RunnableBranch(
  259. # If input includes chat_history, we condense it with the follow-up question
  260. (
  261. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  262. run_name="HasChatHistoryCheck"
  263. ), # Condense follow-up question and chat into a standalone_question
  264. RunnablePassthrough.assign(
  265. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  266. )
  267. | CONDENSE_QUESTION_PROMPT
  268. | llm
  269. | StrOutputParser(),
  270. ),
  271. # Else, we have no chat history, so just pass through the question
  272. RunnableLambda(lambda x: x["question"]),
  273. )
  274. template = """Answer the question based only on the following context:
  275. {context}
  276. Question: {question}
  277. Use natural language and be concise.
  278. Answer:"""
  279. prompt = ChatPromptTemplate.from_template(template)
  280. chain = (
  281. RunnableParallel(
  282. {
  283. "context": _search_query | self.retriever,
  284. "question": RunnablePassthrough(),
  285. }
  286. )
  287. | prompt
  288. | llm
  289. | StrOutputParser()
  290. )
  291. return chain
  292. def build_rag_chain(self):
  293. llm = self.llm
  294. # Condense a chat history and follow-up question into a standalone question
  295. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  296. in its original language.
  297. Chat History:
  298. {chat_history}
  299. Follow Up Input: {question}
  300. Standalone question:""" # noqa: E501
  301. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  302. _search_query = RunnableBranch(
  303. # If input includes chat_history, we condense it with the follow-up question
  304. (
  305. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  306. run_name="HasChatHistoryCheck"
  307. ), # Condense follow-up question and chat into a standalone_question
  308. RunnablePassthrough.assign(
  309. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  310. )
  311. | CONDENSE_QUESTION_PROMPT
  312. | llm
  313. | StrOutputParser(),
  314. ),
  315. # Else, we have no chat history, so just pass through the question
  316. RunnableLambda(lambda x: x["question"]),
  317. )
  318. template = """Answer the question based only on the following context:
  319. {context}
  320. Question: {question}
  321. Use natural language and be concise.
  322. Answer:"""
  323. prompt = ChatPromptTemplate.from_template(template)
  324. chain = (
  325. RunnableParallel(
  326. {
  327. "context": _search_query | self.rag_retriever,
  328. "question": RunnablePassthrough(),
  329. }
  330. )
  331. | prompt
  332. | llm
  333. | StrOutputParser()
  334. )
  335. return chain
  336. # 非结构化+图谱
  337. def build_text_chain(self):
  338. llm = self.llm
  339. # Condense a chat history and follow-up question into a standalone question
  340. _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
  341. in its original language.
  342. Chat History:
  343. {chat_history}
  344. Follow Up Input: {question}
  345. Standalone question:""" # noqa: E501
  346. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  347. _search_query = RunnableBranch(
  348. # If input includes chat_history, we condense it with the follow-up question
  349. (
  350. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  351. run_name="HasChatHistoryCheck"
  352. ), # Condense follow-up question and chat into a standalone_question
  353. RunnablePassthrough.assign(
  354. chat_history=lambda x: self._format_chat_history(x["chat_history"])
  355. )
  356. | CONDENSE_QUESTION_PROMPT
  357. | llm
  358. | StrOutputParser(),
  359. ),
  360. # Else, we have no chat history, so just pass through the question
  361. RunnableLambda(lambda x: x["question"]),
  362. )
  363. template = """Answer the question based only on the following context:
  364. {context}
  365. Question: {question}
  366. Use natural language and be concise.
  367. Answer:"""
  368. prompt = ChatPromptTemplate.from_template(template)
  369. chain = (
  370. RunnableParallel(
  371. {
  372. "context": _search_query | self.text_retriever,
  373. "question": RunnablePassthrough(),
  374. }
  375. )
  376. | prompt
  377. | llm
  378. | StrOutputParser()
  379. )
  380. return chain