123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import sys
- import os
- from embedding.embedding import get_embedding
- from langchain_community.document_loaders import PyMuPDFLoader,Docx2txtLoader,TextLoader,DirectoryLoader
- from langchain_chroma.vectorstores import Chroma
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- # 首先实现基本配置
- DEFAULT_DB_PATH = r"knowledge_db"
- DEFAULT_PERSIST_PATH = 'vector_db/chroma'
- # 目录下保存不同的文件
- # def get_files(dir_path):
- # file_list = []
- # for filepath, dirnames, filenames in os.walk(dir_path):
- # for filename in filenames:
- # file_list.append(os.path.join(filepath, filename))
- # return file_list
- def get_files(dir_path):
- file_list = []
- for filepath, dirnames, filenames in os.walk(dir_path):
- for filename in filenames:
- file_list.append(os.path.join(filepath, filename))
- return file_list
- def file_loader(file_list, loaders):
- for file in file_list:
- file_type = file.split('.')[-1]
- if file_type == 'pdf':
- loaders.append(PyMuPDFLoader(file))
- elif file_type == 'txt':
- l = DirectoryLoader(file, glob="**/*.txt", loader_cls=TextLoader)
- # loaders.append(TextLoader(file))
- loaders.append(l)
- elif file_type == 'docx':
- loaders.append(Docx2txtLoader(file))
- return loaders
- def load_txt_from_dir(directory_path):
- data = []
- for filename in os.listdir(directory_path):
- if filename.endswith(".txt"):
- # print(filename)
- loader = TextLoader(f'{directory_path}/{filename}',encoding='utf8')
- data.extend(loader.load())
- return data
- def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
- if embeddings in ['openai', 'm3e', 'zhipuai']:
- vectordb = create_db(files, persist_directory, embeddings)
- return ""
- # def create_db(file_path=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="m3e"):
- # """
- # 该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
- #
- # 参数:
- # file: 存放文件的路径。
- # embeddings: 用于生产 Embedding 的模型
- #
- # 返回:
- # vectordb: 创建的数据库。
- # """
- # if files is None:
- # return "can't load empty file"
- #
- # # directory_path = files
- # # file_list = get_files(directory_path)
- #
- # # 加载文件
- # loaders = []
- # # loaders = file_loader(file_list, loaders)
- # # # 加载文档
- # # all_docs = []
- # # for loader in loaders:
- # # docs = loader.load()
- # # all_docs.extend(docs)
- #
- # all_docs = load_txt_from_dir(files)
- #
- # # 切分文档
- # text_splitter = RecursiveCharacterTextSplitter(
- # chunk_size=500, chunk_overlap=150)
- # split_docs = text_splitter.split_documents(all_docs)
- #
- # if isinstance(embeddings, str):
- # embeddings = get_embedding(embedding=embeddings)
- #
- # # 定义持久化路径
- # persist_directory = DEFAULT_PERSIST_PATH
- #
- # # 加载数据库
- # vectordb = Chroma.from_documents(
- # documents=split_docs,
- # embedding=embeddings,
- # persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上
- # )
- # vectordb.persist()
- # return vectordb
- def ensure_directories_exist(file_path, persist_directory):
- if not os.path.exists(file_path):
- os.makedirs(file_path, exist_ok=True)
- if not os.path.exists(persist_directory):
- os.makedirs(persist_directory, exist_ok=True)
- # 创建向量数据库
- def create_db(file_path="./knowledge_db", persist_directory="./vector_db/chroma", embeddings="m3e"):
- ensure_directories_exist(file_path, persist_directory)
- # 加载文件
- all_docs = load_txt_from_dir(file_path)
- # 切分文档
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
- split_docs = text_splitter.split_documents(all_docs)
- if isinstance(embeddings, str):
- embeddings = get_embedding(embedding=embeddings)
- # 创建向量数据库
- vectordb = Chroma.from_documents(
- documents=split_docs,
- embedding=embeddings,
- persist_directory=persist_directory
- )
- vectordb.persist()
- return vectordb
- def presit_knowledge_db(vectordb):
- """
- 该函数用于持久化向量数据库。
- 参数:
- vectordb: 要持久化的向量数据库。
- """
- vectordb.persist()
- def load_knowledge_db(path, embeddings):
- """
- 该函数用于加载向量数据库。
- 参数:
- path: 要加载的向量数据库路径。
- embeddings: 向量数据库使用的 embedding 模型。
- 返回:
- vectordb: 加载的数据库。
- """
- vectordb = Chroma(
- persist_directory=path,
- embedding_function=embeddings
- )
- return vectordb
- if __name__ == "__main__":
- create_db(embeddings="m3e")
|