create_db.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import sys
  2. import os
  3. from embedding.embedding import get_embedding
  4. from langchain_community.document_loaders import PyMuPDFLoader,Docx2txtLoader,TextLoader,DirectoryLoader
  5. from langchain_chroma.vectorstores import Chroma
  6. from langchain.text_splitter import RecursiveCharacterTextSplitter
  7. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  8. # 首先实现基本配置
  9. # DEFAULT_DB_PATH = r"../knowledge_db"
  10. # DEFAULT_PERSIST_PATH = '../vector_db/chroma'
  11. # 目录下保存不同的文件
  12. # def get_files(dir_path):
  13. # file_list = []
  14. # for filepath, dirnames, filenames in os.walk(dir_path):
  15. # for filename in filenames:
  16. # file_list.append(os.path.join(filepath, filename))
  17. # return file_list
  18. def get_files(dir_path):
  19. file_list = []
  20. for filepath, dirnames, filenames in os.walk(dir_path):
  21. for filename in filenames:
  22. file_list.append(os.path.join(filepath, filename))
  23. return file_list
  24. def file_loader(file_list, loaders):
  25. for file in file_list:
  26. file_type = file.split('.')[-1]
  27. if file_type == 'pdf':
  28. loaders.append(PyMuPDFLoader(file))
  29. elif file_type == 'txt':
  30. l = DirectoryLoader(file, glob="**/*.txt", loader_cls=TextLoader)
  31. # loaders.append(TextLoader(file))
  32. loaders.append(l)
  33. elif file_type == 'docx':
  34. loaders.append(Docx2txtLoader(file))
  35. return loaders
  36. def load_txt_from_dir(directory_path):
  37. data = []
  38. for filename in os.listdir(directory_path):
  39. if filename.endswith(".txt"):
  40. # print(filename)
  41. loader = TextLoader(f'{directory_path}/{filename}',encoding='utf8')
  42. data.extend(loader.load())
  43. return data
  44. # def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
  45. # if embeddings in ['openai', 'm3e', 'zhipuai']:
  46. # vectordb = create_db(files, persist_directory, embeddings)
  47. # return ""
  48. # def create_db(file_path=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="m3e"):
  49. # """
  50. # 该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
  51. #
  52. # 参数:
  53. # file: 存放文件的路径。
  54. # embeddings: 用于生产 Embedding 的模型
  55. #
  56. # 返回:
  57. # vectordb: 创建的数据库。
  58. # """
  59. # if files is None:
  60. # return "can't load empty file"
  61. #
  62. # # directory_path = files
  63. # # file_list = get_files(directory_path)
  64. #
  65. # # 加载文件
  66. # loaders = []
  67. # # loaders = file_loader(file_list, loaders)
  68. # # # 加载文档
  69. # # all_docs = []
  70. # # for loader in loaders:
  71. # # docs = loader.load()
  72. # # all_docs.extend(docs)
  73. #
  74. # all_docs = load_txt_from_dir(files)
  75. #
  76. # # 切分文档
  77. # text_splitter = RecursiveCharacterTextSplitter(
  78. # chunk_size=500, chunk_overlap=150)
  79. # split_docs = text_splitter.split_documents(all_docs)
  80. #
  81. # if isinstance(embeddings, str):
  82. # embeddings = get_embedding(embedding=embeddings)
  83. #
  84. # # 定义持久化路径
  85. # persist_directory = DEFAULT_PERSIST_PATH
  86. #
  87. # # 加载数据库
  88. # vectordb = Chroma.from_documents(
  89. # documents=split_docs,
  90. # embedding=embeddings,
  91. # persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上
  92. # )
  93. # vectordb.persist()
  94. # return vectordb
  95. def ensure_directories_exist(file_path, persist_directory):
  96. if not os.path.exists(file_path):
  97. os.makedirs(file_path, exist_ok=True)
  98. if not os.path.exists(persist_directory):
  99. os.makedirs(persist_directory, exist_ok=True)
  100. # 创建向量数据库
  101. def create_db(file_path, persist_directory, embeddings="m3e"):
  102. ensure_directories_exist(file_path, persist_directory)
  103. # 加载文件
  104. all_docs = load_txt_from_dir(file_path)
  105. # 切分文档
  106. text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
  107. split_docs = text_splitter.split_documents(all_docs)
  108. if isinstance(embeddings, str):
  109. embeddings = get_embedding(embedding=embeddings)
  110. # 创建向量数据库
  111. vectordb = Chroma.from_documents(
  112. documents=split_docs,
  113. embedding=embeddings,
  114. persist_directory=persist_directory
  115. )
  116. vectordb.persist()
  117. return vectordb
  118. def presit_knowledge_db(vectordb):
  119. """
  120. 该函数用于持久化向量数据库。
  121. 参数:
  122. vectordb: 要持久化的向量数据库。
  123. """
  124. vectordb.persist()
  125. def load_knowledge_db(path, embeddings):
  126. """
  127. 该函数用于加载向量数据库。
  128. 参数:
  129. path: 要加载的向量数据库路径。
  130. embeddings: 向量数据库使用的 embedding 模型。
  131. 返回:
  132. vectordb: 加载的数据库。
  133. """
  134. vectordb = Chroma(
  135. persist_directory=path,
  136. embedding_function=embeddings
  137. )
  138. return vectordb
  139. if __name__ == "__main__":
  140. create_db(embeddings="m3e")