test_ollama_integration.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #!/usr/bin/env python3
  2. """
  3. 测试Ollama集成功能的脚本
  4. 用于验证Ollama LLM和Embedding是否正常工作
  5. """
  6. def test_ollama_llm():
  7. """测试Ollama LLM功能"""
  8. print("=== 测试Ollama LLM ===")
  9. try:
  10. from customollama.ollama_chat import OllamaChat
  11. # 测试配置
  12. config = {
  13. "base_url": "http://localhost:11434",
  14. "model": "qwen2.5:7b",
  15. "temperature": 0.7,
  16. "timeout": 60
  17. }
  18. # 创建实例
  19. ollama_chat = OllamaChat(config=config)
  20. # 测试连接
  21. print("测试Ollama连接...")
  22. test_result = ollama_chat.test_connection()
  23. if test_result["success"]:
  24. print(f"✅ Ollama LLM连接成功: {test_result['message']}")
  25. else:
  26. print(f"❌ Ollama LLM连接失败: {test_result['message']}")
  27. return False
  28. # 测试简单对话
  29. print("\n测试简单对话...")
  30. response = ollama_chat.chat_with_llm("你好,请简单介绍一下你自己")
  31. print(f"LLM响应: {response}")
  32. return True
  33. except Exception as e:
  34. print(f"❌ Ollama LLM测试失败: {e}")
  35. import traceback
  36. traceback.print_exc()
  37. return False
  38. def test_ollama_embedding():
  39. """测试Ollama Embedding功能"""
  40. print("\n=== 测试Ollama Embedding ===")
  41. try:
  42. from customollama.ollama_embedding import OllamaEmbeddingFunction
  43. # 创建实例
  44. embedding_func = OllamaEmbeddingFunction(
  45. model_name="nomic-embed-text",
  46. base_url="http://localhost:11434",
  47. embedding_dimension=768
  48. )
  49. # 测试连接
  50. print("测试Ollama Embedding连接...")
  51. test_result = embedding_func.test_connection()
  52. if test_result["success"]:
  53. print(f"✅ Ollama Embedding连接成功: {test_result['message']}")
  54. else:
  55. print(f"❌ Ollama Embedding连接失败: {test_result['message']}")
  56. return False
  57. # 测试生成embedding
  58. print("\n测试生成embedding...")
  59. test_texts = ["这是一个测试文本", "另一个测试文本"]
  60. embeddings = embedding_func(test_texts)
  61. print(f"生成了 {len(embeddings)} 个embedding向量")
  62. for i, emb in enumerate(embeddings):
  63. print(f"文本 {i+1} 的embedding维度: {len(emb)}")
  64. return True
  65. except Exception as e:
  66. print(f"❌ Ollama Embedding测试失败: {e}")
  67. import traceback
  68. traceback.print_exc()
  69. return False
  70. def test_ollama_with_config():
  71. """测试使用配置文件的Ollama功能"""
  72. print("\n=== 测试配置文件中的Ollama设置 ===")
  73. try:
  74. import app_config
  75. from common.utils import print_current_config, is_using_ollama_llm, is_using_ollama_embedding
  76. # 保存原始配置
  77. original_llm_type = app_config.LLM_MODEL_TYPE
  78. original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
  79. try:
  80. # 设置为Ollama模式
  81. app_config.LLM_MODEL_TYPE = "ollama"
  82. app_config.EMBEDDING_MODEL_TYPE = "ollama"
  83. print("当前配置:")
  84. print_current_config()
  85. print(f"\n使用Ollama LLM: {is_using_ollama_llm()}")
  86. print(f"使用Ollama Embedding: {is_using_ollama_embedding()}")
  87. # 测试embedding函数
  88. print("\n测试通过配置获取embedding函数...")
  89. from embedding_function import get_embedding_function
  90. embedding_func = get_embedding_function()
  91. print(f"成功创建embedding函数: {type(embedding_func).__name__}")
  92. # 测试工厂函数(如果Ollama服务可用的话)
  93. print("\n测试工厂函数...")
  94. try:
  95. from vanna_llm_factory import create_vanna_instance
  96. vn = create_vanna_instance()
  97. print(f"✅ 成功创建Vanna实例: {type(vn).__name__}")
  98. return True
  99. except Exception as e:
  100. print(f"⚠️ 工厂函数测试失败(可能是Ollama服务未启动): {e}")
  101. return True # 这不算失败,只是服务未启动
  102. finally:
  103. # 恢复原始配置
  104. app_config.LLM_MODEL_TYPE = original_llm_type
  105. app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
  106. except Exception as e:
  107. print(f"❌ 配置测试失败: {e}")
  108. import traceback
  109. traceback.print_exc()
  110. return False
  111. def test_mixed_configurations():
  112. """测试混合配置(API + Ollama)"""
  113. print("\n=== 测试混合配置 ===")
  114. try:
  115. import app_config
  116. from common.utils import print_current_config
  117. # 保存原始配置
  118. original_llm_type = app_config.LLM_MODEL_TYPE
  119. original_embedding_type = app_config.EMBEDDING_MODEL_TYPE
  120. try:
  121. # 测试配置1:API LLM + Ollama Embedding
  122. print("\n--- 测试: API LLM + Ollama Embedding ---")
  123. app_config.LLM_MODEL_TYPE = "api"
  124. app_config.EMBEDDING_MODEL_TYPE = "ollama"
  125. print_current_config()
  126. from embedding_function import get_embedding_function
  127. embedding_func = get_embedding_function()
  128. print(f"Embedding函数类型: {type(embedding_func).__name__}")
  129. # 测试配置2:Ollama LLM + API Embedding
  130. print("\n--- 测试: Ollama LLM + API Embedding ---")
  131. app_config.LLM_MODEL_TYPE = "ollama"
  132. app_config.EMBEDDING_MODEL_TYPE = "api"
  133. print_current_config()
  134. embedding_func = get_embedding_function()
  135. print(f"Embedding函数类型: {type(embedding_func).__name__}")
  136. print("✅ 混合配置测试通过")
  137. return True
  138. finally:
  139. # 恢复原始配置
  140. app_config.LLM_MODEL_TYPE = original_llm_type
  141. app_config.EMBEDDING_MODEL_TYPE = original_embedding_type
  142. except Exception as e:
  143. print(f"❌ 混合配置测试失败: {e}")
  144. import traceback
  145. traceback.print_exc()
  146. return False
  147. def main():
  148. """主测试函数"""
  149. print("开始测试Ollama集成功能...")
  150. print("注意: 这些测试需要Ollama服务运行在 http://localhost:11434")
  151. print("=" * 60)
  152. results = []
  153. # 测试配置和工具函数(不需要Ollama服务)
  154. results.append(("配置文件测试", test_ollama_with_config()))
  155. results.append(("混合配置测试", test_mixed_configurations()))
  156. # 测试实际的Ollama功能(需要Ollama服务)
  157. print(f"\n{'='*60}")
  158. print("以下测试需要Ollama服务运行,如果失败可能是服务未启动")
  159. print("=" * 60)
  160. results.append(("Ollama LLM", test_ollama_llm()))
  161. results.append(("Ollama Embedding", test_ollama_embedding()))
  162. # 总结
  163. print(f"\n{'='*60}")
  164. print("测试结果总结:")
  165. print("=" * 60)
  166. for test_name, success in results:
  167. status = "✅ 通过" if success else "❌ 失败"
  168. print(f"{test_name}: {status}")
  169. total_passed = sum(1 for _, success in results if success)
  170. print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
  171. if total_passed == len(results):
  172. print("🎉 所有测试都通过了!Ollama集成功能正常。")
  173. else:
  174. print("⚠️ 部分测试失败,请检查Ollama服务是否正常运行。")
  175. if __name__ == "__main__":
  176. main()