ollama_chat.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import requests
  2. import json
  3. from vanna.base import VannaBase
  4. from typing import List, Dict, Any
  5. class OllamaChat(VannaBase):
  6. def __init__(self, config=None):
  7. print("...OllamaChat init...")
  8. VannaBase.__init__(self, config=config)
  9. print("传入的 config 参数如下:")
  10. for key, value in self.config.items():
  11. print(f" {key}: {value}")
  12. # 默认参数
  13. self.temperature = 0.7
  14. self.base_url = config.get("base_url", "http://localhost:11434")
  15. self.model = config.get("model", "qwen2.5:7b")
  16. self.timeout = config.get("timeout", 60)
  17. if "temperature" in config:
  18. print(f"temperature is changed to: {config['temperature']}")
  19. self.temperature = config["temperature"]
  20. def system_message(self, message: str) -> any:
  21. print(f"system_content: {message}")
  22. return {"role": "system", "content": message}
  23. def user_message(self, message: str) -> any:
  24. print(f"\nuser_content: {message}")
  25. return {"role": "user", "content": message}
  26. def assistant_message(self, message: str) -> any:
  27. print(f"assistant_content: {message}")
  28. return {"role": "assistant", "content": message}
  29. def submit_prompt(self, prompt, **kwargs) -> str:
  30. if prompt is None:
  31. raise Exception("Prompt is None")
  32. if len(prompt) == 0:
  33. raise Exception("Prompt is empty")
  34. # 计算token数量估计
  35. num_tokens = 0
  36. for message in prompt:
  37. num_tokens += len(message["content"]) / 4
  38. # 确定使用的模型
  39. model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
  40. print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
  41. # 准备Ollama API请求
  42. url = f"{self.base_url}/api/chat"
  43. payload = {
  44. "model": model,
  45. "messages": prompt,
  46. "stream": False,
  47. "options": {
  48. "temperature": self.temperature
  49. }
  50. }
  51. try:
  52. response = requests.post(
  53. url,
  54. json=payload,
  55. timeout=self.timeout,
  56. headers={"Content-Type": "application/json"}
  57. )
  58. response.raise_for_status()
  59. result = response.json()
  60. return result["message"]["content"]
  61. except requests.exceptions.RequestException as e:
  62. print(f"Ollama API请求失败: {e}")
  63. raise Exception(f"Ollama API调用失败: {str(e)}")
  64. def generate_sql(self, question: str, **kwargs) -> str:
  65. """重写generate_sql方法,增加异常处理"""
  66. try:
  67. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  68. sql = super().generate_sql(question, **kwargs)
  69. if not sql or sql.strip() == "":
  70. print(f"[WARNING] 生成的SQL为空")
  71. return None
  72. # 检查返回内容是否为有效SQL
  73. sql_lower = sql.lower().strip()
  74. error_indicators = [
  75. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  76. "no relevant", "no suitable", "unable to", "无法", "抱歉"
  77. ]
  78. for indicator in error_indicators:
  79. if indicator in sql_lower:
  80. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  81. return None
  82. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  83. if not any(keyword in sql_lower for keyword in sql_keywords):
  84. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  85. return None
  86. print(f"[SUCCESS] 成功生成SQL: {sql}")
  87. return sql
  88. except Exception as e:
  89. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  90. return None
  91. def generate_question(self, sql: str, **kwargs) -> str:
  92. """根据SQL生成中文问题"""
  93. prompt = [
  94. self.system_message(
  95. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  96. ),
  97. self.user_message(sql)
  98. ]
  99. response = self.submit_prompt(prompt, **kwargs)
  100. return response
  101. def chat_with_llm(self, question: str, **kwargs) -> str:
  102. """直接与LLM对话"""
  103. try:
  104. prompt = [
  105. self.system_message(
  106. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  107. ),
  108. self.user_message(question)
  109. ]
  110. response = self.submit_prompt(prompt, **kwargs)
  111. return response
  112. except Exception as e:
  113. print(f"[ERROR] LLM对话失败: {str(e)}")
  114. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
  115. def test_connection(self, test_prompt="你好") -> dict:
  116. """测试Ollama连接"""
  117. result = {
  118. "success": False,
  119. "model": self.model,
  120. "base_url": self.base_url,
  121. "message": "",
  122. }
  123. try:
  124. print(f"测试Ollama连接 - 模型: {self.model}")
  125. print(f"Ollama服务地址: {self.base_url}")
  126. # 测试简单对话
  127. prompt = [self.user_message(test_prompt)]
  128. response = self.submit_prompt(prompt)
  129. result["success"] = True
  130. result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
  131. return result
  132. except Exception as e:
  133. result["message"] = f"Ollama连接测试失败: {str(e)}"
  134. return result