Browse Source

第一次提交全部代码

yulongyan_citu 10 months ago
parent
commit
47e351071e
2 changed files with 64 additions and 33 deletions
  1. 1 1
      app.py
  2. 63 32
      database/create_db.py

+ 1 - 1
app.py

@@ -30,7 +30,7 @@ app.add_middleware(
 )
 
 # 初始化你的Chat_QA_chain_self实例
-chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma")
+chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding='m3e')
 
 @app.post('/clear/history')
 async def clear_history():

+ 63 - 32
database/create_db.py

@@ -54,50 +54,81 @@ def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory
     return ""
 
 
-def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="m3e"):
-    """
-    该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
-
-    参数:
-    file: 存放文件的路径。
-    embeddings: 用于生产 Embedding 的模型
-
-    返回:
-    vectordb: 创建的数据库。
-    """
-    if files is None:
-        return "can't load empty file"
-
-    # directory_path = files
-    # file_list = get_files(directory_path)
+# def create_db(file_path=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="m3e"):
+#     """
+#     该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
+#
+#     参数:
+#     file: 存放文件的路径。
+#     embeddings: 用于生产 Embedding 的模型
+#
+#     返回:
+#     vectordb: 创建的数据库。
+#     """
+#     if files is None:
+#         return "can't load empty file"
+#
+#     # directory_path = files
+#     # file_list = get_files(directory_path)
+#
+#     # 加载文件
+#     loaders = []
+#     # loaders = file_loader(file_list, loaders)
+#     # # 加载文档
+#     # all_docs = []
+#     # for loader in loaders:
+#     #     docs = loader.load()
+#     #     all_docs.extend(docs)
+#
+#     all_docs = load_txt_from_dir(files)
+#
+#     # 切分文档
+#     text_splitter = RecursiveCharacterTextSplitter(
+#         chunk_size=500, chunk_overlap=150)
+#     split_docs = text_splitter.split_documents(all_docs)
+#
+#     if isinstance(embeddings, str):
+#         embeddings = get_embedding(embedding=embeddings)
+#
+#     # 定义持久化路径
+#     persist_directory = DEFAULT_PERSIST_PATH
+#
+#     # 加载数据库
+#     vectordb = Chroma.from_documents(
+#         documents=split_docs,
+#         embedding=embeddings,
+#         persist_directory=persist_directory  # 允许我们将persist_directory目录保存到磁盘上
+#     )
+#     vectordb.persist()
+#     return vectordb
+
+
+def ensure_directories_exist(file_path, persist_directory):
+    if not os.path.exists(file_path):
+        os.makedirs(file_path, exist_ok=True)
+
+    if not os.path.exists(persist_directory):
+        os.makedirs(persist_directory, exist_ok=True)
+
+# 创建向量数据库
+def create_db(file_path="./knowledge_db", persist_directory="./vector_db/chroma", embeddings="m3e"):
+    ensure_directories_exist(file_path, persist_directory)
 
     # 加载文件
-    loaders = []
-    # loaders = file_loader(file_list, loaders)
-    # # 加载文档
-    # all_docs = []
-    # for loader in loaders:
-    #     docs = loader.load()
-    #     all_docs.extend(docs)
-
-    all_docs = load_txt_from_dir(files)
+    all_docs = load_txt_from_dir(file_path)
 
     # 切分文档
-    text_splitter = RecursiveCharacterTextSplitter(
-        chunk_size=500, chunk_overlap=150)
+    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
     split_docs = text_splitter.split_documents(all_docs)
 
     if isinstance(embeddings, str):
         embeddings = get_embedding(embedding=embeddings)
 
-    # 定义持久化路径
-    persist_directory = DEFAULT_PERSIST_PATH
-
-    # 加载数据库
+    # 创建向量数据库
     vectordb = Chroma.from_documents(
         documents=split_docs,
         embedding=embeddings,
-        persist_directory=persist_directory  # 允许我们将persist_directory目录保存到磁盘上
+        persist_directory=persist_directory
     )
     vectordb.persist()
     return vectordb