vanna_llm_factory.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. Vanna LLM 工厂文件,支持多种LLM提供商和向量数据库
  3. """
  4. import app_config, os
  5. from embedding_function import get_embedding_function
  6. from common.vanna_combinations import get_vanna_class, print_available_combinations
  7. def create_vanna_instance(config_module=None):
  8. """
  9. 工厂函数:创建并初始化一个Vanna实例
  10. 支持API和Ollama两种LLM提供商,以及ChromaDB和PgVector两种向量数据库
  11. Args:
  12. config_module: 配置模块,默认为None时使用 app_config
  13. Returns:
  14. 初始化后的Vanna实例
  15. """
  16. if config_module is None:
  17. config_module = app_config
  18. try:
  19. from common.utils import (
  20. get_current_llm_config,
  21. get_current_vector_db_config,
  22. get_current_model_info,
  23. is_using_ollama_llm,
  24. print_current_config
  25. )
  26. except ImportError:
  27. raise ImportError("无法导入 common.utils,请确保该文件存在")
  28. # 打印当前配置信息
  29. print_current_config()
  30. # 获取当前配置
  31. llm_config = get_current_llm_config()
  32. vector_db_config = get_current_vector_db_config()
  33. model_info = get_current_model_info()
  34. # 获取对应的Vanna组合类
  35. try:
  36. if is_using_ollama_llm():
  37. llm_type = "ollama"
  38. else:
  39. llm_type = model_info["llm_model"].lower()
  40. vector_db_type = model_info["vector_db"].lower()
  41. cls = get_vanna_class(llm_type, vector_db_type)
  42. print(f"创建{llm_type.upper()}+{vector_db_type.upper()}实例")
  43. except ValueError as e:
  44. print(f"错误: {e}")
  45. print("\n可用的组合:")
  46. print_available_combinations()
  47. raise
  48. # 准备配置
  49. config = llm_config.copy()
  50. # 配置向量数据库
  51. if model_info["vector_db"] == "chromadb":
  52. config["path"] = os.path.dirname(os.path.abspath(__file__))
  53. print(f"已配置使用ChromaDB,路径:{config['path']}")
  54. elif model_info["vector_db"] == "pgvector":
  55. # 构建PostgreSQL连接字符串
  56. connection_string = f"postgresql://{vector_db_config['user']}:{vector_db_config['password']}@{vector_db_config['host']}:{vector_db_config['port']}/{vector_db_config['dbname']}"
  57. config["connection_string"] = connection_string
  58. print(f"已配置使用PgVector,连接字符串: {connection_string}")
  59. # 配置embedding函数
  60. embedding_function = get_embedding_function()
  61. config["embedding_function"] = embedding_function
  62. print(f"已配置使用{model_info['embedding_type'].upper()}嵌入模型: {model_info['embedding_model']}")
  63. # 创建实例
  64. vn = cls(config=config)
  65. # 连接到业务数据库
  66. vn.connect_to_postgres(**config_module.APP_DB_CONFIG)
  67. print(f"已连接到业务数据库: "
  68. f"{config_module.APP_DB_CONFIG['host']}:"
  69. f"{config_module.APP_DB_CONFIG['port']}/"
  70. f"{config_module.APP_DB_CONFIG['dbname']}")
  71. return vn