get_vectordb.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. import os
  2. from create_db import create_db, load_knowledge_db
  3. from embedding.embedding import get_embedding
  4. # 定义默认路径
  5. DEFAULT_DB_PATH = os.path.join("..", "knowledge_db")
  6. DEFAULT_PERSIST_PATH = os.path.join("..", "vector_db", "chroma")
  7. def get_vectordb(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH, embedding="m3e"):
  8. """
  9. 返回向量数据库对象
  10. 输入参数:
  11. question:
  12. llm:
  13. vectordb:向量数据库(必要参数),一个对象
  14. embedding:qwen
  15. """
  16. embedding = get_embedding(embedding=embedding)
  17. if os.path.exists(persist_path): # 持久化目录存在
  18. contents = os.listdir(persist_path)
  19. if len(contents) == 0: # 但是下面为空
  20. # print("目录为空")
  21. create_db(file_path, persist_path, embedding)
  22. # presit_knowledge_db(vectordb)
  23. vectordb = load_knowledge_db(persist_path, embedding)
  24. else:
  25. # print("目录不为空")
  26. vectordb = load_knowledge_db(persist_path, embedding)
  27. else: # 目录不存在,从头开始创建向量数据库
  28. create_db(file_path, persist_path, embedding)
  29. # presit_knowledge_db(vectordb)
  30. vectordb = load_knowledge_db(persist_path, embedding)
  31. return vectordb