ollama_chat.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import requests
  2. import json
  3. from vanna.base import VannaBase
  4. from typing import List, Dict, Any
  5. # 导入配置参数
  6. from app_config import REWRITE_QUESTION_ENABLED
  7. class OllamaChat(VannaBase):
  8. def __init__(self, config=None):
  9. print("...OllamaChat init...")
  10. VannaBase.__init__(self, config=config)
  11. print("传入的 config 参数如下:")
  12. for key, value in self.config.items():
  13. print(f" {key}: {value}")
  14. # 默认参数
  15. self.temperature = 0.7
  16. self.base_url = config.get("base_url", "http://localhost:11434")
  17. self.model = config.get("model", "qwen2.5:7b")
  18. self.timeout = config.get("timeout", 60)
  19. if "temperature" in config:
  20. print(f"temperature is changed to: {config['temperature']}")
  21. self.temperature = config["temperature"]
  22. def system_message(self, message: str) -> any:
  23. print(f"system_content: {message}")
  24. return {"role": "system", "content": message}
  25. def user_message(self, message: str) -> any:
  26. print(f"\nuser_content: {message}")
  27. return {"role": "user", "content": message}
  28. def assistant_message(self, message: str) -> any:
  29. print(f"assistant_content: {message}")
  30. return {"role": "assistant", "content": message}
  31. def submit_prompt(self, prompt, **kwargs) -> str:
  32. if prompt is None:
  33. raise Exception("Prompt is None")
  34. if len(prompt) == 0:
  35. raise Exception("Prompt is empty")
  36. # 计算token数量估计
  37. num_tokens = 0
  38. for message in prompt:
  39. num_tokens += len(message["content"]) / 4
  40. # 确定使用的模型
  41. model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
  42. print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
  43. # 准备Ollama API请求
  44. url = f"{self.base_url}/api/chat"
  45. payload = {
  46. "model": model,
  47. "messages": prompt,
  48. "stream": False,
  49. "options": {
  50. "temperature": self.temperature
  51. }
  52. }
  53. try:
  54. response = requests.post(
  55. url,
  56. json=payload,
  57. timeout=self.timeout,
  58. headers={"Content-Type": "application/json"}
  59. )
  60. response.raise_for_status()
  61. result = response.json()
  62. return result["message"]["content"]
  63. except requests.exceptions.RequestException as e:
  64. print(f"Ollama API请求失败: {e}")
  65. raise Exception(f"Ollama API调用失败: {str(e)}")
  66. def generate_sql(self, question: str, **kwargs) -> str:
  67. """重写generate_sql方法,增加异常处理"""
  68. try:
  69. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  70. sql = super().generate_sql(question, **kwargs)
  71. if not sql or sql.strip() == "":
  72. print(f"[WARNING] 生成的SQL为空")
  73. return None
  74. # 检查返回内容是否为有效SQL
  75. sql_lower = sql.lower().strip()
  76. error_indicators = [
  77. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  78. "no relevant", "no suitable", "unable to", "无法", "抱歉"
  79. ]
  80. for indicator in error_indicators:
  81. if indicator in sql_lower:
  82. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  83. return None
  84. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  85. if not any(keyword in sql_lower for keyword in sql_keywords):
  86. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  87. return None
  88. print(f"[SUCCESS] 成功生成SQL: {sql}")
  89. return sql
  90. except Exception as e:
  91. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  92. return None
  93. def generate_question(self, sql: str, **kwargs) -> str:
  94. """根据SQL生成中文问题"""
  95. prompt = [
  96. self.system_message(
  97. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  98. ),
  99. self.user_message(sql)
  100. ]
  101. response = self.submit_prompt(prompt, **kwargs)
  102. return response
  103. def chat_with_llm(self, question: str, **kwargs) -> str:
  104. """直接与LLM对话"""
  105. try:
  106. prompt = [
  107. self.system_message(
  108. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  109. ),
  110. self.user_message(question)
  111. ]
  112. response = self.submit_prompt(prompt, **kwargs)
  113. return response
  114. except Exception as e:
  115. print(f"[ERROR] LLM对话失败: {str(e)}")
  116. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
  117. def test_connection(self, test_prompt="你好") -> dict:
  118. """测试Ollama连接"""
  119. result = {
  120. "success": False,
  121. "model": self.model,
  122. "base_url": self.base_url,
  123. "message": "",
  124. }
  125. try:
  126. print(f"测试Ollama连接 - 模型: {self.model}")
  127. print(f"Ollama服务地址: {self.base_url}")
  128. # 测试简单对话
  129. prompt = [self.user_message(test_prompt)]
  130. response = self.submit_prompt(prompt)
  131. result["success"] = True
  132. result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
  133. return result
  134. except Exception as e:
  135. result["message"] = f"Ollama连接测试失败: {str(e)}"
  136. return result
  137. def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
  138. """
  139. 重写问题合并方法,通过配置参数控制是否启用合并功能
  140. Args:
  141. last_question (str): 上一个问题
  142. new_question (str): 新问题
  143. **kwargs: 其他参数
  144. Returns:
  145. str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
  146. """
  147. # 如果未启用合并功能或没有上一个问题,直接返回新问题
  148. if not REWRITE_QUESTION_ENABLED or last_question is None:
  149. print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
  150. return new_question
  151. print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
  152. print(f"[DEBUG] 上一个问题: {last_question}")
  153. print(f"[DEBUG] 新问题: {new_question}")
  154. try:
  155. prompt = [
  156. self.system_message(
  157. "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
  158. "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
  159. "请用中文回答。"
  160. ),
  161. self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
  162. ]
  163. rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
  164. print(f"[DEBUG] 合并后的问题: {rewritten_question}")
  165. return rewritten_question
  166. except Exception as e:
  167. print(f"[ERROR] 问题合并失败: {str(e)}")
  168. # 如果合并失败,返回新问题
  169. return new_question