test_vanna_combinations.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #!/usr/bin/env python3
  2. """
  3. 测试统一的Vanna组合类文件
  4. 验证common/vanna_combinations.py中的功能
  5. """
  6. def test_import_combinations():
  7. """测试导入组合类"""
  8. print("=== 测试导入组合类 ===")
  9. try:
  10. from common.vanna_combinations import (
  11. Vanna_Qwen_ChromaDB,
  12. Vanna_DeepSeek_ChromaDB,
  13. Vanna_Qwen_PGVector,
  14. Vanna_DeepSeek_PGVector,
  15. Vanna_Ollama_ChromaDB,
  16. Vanna_Ollama_PGVector,
  17. get_vanna_class,
  18. list_available_combinations,
  19. print_available_combinations
  20. )
  21. print("✅ 成功导入所有组合类和工具函数")
  22. return True
  23. except ImportError as e:
  24. print(f"❌ 导入失败: {e}")
  25. return False
  26. def test_get_vanna_class():
  27. """测试get_vanna_class函数"""
  28. print("\n=== 测试get_vanna_class函数 ===")
  29. try:
  30. from common.vanna_combinations import get_vanna_class
  31. # 测试有效组合
  32. test_cases = [
  33. ("qwen", "chromadb"),
  34. ("deepseek", "chromadb"),
  35. ("qwen", "pgvector"),
  36. ("deepseek", "pgvector"),
  37. ("ollama", "chromadb"),
  38. ("ollama", "pgvector"),
  39. ]
  40. for llm_type, vector_db_type in test_cases:
  41. try:
  42. cls = get_vanna_class(llm_type, vector_db_type)
  43. print(f"✅ {llm_type} + {vector_db_type} -> {cls.__name__}")
  44. except Exception as e:
  45. print(f"⚠️ {llm_type} + {vector_db_type} -> 错误: {e}")
  46. # 测试无效组合
  47. print("\n测试无效组合:")
  48. try:
  49. get_vanna_class("invalid_llm", "chromadb")
  50. print("❌ 应该抛出异常但没有")
  51. return False
  52. except ValueError:
  53. print("✅ 正确处理无效LLM类型")
  54. try:
  55. get_vanna_class("qwen", "invalid_db")
  56. print("❌ 应该抛出异常但没有")
  57. return False
  58. except ValueError:
  59. print("✅ 正确处理无效向量数据库类型")
  60. return True
  61. except Exception as e:
  62. print(f"❌ 测试失败: {e}")
  63. return False
  64. def test_list_available_combinations():
  65. """测试列出可用组合"""
  66. print("\n=== 测试列出可用组合 ===")
  67. try:
  68. from common.vanna_combinations import list_available_combinations, print_available_combinations
  69. # 获取可用组合
  70. combinations = list_available_combinations()
  71. print(f"可用组合数据结构: {combinations}")
  72. # 打印可用组合
  73. print("\n打印可用组合:")
  74. print_available_combinations()
  75. return True
  76. except Exception as e:
  77. print(f"❌ 测试失败: {e}")
  78. import traceback
  79. traceback.print_exc()
  80. return False
  81. def test_class_instantiation():
  82. """测试类实例化(不需要实际服务)"""
  83. print("\n=== 测试类实例化 ===")
  84. try:
  85. from common.vanna_combinations import get_vanna_class
  86. # 测试ChromaDB组合(通常可用)
  87. test_cases = [
  88. ("qwen", "chromadb"),
  89. ("deepseek", "chromadb"),
  90. ]
  91. for llm_type, vector_db_type in test_cases:
  92. try:
  93. cls = get_vanna_class(llm_type, vector_db_type)
  94. # 尝试创建实例(使用空配置)
  95. instance = cls(config={})
  96. print(f"✅ 成功创建 {cls.__name__} 实例")
  97. # 检查实例类型
  98. print(f" 实例类型: {type(instance)}")
  99. print(f" MRO: {[c.__name__ for c in type(instance).__mro__[:3]]}")
  100. except Exception as e:
  101. print(f"⚠️ 创建 {llm_type}+{vector_db_type} 实例失败: {e}")
  102. return True
  103. except Exception as e:
  104. print(f"❌ 测试失败: {e}")
  105. import traceback
  106. traceback.print_exc()
  107. return False
  108. def test_factory_integration():
  109. """测试与工厂函数的集成"""
  110. print("\n=== 测试与工厂函数的集成 ===")
  111. try:
  112. import app_config
  113. from common.utils import print_current_config
  114. # 保存原始配置
  115. original_llm_type = app_config.LLM_MODEL_TYPE
  116. original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
  117. original_vector_db = app_config.VECTOR_DB_NAME
  118. try:
  119. # 测试不同配置
  120. test_configs = [
  121. ("api", "api", "qwen", "chromadb"),
  122. ("api", "api", "deepseek", "chromadb"),
  123. ("ollama", "ollama", None, "chromadb"),
  124. ]
  125. for llm_type, emb_type, llm_name, vector_db in test_configs:
  126. print(f"\n--- 测试配置: LLM={llm_type}, EMB={emb_type}, MODEL={llm_name}, DB={vector_db} ---")
  127. # 设置配置
  128. app_config.LLM_MODEL_TYPE = llm_type
  129. app_config.EMBEDDING_MODEL_TYPE = emb_type
  130. if llm_name:
  131. app_config.LLM_MODEL_NAME = llm_name
  132. app_config.VECTOR_DB_NAME = vector_db
  133. # 打印当前配置
  134. print_current_config()
  135. # 测试工厂函数(不实际创建实例,只测试类选择)
  136. try:
  137. from vanna_llm_factory import create_vanna_instance
  138. from common.utils import get_current_model_info, is_using_ollama_llm
  139. from common.vanna_combinations import get_vanna_class
  140. model_info = get_current_model_info()
  141. if is_using_ollama_llm():
  142. selected_llm_type = "ollama"
  143. else:
  144. selected_llm_type = model_info["llm_model"].lower()
  145. selected_vector_db = model_info["vector_db"].lower()
  146. cls = get_vanna_class(selected_llm_type, selected_vector_db)
  147. print(f"✅ 工厂函数会选择: {cls.__name__}")
  148. except Exception as e:
  149. print(f"⚠️ 工厂函数测试失败: {e}")
  150. return True
  151. finally:
  152. # 恢复原始配置
  153. app_config.LLM_MODEL_TYPE = original_llm_type
  154. app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
  155. app_config.VECTOR_DB_NAME = original_vector_db
  156. except Exception as e:
  157. print(f"❌ 测试失败: {e}")
  158. import traceback
  159. traceback.print_exc()
  160. return False
  161. def main():
  162. """主测试函数"""
  163. print("开始测试统一的Vanna组合类...")
  164. print("=" * 60)
  165. results = []
  166. # 运行所有测试
  167. results.append(("导入组合类", test_import_combinations()))
  168. results.append(("get_vanna_class函数", test_get_vanna_class()))
  169. results.append(("列出可用组合", test_list_available_combinations()))
  170. results.append(("类实例化", test_class_instantiation()))
  171. results.append(("工厂函数集成", test_factory_integration()))
  172. # 总结
  173. print(f"\n{'='*60}")
  174. print("测试结果总结:")
  175. print("=" * 60)
  176. for test_name, success in results:
  177. status = "✅ 通过" if success else "❌ 失败"
  178. print(f"{test_name}: {status}")
  179. total_passed = sum(1 for _, success in results if success)
  180. print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
  181. if total_passed == len(results):
  182. print("🎉 所有测试都通过了!统一组合类文件工作正常。")
  183. else:
  184. print("⚠️ 部分测试失败,请检查相关依赖和配置。")
  185. if __name__ == "__main__":
  186. main()