vanna_llm_factory.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. Vanna LLM 工厂文件,专注于 ChromaDB 并简化配置。
  3. """
  4. import app_config, os
  5. from vanna.chromadb import ChromaDB_VectorStore # 从 Vanna 系统获取
  6. from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
  7. from customdeepseek.custom_deepseek_chat import DeepSeekChat
  8. from embedding_function import get_embedding_function
  9. from custompgvector import PG_VectorStore
  10. class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
  11. def __init__(self, config=None):
  12. ChromaDB_VectorStore.__init__(self, config=config)
  13. QianWenAI_Chat.__init__(self, config=config)
  14. class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
  15. def __init__(self, config=None):
  16. ChromaDB_VectorStore.__init__(self, config=config)
  17. DeepSeekChat.__init__(self, config=config)
  18. class Vanna_Qwen_PGVector(PG_VectorStore, QianWenAI_Chat):
  19. def __init__(self, config=None):
  20. PG_VectorStore.__init__(self, config=config)
  21. QianWenAI_Chat.__init__(self, config=config)
  22. class Vanna_DeepSeek_PGVector(PG_VectorStore, DeepSeekChat):
  23. def __init__(self, config=None):
  24. PG_VectorStore.__init__(self, config=config)
  25. DeepSeekChat.__init__(self, config=config)
  26. # 组合映射表
  27. LLM_VECTOR_DB_MAP = {
  28. ('deepseek', 'chromadb'): Vanna_DeepSeek_ChromaDB,
  29. ('deepseek', 'pgvector'): Vanna_DeepSeek_PGVector,
  30. ('qwen', 'chromadb'): Vanna_Qwen_ChromaDB,
  31. ('qwen', 'pgvector'): Vanna_Qwen_PGVector,
  32. }
  33. def create_vanna_instance(config_module=None):
  34. """
  35. 工厂函数:创建并初始化一个Vanna实例 (LLM 和 ChromaDB 特定版本)
  36. Args:
  37. config_module: 配置模块,默认为None时使用 app_config
  38. Returns:
  39. 初始化后的Vanna实例
  40. """
  41. if config_module is None:
  42. config_module = app_config
  43. llm_model_name = config_module.LLM_MODEL_NAME.lower()
  44. vector_db_name = config_module.VECTOR_DB_NAME.lower()
  45. if (llm_model_name, vector_db_name) not in LLM_VECTOR_DB_MAP:
  46. raise ValueError(f"不支持的模型类型: {llm_model_name} 或 向量数据库类型: {vector_db_name}")
  47. config = {}
  48. if llm_model_name == "deepseek":
  49. config = config_module.DEEPSEEK_CONFIG.copy()
  50. print(f"创建DeepSeek模型实例,使用模型: {config.get('model', 'deepseek-chat')}")
  51. elif llm_model_name == "qwen":
  52. config = config_module.QWEN_CONFIG.copy()
  53. print(f"创建Qwen模型实例,使用模型: {config.get('model', 'qwen-plus-latest')}")
  54. else:
  55. raise ValueError(f"不支持的模型类型: {llm_model_name}")
  56. if vector_db_name == "chromadb":
  57. config["path"] = os.path.dirname(os.path.abspath(__file__))
  58. print(f"已配置使用ChromaDB作为向量数据库,路径:{config['path']}")
  59. elif vector_db_name == "pgvector":
  60. # 构建PostgreSQL连接字符串
  61. pg_config = config_module.PGVECTOR_CONFIG
  62. connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
  63. config["connection_string"] = connection_string
  64. print(f"已配置使用PgVector作为向量数据库,连接字符串: {connection_string}")
  65. else:
  66. raise ValueError(f"不支持的向量数据库类型: {vector_db_name}")
  67. embedding_function = get_embedding_function()
  68. config["embedding_function"] = embedding_function
  69. print(f"已配置使用 EMBEDDING_CONFIG 中的嵌入模型: {config_module.EMBEDDING_CONFIG['model_name']}, 维度: {config_module.EMBEDDING_CONFIG['embedding_dimension']}")
  70. key = (llm_model_name, vector_db_name)
  71. cls = LLM_VECTOR_DB_MAP.get(key)
  72. if cls is None:
  73. raise ValueError(f"不支持的组合: 模型类型={llm_model_name}, 向量数据库类型={vector_db_name}")
  74. vn = cls(config=config)
  75. vn.connect_to_postgres(**config_module.APP_DB_CONFIG)
  76. print(f"连接到PostgreSQL业务数据库: "
  77. f"{config_module.APP_DB_CONFIG['host']}:"
  78. f"{config_module.APP_DB_CONFIG['port']}/"
  79. f"{config_module.APP_DB_CONFIG['dbname']}")
  80. return vn