|
@@ -1,58 +1,29 @@
|
|
|
import os
|
|
|
-from openai import OpenAI
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
from vanna.base import VannaBase
|
|
|
# 导入配置参数
|
|
|
from app_config import REWRITE_QUESTION_ENABLED
|
|
|
|
|
|
|
|
|
-class QianWenAI_Chat(VannaBase):
|
|
|
- def __init__(self, client=None, config=None):
|
|
|
- print("...QianWenAI_Chat init...")
|
|
|
+class BaseLLMChat(VannaBase, ABC):
|
|
|
+ """自定义LLM聊天基类,包含公共方法"""
|
|
|
+
|
|
|
+ def __init__(self, config=None):
|
|
|
VannaBase.__init__(self, config=config)
|
|
|
-
|
|
|
+
|
|
|
print("传入的 config 参数如下:")
|
|
|
for key, value in self.config.items():
|
|
|
print(f" {key}: {value}")
|
|
|
-
|
|
|
- # default parameters - can be overrided using config
|
|
|
+
|
|
|
+ # 默认参数
|
|
|
self.temperature = 0.7
|
|
|
-
|
|
|
+
|
|
|
if "temperature" in config:
|
|
|
print(f"temperature is changed to: {config['temperature']}")
|
|
|
self.temperature = config["temperature"]
|
|
|
-
|
|
|
- if "api_type" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if "api_base" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_base is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if "api_version" in config:
|
|
|
- raise Exception(
|
|
|
- "Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
|
|
- )
|
|
|
-
|
|
|
- if client is not None:
|
|
|
- self.client = client
|
|
|
- return
|
|
|
-
|
|
|
- if config is None and client is None:
|
|
|
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
- return
|
|
|
-
|
|
|
- if "api_key" in config:
|
|
|
- if "base_url" not in config:
|
|
|
- self.client = OpenAI(api_key=config["api_key"],
|
|
|
- base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
|
- else:
|
|
|
- self.client = OpenAI(api_key=config["api_key"],
|
|
|
- base_url=config["base_url"])
|
|
|
|
|
|
- # 新增:加载错误SQL提示配置
|
|
|
+ # 加载错误SQL提示配置
|
|
|
self.enable_error_sql_prompt = self._load_error_sql_prompt_config()
|
|
|
|
|
|
def _load_error_sql_prompt_config(self) -> bool:
|
|
@@ -65,8 +36,22 @@ class QianWenAI_Chat(VannaBase):
|
|
|
except (ImportError, AttributeError) as e:
|
|
|
print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
|
|
|
return False
|
|
|
-
|
|
|
- # 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
|
|
|
+
|
|
|
+ def system_message(self, message: str) -> dict:
|
|
|
+ """创建系统消息格式"""
|
|
|
+ print(f"system_content: {message}")
|
|
|
+ return {"role": "system", "content": message}
|
|
|
+
|
|
|
+ def user_message(self, message: str) -> dict:
|
|
|
+ """创建用户消息格式"""
|
|
|
+ print(f"\nuser_content: {message}")
|
|
|
+ return {"role": "user", "content": message}
|
|
|
+
|
|
|
+ def assistant_message(self, message: str) -> dict:
|
|
|
+ """创建助手消息格式"""
|
|
|
+ print(f"assistant_content: {message}")
|
|
|
+ return {"role": "assistant", "content": message}
|
|
|
+
|
|
|
def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
|
|
|
"""
|
|
|
基于VannaBase源码实现,在第7点添加中文别名指令
|
|
@@ -160,7 +145,6 @@ class QianWenAI_Chat(VannaBase):
|
|
|
|
|
|
return message_log
|
|
|
|
|
|
- # 生成图形的时候,使用中文标注
|
|
|
def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
|
|
|
"""
|
|
|
重写父类方法,添加明确的中文图表指令
|
|
@@ -222,22 +206,45 @@ class QianWenAI_Chat(VannaBase):
|
|
|
self.user_message(user_msg),
|
|
|
]
|
|
|
|
|
|
- # 调用父类submit_prompt方法,并清理结果
|
|
|
- plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
|
|
|
+ # 调用submit_prompt方法,并清理结果
|
|
|
+ plotly_code = self.submit_prompt(message_log, **kwargs)
|
|
|
|
|
|
return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
|
|
|
-
|
|
|
- def system_message(self, message: str) -> any:
|
|
|
- print(f"system_content: {message}")
|
|
|
- return {"role": "system", "content": message}
|
|
|
|
|
|
- def user_message(self, message: str) -> any:
|
|
|
- print(f"\nuser_content: {message}")
|
|
|
- return {"role": "user", "content": message}
|
|
|
+ def _extract_python_code(self, response: str) -> str:
|
|
|
+ """从LLM响应中提取Python代码"""
|
|
|
+ if not response:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ # 查找代码块
|
|
|
+ import re
|
|
|
+
|
|
|
+ # 匹配 ```python 或 ``` 代码块
|
|
|
+ code_pattern = r'```(?:python)?\s*(.*?)```'
|
|
|
+ matches = re.findall(code_pattern, response, re.DOTALL)
|
|
|
+
|
|
|
+ if matches:
|
|
|
+ return matches[0].strip()
|
|
|
+
|
|
|
+ # 如果没有找到代码块,返回原始响应
|
|
|
+ return response.strip()
|
|
|
|
|
|
- def assistant_message(self, message: str) -> any:
|
|
|
- print(f"assistant_content: {message}")
|
|
|
- return {"role": "assistant", "content": message}
|
|
|
+ def _sanitize_plotly_code(self, code: str) -> str:
|
|
|
+ """清理和验证Plotly代码"""
|
|
|
+ if not code:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ # 基本的代码清理
|
|
|
+ lines = code.split('\n')
|
|
|
+ cleaned_lines = []
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ # 移除空行和注释行
|
|
|
+ line = line.strip()
|
|
|
+ if line and not line.startswith('#'):
|
|
|
+ cleaned_lines.append(line)
|
|
|
+
|
|
|
+ return '\n'.join(cleaned_lines)
|
|
|
|
|
|
def should_generate_chart(self, df) -> bool:
|
|
|
"""
|
|
@@ -257,127 +264,6 @@ class QianWenAI_Chat(VannaBase):
|
|
|
|
|
|
return False
|
|
|
|
|
|
- # def get_plotly_figure(self, plotly_code: str, df, dark_mode: bool = True):
|
|
|
- # """
|
|
|
- # 重写父类方法,确保Flask应用也使用我们的自定义图表生成逻辑
|
|
|
- # 这个方法会被VannaFlaskApp调用,而不是generate_plotly_code
|
|
|
- # """
|
|
|
- # print(f"[DEBUG] get_plotly_figure被调用,plotly_code长度: {len(plotly_code) if plotly_code else 0}")
|
|
|
-
|
|
|
- # # 如果没有提供plotly_code,尝试生成一个
|
|
|
- # if not plotly_code or plotly_code.strip() == "":
|
|
|
- # print(f"[DEBUG] plotly_code为空,尝试生成默认图表")
|
|
|
- # # 生成一个简单的默认图表
|
|
|
- # df_metadata = f"DataFrame形状: {df.shape}\n列名: {list(df.columns)}\n数据类型:\n{df.dtypes}"
|
|
|
- # plotly_code = self.generate_plotly_code(
|
|
|
- # question="数据可视化",
|
|
|
- # sql=None,
|
|
|
- # df_metadata=df_metadata
|
|
|
- # )
|
|
|
-
|
|
|
- # # 调用父类方法执行plotly代码
|
|
|
- # try:
|
|
|
- # return super().get_plotly_figure(plotly_code=plotly_code, df=df, dark_mode=dark_mode)
|
|
|
- # except Exception as e:
|
|
|
- # print(f"[ERROR] 执行plotly代码失败: {e}")
|
|
|
- # print(f"[ERROR] plotly_code: {plotly_code}")
|
|
|
- # # 如果执行失败,返回None或生成一个简单的备用图表
|
|
|
- # return None
|
|
|
-
|
|
|
- def submit_prompt(self, prompt, **kwargs) -> str:
|
|
|
- if prompt is None:
|
|
|
- raise Exception("Prompt is None")
|
|
|
-
|
|
|
- if len(prompt) == 0:
|
|
|
- raise Exception("Prompt is empty")
|
|
|
-
|
|
|
- # Count the number of tokens in the message log
|
|
|
- # Use 4 as an approximation for the number of characters per token
|
|
|
- num_tokens = 0
|
|
|
- for message in prompt:
|
|
|
- num_tokens += len(message["content"]) / 4
|
|
|
-
|
|
|
- # 从配置和参数中获取enable_thinking设置
|
|
|
- # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
|
|
|
- enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
|
|
|
-
|
|
|
- # 公共参数
|
|
|
- common_params = {
|
|
|
- "messages": prompt,
|
|
|
- "stop": None,
|
|
|
- "temperature": self.temperature,
|
|
|
- }
|
|
|
-
|
|
|
- # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
|
|
|
- if enable_thinking:
|
|
|
- common_params["stream"] = True
|
|
|
- # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
|
|
|
- # 也可能它只是默认启用stream=True时的thinking功能
|
|
|
-
|
|
|
- model = None
|
|
|
- # 确定使用的模型
|
|
|
- if kwargs.get("model", None) is not None:
|
|
|
- model = kwargs.get("model", None)
|
|
|
- common_params["model"] = model
|
|
|
- elif kwargs.get("engine", None) is not None:
|
|
|
- engine = kwargs.get("engine", None)
|
|
|
- common_params["engine"] = engine
|
|
|
- model = engine
|
|
|
- elif self.config is not None and "engine" in self.config:
|
|
|
- common_params["engine"] = self.config["engine"]
|
|
|
- model = self.config["engine"]
|
|
|
- elif self.config is not None and "model" in self.config:
|
|
|
- common_params["model"] = self.config["model"]
|
|
|
- model = self.config["model"]
|
|
|
- else:
|
|
|
- if num_tokens > 3500:
|
|
|
- model = "qwen-long"
|
|
|
- else:
|
|
|
- model = "qwen-plus"
|
|
|
- common_params["model"] = model
|
|
|
-
|
|
|
- print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
|
|
|
-
|
|
|
- if enable_thinking:
|
|
|
- # 流式处理模式
|
|
|
- print("使用流式处理模式,启用thinking功能")
|
|
|
-
|
|
|
- # 检查是否需要通过headers传递enable_thinking参数
|
|
|
- response_stream = self.client.chat.completions.create(**common_params)
|
|
|
-
|
|
|
- # 收集流式响应
|
|
|
- collected_thinking = []
|
|
|
- collected_content = []
|
|
|
-
|
|
|
- for chunk in response_stream:
|
|
|
- # 处理thinking部分
|
|
|
- if hasattr(chunk, 'thinking') and chunk.thinking:
|
|
|
- collected_thinking.append(chunk.thinking)
|
|
|
-
|
|
|
- # 处理content部分
|
|
|
- if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
|
- collected_content.append(chunk.choices[0].delta.content)
|
|
|
-
|
|
|
- # 可以在这里处理thinking的展示逻辑,如保存到日志等
|
|
|
- if collected_thinking:
|
|
|
- print("Model thinking process:", "".join(collected_thinking))
|
|
|
-
|
|
|
- # 返回完整的内容
|
|
|
- return "".join(collected_content)
|
|
|
- else:
|
|
|
- # 非流式处理模式
|
|
|
- print("使用非流式处理模式")
|
|
|
- response = self.client.chat.completions.create(**common_params)
|
|
|
-
|
|
|
- # Find the first response from the chatbot that has text in it (some responses may not have text)
|
|
|
- for choice in response.choices:
|
|
|
- if "text" in choice:
|
|
|
- return choice.text
|
|
|
-
|
|
|
- # If no response with text is found, return the first response's content (which may be empty)
|
|
|
- return response.choices[0].message.content
|
|
|
-
|
|
|
- # 重写 generate_sql 方法以增加异常处理
|
|
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
|
"""
|
|
|
重写父类的 generate_sql 方法,增加异常处理
|
|
@@ -391,6 +277,9 @@ class QianWenAI_Chat(VannaBase):
|
|
|
print(f"[WARNING] 生成的SQL为空")
|
|
|
return None
|
|
|
|
|
|
+ # 替换 "\_" 为 "_",解决特殊字符转义问题
|
|
|
+ sql = sql.replace("\\_", "_")
|
|
|
+
|
|
|
# 检查返回内容是否为有效SQL或错误信息
|
|
|
sql_lower = sql.lower().strip()
|
|
|
|
|
@@ -424,9 +313,8 @@ class QianWenAI_Chat(VannaBase):
|
|
|
# 返回 None 而不是抛出异常
|
|
|
return None
|
|
|
|
|
|
- # 为了解决通过sql生成question时,question是英文的问题。
|
|
|
def generate_question(self, sql: str, **kwargs) -> str:
|
|
|
- # 这里可以自定义提示词/逻辑
|
|
|
+ """根据SQL生成中文问题"""
|
|
|
prompt = [
|
|
|
self.system_message(
|
|
|
"请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
|
|
@@ -434,10 +322,8 @@ class QianWenAI_Chat(VannaBase):
|
|
|
self.user_message(sql)
|
|
|
]
|
|
|
response = self.submit_prompt(prompt, **kwargs)
|
|
|
- # 你也可以在这里对response做后处理
|
|
|
return response
|
|
|
-
|
|
|
- # 新增:直接与LLM对话的方法
|
|
|
+
|
|
|
def chat_with_llm(self, question: str, **kwargs) -> str:
|
|
|
"""
|
|
|
直接与LLM对话,不涉及SQL生成
|
|
@@ -493,4 +379,18 @@ class QianWenAI_Chat(VannaBase):
|
|
|
except Exception as e:
|
|
|
print(f"[ERROR] 问题合并失败: {str(e)}")
|
|
|
# 如果合并失败,返回新问题
|
|
|
- return new_question
|
|
|
+ return new_question
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def submit_prompt(self, prompt, **kwargs) -> str:
|
|
|
+ """
|
|
|
+ 子类必须实现的核心提交方法
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt: 消息列表
|
|
|
+ **kwargs: 其他参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: LLM的响应
|
|
|
+ """
|
|
|
+ pass
|