123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import requests
- import json
- from vanna.base import VannaBase
- from typing import List, Dict, Any
- # 导入配置参数
- from app_config import REWRITE_QUESTION_ENABLED
- class OllamaChat(VannaBase):
- def __init__(self, config=None):
- print("...OllamaChat init...")
- VannaBase.__init__(self, config=config)
- print("传入的 config 参数如下:")
- for key, value in self.config.items():
- print(f" {key}: {value}")
- # 默认参数
- self.temperature = 0.7
- self.base_url = config.get("base_url", "http://localhost:11434")
- self.model = config.get("model", "qwen2.5:7b")
- self.timeout = config.get("timeout", 60)
- if "temperature" in config:
- print(f"temperature is changed to: {config['temperature']}")
- self.temperature = config["temperature"]
- def system_message(self, message: str) -> any:
- print(f"system_content: {message}")
- return {"role": "system", "content": message}
- def user_message(self, message: str) -> any:
- print(f"\nuser_content: {message}")
- return {"role": "user", "content": message}
- def assistant_message(self, message: str) -> any:
- print(f"assistant_content: {message}")
- return {"role": "assistant", "content": message}
- def submit_prompt(self, prompt, **kwargs) -> str:
- if prompt is None:
- raise Exception("Prompt is None")
- if len(prompt) == 0:
- raise Exception("Prompt is empty")
- # 计算token数量估计
- num_tokens = 0
- for message in prompt:
- num_tokens += len(message["content"]) / 4
- # 确定使用的模型
- model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
- print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
- # 准备Ollama API请求
- url = f"{self.base_url}/api/chat"
- payload = {
- "model": model,
- "messages": prompt,
- "stream": False,
- "options": {
- "temperature": self.temperature
- }
- }
- try:
- response = requests.post(
- url,
- json=payload,
- timeout=self.timeout,
- headers={"Content-Type": "application/json"}
- )
- response.raise_for_status()
-
- result = response.json()
- return result["message"]["content"]
-
- except requests.exceptions.RequestException as e:
- print(f"Ollama API请求失败: {e}")
- raise Exception(f"Ollama API调用失败: {str(e)}")
- def generate_sql(self, question: str, **kwargs) -> str:
- """重写generate_sql方法,增加异常处理"""
- try:
- print(f"[DEBUG] 尝试为问题生成SQL: {question}")
- sql = super().generate_sql(question, **kwargs)
-
- if not sql or sql.strip() == "":
- print(f"[WARNING] 生成的SQL为空")
- return None
-
- # 检查返回内容是否为有效SQL
- sql_lower = sql.lower().strip()
- error_indicators = [
- "insufficient context", "无法生成", "sorry", "cannot", "不能",
- "no relevant", "no suitable", "unable to", "无法", "抱歉"
- ]
-
- for indicator in error_indicators:
- if indicator in sql_lower:
- print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
- return None
-
- sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
- if not any(keyword in sql_lower for keyword in sql_keywords):
- print(f"[WARNING] 返回内容不像有效SQL: {sql}")
- return None
-
- print(f"[SUCCESS] 成功生成SQL: {sql}")
- return sql
-
- except Exception as e:
- print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
- return None
- def generate_question(self, sql: str, **kwargs) -> str:
- """根据SQL生成中文问题"""
- prompt = [
- self.system_message(
- "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
- ),
- self.user_message(sql)
- ]
- response = self.submit_prompt(prompt, **kwargs)
- return response
-
- def chat_with_llm(self, question: str, **kwargs) -> str:
- """直接与LLM对话"""
- try:
- prompt = [
- self.system_message(
- "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
- ),
- self.user_message(question)
- ]
- response = self.submit_prompt(prompt, **kwargs)
- return response
- except Exception as e:
- print(f"[ERROR] LLM对话失败: {str(e)}")
- return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
- def test_connection(self, test_prompt="你好") -> dict:
- """测试Ollama连接"""
- result = {
- "success": False,
- "model": self.model,
- "base_url": self.base_url,
- "message": "",
- }
-
- try:
- print(f"测试Ollama连接 - 模型: {self.model}")
- print(f"Ollama服务地址: {self.base_url}")
-
- # 测试简单对话
- prompt = [self.user_message(test_prompt)]
- response = self.submit_prompt(prompt)
-
- result["success"] = True
- result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
-
- return result
-
- except Exception as e:
- result["message"] = f"Ollama连接测试失败: {str(e)}"
- return result
- def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
- """
- 重写问题合并方法,通过配置参数控制是否启用合并功能
-
- Args:
- last_question (str): 上一个问题
- new_question (str): 新问题
- **kwargs: 其他参数
-
- Returns:
- str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
- """
- # 如果未启用合并功能或没有上一个问题,直接返回新问题
- if not REWRITE_QUESTION_ENABLED or last_question is None:
- print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
- return new_question
-
- print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
- print(f"[DEBUG] 上一个问题: {last_question}")
- print(f"[DEBUG] 新问题: {new_question}")
-
- try:
- prompt = [
- self.system_message(
- "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
- "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
- "请用中文回答。"
- ),
- self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
- ]
-
- rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
- print(f"[DEBUG] 合并后的问题: {rewritten_question}")
- return rewritten_question
-
- except Exception as e:
- print(f"[ERROR] 问题合并失败: {str(e)}")
- # 如果合并失败,返回新问题
- return new_question
|