|
@@ -4,166 +4,231 @@ from vanna.base import VannaBase
|
|
|
|
|
|
|
|
|
class QianWenAI_Chat(VannaBase):
|
|
|
- def __init__(self, client=None, config=None):
|
|
|
- print("...QianWenAI_Chat init...")
|
|
|
- VannaBase.__init__(self, config=config)
|
|
|
-
|
|
|
- print("传入的 config 参数如下:")
|
|
|
- for key, value in self.config.items():
|
|
|
- print(f" {key}: {value}")
|
|
|
-
|
|
|
- # default parameters - can be overrided using config
|
|
|
- self.temperature = 0.7
|
|
|
-
|
|
|
- if "temperature" in config:
|
|
|
- print(f"temperature is changed to: {config['temperature']}")
|
|
|
- self.temperature = config["temperature"]
|
|
|
-
|
|
|
- if "api_type" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if "api_base" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_base is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if "api_version" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if client is not None:
|
|
|
- self.client = client
|
|
|
- return
|
|
|
-
|
|
|
- if config is None and client is None:
|
|
|
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
- return
|
|
|
-
|
|
|
- if "api_key" in config:
|
|
|
- if "base_url" not in config:
|
|
|
- self.client = OpenAI(api_key=config["api_key"],
|
|
|
- base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
|
- else:
|
|
|
- self.client = OpenAI(api_key=config["api_key"],
|
|
|
- base_url=config["base_url"])
|
|
|
+ def __init__(self, client=None, config=None):
|
|
|
+ print("...QianWenAI_Chat init...")
|
|
|
+ VannaBase.__init__(self, config=config)
|
|
|
+
|
|
|
+ print("传入的 config 参数如下:")
|
|
|
+ for key, value in self.config.items():
|
|
|
+ print(f" {key}: {value}")
|
|
|
+
|
|
|
+ # default parameters - can be overrided using config
|
|
|
+ self.temperature = 0.7
|
|
|
+
|
|
|
+ if "temperature" in config:
|
|
|
+ print(f"temperature is changed to: {config['temperature']}")
|
|
|
+ self.temperature = config["temperature"]
|
|
|
+
|
|
|
+ if "api_type" in config:
|
|
|
+ raise Exception(
|
|
|
+ "Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
|
|
+ )
|
|
|
+
|
|
|
+ if "api_base" in config:
|
|
|
+ raise Exception(
|
|
|
+ "Passing api_base is now deprecated. Please pass an OpenAI client instead."
|
|
|
+ )
|
|
|
+
|
|
|
+ if "api_version" in config:
|
|
|
+ raise Exception(
|
|
|
+ "Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
|
|
+ )
|
|
|
+
|
|
|
+ if client is not None:
|
|
|
+ self.client = client
|
|
|
+ return
|
|
|
+
|
|
|
+ if config is None and client is None:
|
|
|
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
+ return
|
|
|
+
|
|
|
+ if "api_key" in config:
|
|
|
+ if "base_url" not in config:
|
|
|
+ self.client = OpenAI(api_key=config["api_key"],
|
|
|
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
|
+ else:
|
|
|
+ self.client = OpenAI(api_key=config["api_key"],
|
|
|
+ base_url=config["base_url"])
|
|
|
|
|
|
- def system_message(self, message: str) -> any:
|
|
|
- print(f"system_content: {message}")
|
|
|
- return {"role": "system", "content": message}
|
|
|
-
|
|
|
- def user_message(self, message: str) -> any:
|
|
|
- print(f"\nuser_content: {message}")
|
|
|
- return {"role": "user", "content": message}
|
|
|
-
|
|
|
- def assistant_message(self, message: str) -> any:
|
|
|
- print(f"assistant_content: {message}")
|
|
|
- return {"role": "assistant", "content": message}
|
|
|
-
|
|
|
- def submit_prompt(self, prompt, **kwargs) -> str:
|
|
|
- if prompt is None:
|
|
|
- raise Exception("Prompt is None")
|
|
|
-
|
|
|
- if len(prompt) == 0:
|
|
|
- raise Exception("Prompt is empty")
|
|
|
-
|
|
|
- # Count the number of tokens in the message log
|
|
|
- # Use 4 as an approximation for the number of characters per token
|
|
|
- num_tokens = 0
|
|
|
- for message in prompt:
|
|
|
- num_tokens += len(message["content"]) / 4
|
|
|
-
|
|
|
- # 从配置和参数中获取enable_thinking设置
|
|
|
- # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
|
|
|
- enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
|
|
|
-
|
|
|
- # 公共参数
|
|
|
- common_params = {
|
|
|
- "messages": prompt,
|
|
|
- "stop": None,
|
|
|
- "temperature": self.temperature,
|
|
|
- }
|
|
|
-
|
|
|
- # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
|
|
|
- if enable_thinking:
|
|
|
- common_params["stream"] = True
|
|
|
- # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
|
|
|
- # 也可能它只是默认启用stream=True时的thinking功能
|
|
|
-
|
|
|
- model = None
|
|
|
- # 确定使用的模型
|
|
|
- if kwargs.get("model", None) is not None:
|
|
|
- model = kwargs.get("model", None)
|
|
|
- common_params["model"] = model
|
|
|
- elif kwargs.get("engine", None) is not None:
|
|
|
- engine = kwargs.get("engine", None)
|
|
|
- common_params["engine"] = engine
|
|
|
- model = engine
|
|
|
- elif self.config is not None and "engine" in self.config:
|
|
|
- common_params["engine"] = self.config["engine"]
|
|
|
- model = self.config["engine"]
|
|
|
- elif self.config is not None and "model" in self.config:
|
|
|
- common_params["model"] = self.config["model"]
|
|
|
- model = self.config["model"]
|
|
|
- else:
|
|
|
- if num_tokens > 3500:
|
|
|
- model = "qwen-long"
|
|
|
- else:
|
|
|
- model = "qwen-plus"
|
|
|
- common_params["model"] = model
|
|
|
-
|
|
|
- print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
|
|
|
-
|
|
|
- if enable_thinking:
|
|
|
- # 流式处理模式
|
|
|
- print("使用流式处理模式,启用thinking功能")
|
|
|
-
|
|
|
- # 检查是否需要通过headers传递enable_thinking参数
|
|
|
- response_stream = self.client.chat.completions.create(**common_params)
|
|
|
-
|
|
|
- # 收集流式响应
|
|
|
- collected_thinking = []
|
|
|
- collected_content = []
|
|
|
-
|
|
|
- for chunk in response_stream:
|
|
|
- # 处理thinking部分
|
|
|
- if hasattr(chunk, 'thinking') and chunk.thinking:
|
|
|
- collected_thinking.append(chunk.thinking)
|
|
|
+ def system_message(self, message: str) -> any:
|
|
|
+ print(f"system_content: {message}")
|
|
|
+ return {"role": "system", "content": message}
|
|
|
+
|
|
|
+ def user_message(self, message: str) -> any:
|
|
|
+ print(f"\nuser_content: {message}")
|
|
|
+ return {"role": "user", "content": message}
|
|
|
+
|
|
|
+ def assistant_message(self, message: str) -> any:
|
|
|
+ print(f"assistant_content: {message}")
|
|
|
+ return {"role": "assistant", "content": message}
|
|
|
+
|
|
|
+ def submit_prompt(self, prompt, **kwargs) -> str:
|
|
|
+ if prompt is None:
|
|
|
+ raise Exception("Prompt is None")
|
|
|
+
|
|
|
+ if len(prompt) == 0:
|
|
|
+ raise Exception("Prompt is empty")
|
|
|
+
|
|
|
+ # Count the number of tokens in the message log
|
|
|
+ # Use 4 as an approximation for the number of characters per token
|
|
|
+ num_tokens = 0
|
|
|
+ for message in prompt:
|
|
|
+ num_tokens += len(message["content"]) / 4
|
|
|
+
|
|
|
+ # 从配置和参数中获取enable_thinking设置
|
|
|
+ # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
|
|
|
+ enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
|
|
|
|
|
|
- # 处理content部分
|
|
|
- if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
|
- collected_content.append(chunk.choices[0].delta.content)
|
|
|
-
|
|
|
- # 可以在这里处理thinking的展示逻辑,如保存到日志等
|
|
|
- if collected_thinking:
|
|
|
- print("Model thinking process:", "".join(collected_thinking))
|
|
|
-
|
|
|
- # 返回完整的内容
|
|
|
- return "".join(collected_content)
|
|
|
- else:
|
|
|
- # 非流式处理模式
|
|
|
- print("使用非流式处理模式")
|
|
|
- response = self.client.chat.completions.create(**common_params)
|
|
|
-
|
|
|
- # Find the first response from the chatbot that has text in it (some responses may not have text)
|
|
|
- for choice in response.choices:
|
|
|
- if "text" in choice:
|
|
|
- return choice.text
|
|
|
-
|
|
|
- # If no response with text is found, return the first response's content (which may be empty)
|
|
|
- return response.choices[0].message.content
|
|
|
-
|
|
|
-# 为了解决通过sql生成question时,question是英文的问题。
|
|
|
- def generate_question(self, sql: str, **kwargs) -> str:
|
|
|
- # 这里可以自定义提示词/逻辑
|
|
|
- prompt = [
|
|
|
- self.system_message(
|
|
|
- "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
|
|
|
- ),
|
|
|
- self.user_message(sql)
|
|
|
- ]
|
|
|
- response = self.submit_prompt(prompt, **kwargs)
|
|
|
- # 你也可以在这里对response做后处理
|
|
|
- return response
|
|
|
+ # 公共参数
|
|
|
+ common_params = {
|
|
|
+ "messages": prompt,
|
|
|
+ "stop": None,
|
|
|
+ "temperature": self.temperature,
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
|
|
|
+ if enable_thinking:
|
|
|
+ common_params["stream"] = True
|
|
|
+ # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
|
|
|
+ # 也可能它只是默认启用stream=True时的thinking功能
|
|
|
+
|
|
|
+ model = None
|
|
|
+ # 确定使用的模型
|
|
|
+ if kwargs.get("model", None) is not None:
|
|
|
+ model = kwargs.get("model", None)
|
|
|
+ common_params["model"] = model
|
|
|
+ elif kwargs.get("engine", None) is not None:
|
|
|
+ engine = kwargs.get("engine", None)
|
|
|
+ common_params["engine"] = engine
|
|
|
+ model = engine
|
|
|
+ elif self.config is not None and "engine" in self.config:
|
|
|
+ common_params["engine"] = self.config["engine"]
|
|
|
+ model = self.config["engine"]
|
|
|
+ elif self.config is not None and "model" in self.config:
|
|
|
+ common_params["model"] = self.config["model"]
|
|
|
+ model = self.config["model"]
|
|
|
+ else:
|
|
|
+ if num_tokens > 3500:
|
|
|
+ model = "qwen-long"
|
|
|
+ else:
|
|
|
+ model = "qwen-plus"
|
|
|
+ common_params["model"] = model
|
|
|
+
|
|
|
+ print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
|
|
|
+
|
|
|
+ if enable_thinking:
|
|
|
+ # 流式处理模式
|
|
|
+ print("使用流式处理模式,启用thinking功能")
|
|
|
+
|
|
|
+ # 检查是否需要通过headers传递enable_thinking参数
|
|
|
+ response_stream = self.client.chat.completions.create(**common_params)
|
|
|
+
|
|
|
+ # 收集流式响应
|
|
|
+ collected_thinking = []
|
|
|
+ collected_content = []
|
|
|
+
|
|
|
+ for chunk in response_stream:
|
|
|
+ # 处理thinking部分
|
|
|
+ if hasattr(chunk, 'thinking') and chunk.thinking:
|
|
|
+ collected_thinking.append(chunk.thinking)
|
|
|
+
|
|
|
+ # 处理content部分
|
|
|
+ if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
|
+ collected_content.append(chunk.choices[0].delta.content)
|
|
|
+
|
|
|
+ # 可以在这里处理thinking的展示逻辑,如保存到日志等
|
|
|
+ if collected_thinking:
|
|
|
+ print("Model thinking process:", "".join(collected_thinking))
|
|
|
+
|
|
|
+ # 返回完整的内容
|
|
|
+ return "".join(collected_content)
|
|
|
+ else:
|
|
|
+ # 非流式处理模式
|
|
|
+ print("使用非流式处理模式")
|
|
|
+ response = self.client.chat.completions.create(**common_params)
|
|
|
+
|
|
|
+ # Find the first response from the chatbot that has text in it (some responses may not have text)
|
|
|
+ for choice in response.choices:
|
|
|
+ if "text" in choice:
|
|
|
+ return choice.text
|
|
|
+
|
|
|
+ # If no response with text is found, return the first response's content (which may be empty)
|
|
|
+ return response.choices[0].message.content
|
|
|
+
|
|
|
+ # 重写 generate_sql 方法以增加异常处理
|
|
|
+ def generate_sql(self, question: str, **kwargs) -> str:
|
|
|
+ """
|
|
|
+ 重写父类的 generate_sql 方法,增加异常处理
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ print(f"[DEBUG] 尝试为问题生成SQL: {question}")
|
|
|
+ # 调用父类的 generate_sql
|
|
|
+ sql = super().generate_sql(question, **kwargs)
|
|
|
+
|
|
|
+ if not sql or sql.strip() == "":
|
|
|
+ print(f"[WARNING] 生成的SQL为空")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 检查返回内容是否为有效SQL或错误信息
|
|
|
+ sql_lower = sql.lower().strip()
|
|
|
+
|
|
|
+ # 检查是否包含错误提示信息
|
|
|
+ error_indicators = [
|
|
|
+ "insufficient context", "无法生成", "sorry", "cannot", "不能",
|
|
|
+ "no relevant", "no suitable", "unable to", "无法", "抱歉",
|
|
|
+ "i don't have", "i cannot", "没有相关", "找不到", "不存在"
|
|
|
+ ]
|
|
|
+
|
|
|
+ for indicator in error_indicators:
|
|
|
+ if indicator in sql_lower:
|
|
|
+ print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 简单检查是否像SQL语句(至少包含一些SQL关键词)
|
|
|
+ sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
|
|
|
+ if not any(keyword in sql_lower for keyword in sql_keywords):
|
|
|
+ print(f"[WARNING] 返回内容不像有效SQL: {sql}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ print(f"[SUCCESS] 成功生成SQL: {sql}")
|
|
|
+ return sql
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
|
|
|
+ print(f"[ERROR] 异常类型: {type(e).__name__}")
|
|
|
+ # 导入traceback以获取详细错误信息
|
|
|
+ import traceback
|
|
|
+ print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
|
|
|
+ # 返回 None 而不是抛出异常
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 为了解决通过sql生成question时,question是英文的问题。
|
|
|
+ def generate_question(self, sql: str, **kwargs) -> str:
|
|
|
+ # 这里可以自定义提示词/逻辑
|
|
|
+ prompt = [
|
|
|
+ self.system_message(
|
|
|
+ "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
|
|
|
+ ),
|
|
|
+ self.user_message(sql)
|
|
|
+ ]
|
|
|
+ response = self.submit_prompt(prompt, **kwargs)
|
|
|
+ # 你也可以在这里对response做后处理
|
|
|
+ return response
|
|
|
+
|
|
|
+ # 新增:直接与LLM对话的方法
|
|
|
+ def chat_with_llm(self, question: str, **kwargs) -> str:
|
|
|
+ """
|
|
|
+ 直接与LLM对话,不涉及SQL生成
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ prompt = [
|
|
|
+ self.system_message(
|
|
|
+ "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
|
|
|
+ ),
|
|
|
+ self.user_message(question)
|
|
|
+ ]
|
|
|
+ response = self.submit_prompt(prompt, **kwargs)
|
|
|
+ return response
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] LLM对话失败: {str(e)}")
|
|
|
+ return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
|