ollama_chat.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import requests
  2. import json
  3. from typing import List, Dict, Any
  4. from .base_llm_chat import BaseLLMChat
  5. class OllamaChat(BaseLLMChat):
  6. """Ollama AI聊天实现"""
  7. def __init__(self, config=None):
  8. print("...OllamaChat init...")
  9. super().__init__(config=config)
  10. # Ollama特定的配置参数
  11. self.base_url = config.get("base_url", "http://localhost:11434")
  12. self.model = config.get("model", "qwen2.5:7b")
  13. self.timeout = config.get("timeout", 60)
  14. def submit_prompt(self, prompt, **kwargs) -> str:
  15. if prompt is None:
  16. raise Exception("Prompt is None")
  17. if len(prompt) == 0:
  18. raise Exception("Prompt is empty")
  19. # 计算token数量估计
  20. num_tokens = 0
  21. for message in prompt:
  22. num_tokens += len(message["content"]) / 4
  23. # 确定使用的模型
  24. model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
  25. print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
  26. # 准备Ollama API请求
  27. url = f"{self.base_url}/api/chat"
  28. payload = {
  29. "model": model,
  30. "messages": prompt,
  31. "stream": False,
  32. "options": {
  33. "temperature": self.temperature
  34. }
  35. }
  36. try:
  37. response = requests.post(
  38. url,
  39. json=payload,
  40. timeout=self.timeout,
  41. headers={"Content-Type": "application/json"}
  42. )
  43. response.raise_for_status()
  44. result = response.json()
  45. return result["message"]["content"]
  46. except requests.exceptions.RequestException as e:
  47. print(f"Ollama API请求失败: {e}")
  48. raise Exception(f"Ollama API调用失败: {str(e)}")
  49. def test_connection(self, test_prompt="你好") -> dict:
  50. """测试Ollama连接"""
  51. result = {
  52. "success": False,
  53. "model": self.model,
  54. "base_url": self.base_url,
  55. "message": "",
  56. }
  57. try:
  58. print(f"测试Ollama连接 - 模型: {self.model}")
  59. print(f"Ollama服务地址: {self.base_url}")
  60. # 测试简单对话
  61. prompt = [self.user_message(test_prompt)]
  62. response = self.submit_prompt(prompt)
  63. result["success"] = True
  64. result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
  65. return result
  66. except Exception as e:
  67. result["message"] = f"Ollama连接测试失败: {str(e)}"
  68. return result