yulongyan_citu vor 10 Monaten
Ursprung
Commit
776d8a95ea
7 geänderte Dateien mit 27 neuen und 207 gelöschten Zeilen
  1. 5 17
      app.py
  2. 1 1
      database/chormadb.py
  3. 7 7
      database/create_db.py
  4. 6 6
      graph/graph_retrieval.py
  5. 0 16
      main.py
  6. 3 155
      qa_chain/Chat_QA_chain_self.py
  7. 5 5
      qa_chain/get_vectordb.py

+ 5 - 17
app.py

@@ -29,8 +29,12 @@ app.add_middleware(
     # max_age=1000
 )
 
+# 定义默认路径
+DEFAULT_DB_PATH = os.path.join(".", "knowledge_db")
+DEFAULT_PERSIST_PATH = os.path.join(".", "vector_db", "chroma")
+
 # 初始化你的Chat_QA_chain_self实例
-chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding='m3e')
+chat_qa_chain = Chat_QA_chain_self(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH,embedding='m3e')
 
 @app.post('/clear/history')
 async def clear_history():
@@ -59,22 +63,6 @@ async def view_history():
         res = failed({}, {"error": f"{e}"})
         return JSONResponse(res)
 
-# @app.post('/ask')
-# async def ask(request: Request):
-#     try:
-#         data = await request.json()
-#         question = data.get('question')
-#         chat_history = data.get('chat_history', [])
-#         response = chat_qa_chain.build_chain(question,chat_history)
-#         # response = chain.invoke({"question": question, "chat_history": chat_history})
-#         # 添加新的聊天记录到历史记录中
-#         chat_qa_chain.add_to_chat_history(question, response)
-#         res = success({"response": response}, "success")
-#         return JSONResponse(res)
-#     except Exception as e:
-#         res = failed({}, {"error": f"{e}"})
-#         return JSONResponse(res)
-
 @app.post('/ask')
 async def ask(request: Request):
     try:

+ 1 - 1
database/chormadb.py

@@ -1,7 +1,7 @@
 import chromadb
 from embedding.embedding import get_embedding
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-from create_db import load_txt_from_dir
+from database.create_db import load_txt_from_dir
 
 DEFAULT_DB_PATH = r"../knowledge_db"
 DEFAULT_PERSIST_PATH = '../vector_db/chroma'

+ 7 - 7
create_db.py → database/create_db.py

@@ -6,8 +6,8 @@ from langchain_chroma.vectorstores import Chroma
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 # 首先实现基本配置
-DEFAULT_DB_PATH = r"knowledge_db"
-DEFAULT_PERSIST_PATH = 'vector_db/chroma'
+# DEFAULT_DB_PATH = r"../knowledge_db"
+# DEFAULT_PERSIST_PATH = '../vector_db/chroma'
 
 # 目录下保存不同的文件
 # def get_files(dir_path):
@@ -48,10 +48,10 @@ def load_txt_from_dir(directory_path):
             data.extend(loader.load())
     return data
 
-def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
-    if embeddings in ['openai', 'm3e', 'zhipuai']:
-        vectordb = create_db(files, persist_directory, embeddings)
-    return ""
+# def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
+#     if embeddings in ['openai', 'm3e', 'zhipuai']:
+#         vectordb = create_db(files, persist_directory, embeddings)
+#     return ""
 
 
 # def create_db(file_path=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="m3e"):
@@ -111,7 +111,7 @@ def ensure_directories_exist(file_path, persist_directory):
         os.makedirs(persist_directory, exist_ok=True)
 
 # 创建向量数据库
-def create_db(file_path="./knowledge_db", persist_directory="./vector_db/chroma", embeddings="m3e"):
+def create_db(file_path, persist_directory, embeddings="m3e"):
     ensure_directories_exist(file_path, persist_directory)
 
     # 加载文件

+ 6 - 6
graph/graph_retrieval.py

@@ -13,13 +13,13 @@ from langchain.schema import AIMessage
 
 
 def connect():
-    os.environ["NEO4J_URI"] = "bolt://172.16.48.8:7687"
-    os.environ["NEO4J_USERNAME"] = "neo4j"
-    os.environ["NEO4J_PASSWORD"] = "!@#qwe123^&*"
-
-    # os.environ["NEO4J_URI"] = "bolt://192.168.3.91:27687"
+    # os.environ["NEO4J_URI"] = "bolt://172.16.48.8:7687"
     # os.environ["NEO4J_USERNAME"] = "neo4j"
