vanna_llm_factory.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. Vanna LLM 工厂文件,支持多种LLM提供商和向量数据库
  3. """
  4. import app_config, os
  5. from core.embedding_function import get_embedding_function
  6. from common.vanna_combinations import get_vanna_class, print_available_combinations
  7. from core.logging import get_vanna_logger
  8. # 初始化日志
  9. logger = get_vanna_logger("VannaFactory")
  10. def create_vanna_instance(config_module=None):
  11. """
  12. 工厂函数:创建并初始化一个Vanna实例
  13. 支持API和Ollama两种LLM提供商,以及ChromaDB和PgVector两种向量数据库
  14. Args:
  15. config_module: 配置模块,默认为None时使用 app_config
  16. Returns:
  17. 初始化后的Vanna实例
  18. """
  19. if config_module is None:
  20. config_module = app_config
  21. try:
  22. from common.utils import (
  23. get_current_llm_config,
  24. get_current_vector_db_config,
  25. get_current_model_info,
  26. is_using_ollama_llm,
  27. print_current_config
  28. )
  29. except ImportError:
  30. raise ImportError("无法导入 common.utils,请确保该文件存在")
  31. # 打印当前配置信息
  32. print_current_config()
  33. # 获取当前配置
  34. llm_config = get_current_llm_config()
  35. vector_db_config = get_current_vector_db_config()
  36. model_info = get_current_model_info()
  37. # 获取对应的Vanna组合类
  38. try:
  39. if is_using_ollama_llm():
  40. llm_type = "ollama"
  41. else:
  42. llm_type = model_info["llm_model"].lower()
  43. vector_db_type = model_info["vector_db"].lower()
  44. cls = get_vanna_class(llm_type, vector_db_type)
  45. logger.info(f"创建{llm_type.upper()}+{vector_db_type.upper()}实例")
  46. except ValueError as e:
  47. logger.error(f"{e}")
  48. logger.info("可用的组合:")
  49. print_available_combinations()
  50. raise
  51. # 准备配置
  52. config = llm_config.copy()
  53. # 配置向量数据库
  54. if model_info["vector_db"] == "chromadb":
  55. config["path"] = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 返回项目根目录
  56. logger.info(f"已配置使用ChromaDB,路径:{config['path']}")
  57. elif model_info["vector_db"] == "pgvector":
  58. # 构建PostgreSQL连接字符串
  59. connection_string = f"postgresql://{vector_db_config['user']}:{vector_db_config['password']}@{vector_db_config['host']}:{vector_db_config['port']}/{vector_db_config['dbname']}"
  60. config["connection_string"] = connection_string
  61. logger.info(f"已配置使用PgVector,连接字符串: {connection_string}")
  62. # 配置embedding函数
  63. embedding_function = get_embedding_function()
  64. config["embedding_function"] = embedding_function
  65. logger.info(f"已配置使用{model_info['embedding_type'].upper()}嵌入模型: {model_info['embedding_model']}")
  66. # 创建实例
  67. vn = cls(config=config)
  68. # 连接到业务数据库
  69. vn.connect_to_postgres(**config_module.APP_DB_CONFIG)
  70. logger.info(f"已连接到业务数据库: "
  71. f"{config_module.APP_DB_CONFIG['host']}:"
  72. f"{config_module.APP_DB_CONFIG['port']}/"
  73. f"{config_module.APP_DB_CONFIG['dbname']}")
  74. return vn