-    # os.environ["NEO4J_PASSWORD"] = "citu2099@@CCA."
+    # os.environ["NEO4J_PASSWORD"] = "!@#qwe123^&*"
+
+    os.environ["NEO4J_URI"] = "bolt://192.168.3.91:27687"
+    os.environ["NEO4J_USERNAME"] = "neo4j"
+    os.environ["NEO4J_PASSWORD"] = "citu2099@@CCA."
     graph = Neo4jGraph()
     return graph
 

+ 0 - 16
main.py

@@ -1,16 +0,0 @@
-# This is a sample Python script.
-
-# Press Shift+F10 to execute it or replace it with your code.
-# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
-
-
-def print_hi(name):
-    # Use a breakpoint in the code line below to debug your script.
-    print(f'Hi, {name}')  # Press Ctrl+F8 to toggle the breakpoint.
-
-
-# Press the green button in the gutter to run the script.
-if __name__ == '__main__':
-    print_hi('PyCharm')
-
-# See PyCharm help at https://www.jetbrains.com/help/pycharm/

+ 3 - 155
qa_chain/Chat_QA_chain_self.py

@@ -1,149 +1,3 @@
-# from langchain_core.runnables import (
-#     RunnableBranch,
-#     RunnableLambda,
-# )
-# from langchain_core.output_parsers import StrOutputParser
-# from langchain_core.runnables import RunnableParallel, RunnablePassthrough
-# from langchain.prompts import (
-#         ChatPromptTemplate,
-# )
-# from typing import List, Tuple
-# from langchain.prompts import PromptTemplate
-# from langchain_core.messages import AIMessage, HumanMessage
-# from qa_chain.get_vectordb import get_vectordb
-# from graph.graph_retrieval import connect, structured_retriever
-# from llm.llm import deepseek_llm
-# # from llm.llm import qwen_llm
-#
-#
-# class Chat_QA_chain_self:
-#     """
-#     带历史记录的问答链
-#     - model:调用的模型名称
-#     - temperature:温度系数,控制生成的随机性
-#     - top_k:返回检索的前k个相似文档
-#     - chat_history:历史记录,输入一个列表,默认是一个空列表
-#     - file_path:建库文件所在路径
-#     - persist_path:向量数据库持久化路径
-#     - embeddings:使用的embedding模型
-#     """
-#
-#     def __init__(self, temperature: float = 0.0, top_k: int = 4, chat_history: List[Tuple[str, str]] = [],
-#                  file_path: str = None, persist_path: str = None, embedding: str = "m3e"):
-#         self.temperature = temperature
-#         self.top_k = top_k
-#         self.chat_history = chat_history
-#         self.file_path = file_path
-#         self.persist_path = persist_path
-#         self.embedding = embedding
-#         self.llm = deepseek_llm
-#         self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding)
-#         self.graph = connect()
-#
-#     def clear_chat_history(self):
-#         """
-#         清空历史记录
-#         :return:
-#         """
-#         self.chat_history = []
-#         # print("Chat history has been cleared.")
-#
-#     def add_to_chat_history(self, human_message: str, ai_message: str):
-#         """
-#         添加一条聊天记录到历史记录中
-#         :param human_message: 人类用户的消息
-#         :param ai_message: AI的回复消息
-#         :return:
-#         """
-#         self.chat_history.append((human_message, ai_message))
-#
-#     def get_chat_history(self):
-#         """
-#         获取所有的聊天历史记录
-#         :return: 聊天历史记录列表
-#         """
-#         return self.chat_history
-#
-#     # 原来的函数
-#     # def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
-#     #     buffer = []
-#     #     for human, ai in chat_history:
-#     #         buffer.append(HumanMessage(content=human))
-#     #         buffer.append(AIMessage(content=ai))
-#     #     buffer.append(chat_history)
-#     #     return buffer
-#
-#     def _format_chat_history(self, chat_history: List[Tuple[str, str]]) -> List:
-#         buffer = []
-#         for human, ai in chat_history:
-#             buffer.append(HumanMessage(content=human))
-#             buffer.append(AIMessage(content=ai))
-#         return buffer
-#
-#     def retriever(self, question: str):
-#         # print(f"Search query: {question}")
-#         structured_data = structured_retriever(self.llm, self.graph, question)
-#         unstructured_data = self.vectordb.as_retriever(search_type="similarity",
-#                                                        search_kwargs={'k': self.top_k})  # 默认similarity,k=4
-#         final_data = f"""Unstructured data:{unstructured_data}\n
-#                          Structured data:{structured_data}
-#                         """
-#         # final_data = f"""Unstructured data:{unstructured_data}\n"""
-#         # print(f"unstructured_data:{unstructured_data}")
-#         return final_data
-#
-#     # # def build_chain(self, question: str):
-#     def build_chain(self):
-#         llm = self.llm
-#
-#         # Condense a chat history and follow-up question into a standalone question
-#         _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
-#         in its original language.
-#         Chat History:
-#         {chat_history}
-#         Follow Up Input: {question}
-#         Standalone question:"""  # noqa: E501
-#         CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
-#
-#         _search_query = RunnableBranch(
-#             # If input includes chat_history, we condense it with the follow-up question
-#             (
-#                 RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
-#                     run_name="HasChatHistoryCheck"
-#                 ),  # Condense follow-up question and chat into a standalone_question
-#                 RunnablePassthrough.assign(
-#                     chat_history=lambda x: self._format_chat_history(x["chat_history"])
-#                 )
-#                 | CONDENSE_QUESTION_PROMPT
-#                 | llm
-#                 | StrOutputParser(),
-#             ),
-#             # Else, we have no chat history, so just pass through the question
-#             RunnableLambda(lambda x: x["question"]),
-#         )
-#
-#         template = """Answer the question based only on the following context:
-#         {context}
-#
-#         Question: {question}
-#         Use natural language and be concise.
-#         Answer:"""
-#         prompt = ChatPromptTemplate.from_template(template)
-#
-#         chain = (
-#             RunnableParallel(
-#                 {
-#                     "context": _search_query | self.retriever,
-#                     "question": RunnablePassthrough(),
-#                 }
-#             )
-#             | prompt
-#             | llm
-#             | StrOutputParser()
-#         )
-#         return chain
-
-
 from langchain_core.runnables import (
     RunnableBranch,
     RunnableLambda,
@@ -161,7 +15,10 @@ from embedding.embedding import get_embedding
 from qa_chain.get_vectordb import get_vectordb
 from graph.graph_retrieval import connect, structured_retriever, text_structured_retriever
 from llm.llm import LLM
+import os
 
+DEFAULT_DB_PATH = os.path.join("..", "knowledge_db")
+DEFAULT_PERSIST_PATH = os.path.join("..", "vector_db", "chroma")
 
 class Chat_QA_chain_self:
     """
@@ -220,29 +77,20 @@ class Chat_QA_chain_self:
         return buffer
 
     def retriever(self, question: str):
-        # print(f"Search query: {question}")
         structured_data = structured_retriever(self.llm, self.graph, question)
-        # unstructured_data = self.vectordb.as_retriever(search_type="similarity",
-        #                                                search_kwargs={'k': self.top_k})  # 默认similarity,k=4
         unstructured_data = self.rag_retriever(question)
         final_data = f"""Unstructured data:{unstructured_data}\n
                          Structured data:{structured_data}
                         """
-        # final_data = f"""Unstructured data:{unstructured_data}\n"""
-        # print(f"unstructured_data:{unstructured_data}")
         return final_data
 
     # 非结构化文本图谱+rag
     def text_retriever(self, question: str):
-        # print(f"Search query: {question}")
         structured_data = text_structured_retriever(self.llm, self.graph, question)
-        # unstructured_data = self.vectordb.as_retriever(search_type="similarity",
-        #                                                search_kwargs={'k': self.top_k})  # 默认similarity,k=4
         unstructured_data = self.rag_retriever(question)
         final_data = f"""Structured data:{structured_data}\n
                          Unstructured data:{unstructured_data}\n
                         """
-        # final_data = f"""Unstructured data:{unstructured_data}\n"""
         print(f"final_data:{final_data}")
         return final_data
 

+ 5 - 5
qa_chain/get_vectordb.py

@@ -1,12 +1,12 @@
 import os
-from create_db import create_db, load_knowledge_db
+from database.create_db import create_db, load_knowledge_db
 from embedding.embedding import get_embedding
 
 # 定义默认路径
-DEFAULT_DB_PATH = os.path.join("..", "knowledge_db")
-DEFAULT_PERSIST_PATH = os.path.join("..", "vector_db", "chroma")
+# DEFAULT_DB_PATH = os.path.join("..", "knowledge_db")
+# DEFAULT_PERSIST_PATH = os.path.join("..", "vector_db", "chroma")
 
-def get_vectordb(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH, embedding="m3e"):
+def get_vectordb(file_path, persist_path, embedding="m3e"):
     """
     返回向量数据库对象
     输入参数:
@@ -19,7 +19,7 @@ def get_vectordb(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH, e
     if os.path.exists(persist_path):  # 持久化目录存在
         contents = os.listdir(persist_path)
         if len(contents) == 0:  # 但是下面为空
-            # print("目录为空")
+            print("目录为空")
             create_db(file_path, persist_path, embedding)
             # presit_knowledge_db(vectordb)
             vectordb = load_knowledge_db(persist_path, embedding)