|
@@ -1,6 +1,8 @@
|
|
|
import os
|
|
|
from abc import ABC, abstractmethod
|
|
|
-from typing import List, Dict, Any, Optional
|
|
|
+from typing import List, Dict, Any, Optional, Union, Tuple
|
|
|
+import pandas as pd
|
|
|
+import plotly.graph_objs
|
|
|
from vanna.base import VannaBase
|
|
|
# 导入配置参数
|
|
|
from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_SUMMARY_THINKING
|
|
@@ -11,6 +13,9 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
|
|
|
def __init__(self, config=None):
|
|
|
VannaBase.__init__(self, config=config)
|
|
|
+
|
|
|
+ # 存储LLM解释性文本
|
|
|
+ self.last_llm_explanation = None
|
|
|
|
|
|
print("传入的 config 参数如下:")
|
|
|
for key, value in self.config.items():
|
|
@@ -266,15 +271,19 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
|
|
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
|
"""
|
|
|
- 重写父类的 generate_sql 方法,增加异常处理
|
|
|
+ 重写父类的 generate_sql 方法,增加异常处理和解释性文本保存
|
|
|
"""
|
|
|
try:
|
|
|
+ # 清空上次的解释性文本
|
|
|
+ self.last_llm_explanation = None
|
|
|
+
|
|
|
print(f"[DEBUG] 尝试为问题生成SQL: {question}")
|
|
|
# 调用父类的 generate_sql
|
|
|
sql = super().generate_sql(question, **kwargs)
|
|
|
|
|
|
if not sql or sql.strip() == "":
|
|
|
print(f"[WARNING] 生成的SQL为空")
|
|
|
+ self.last_llm_explanation = "无法生成SQL查询,可能是问题描述不够清晰或缺少必要的数据表信息。"
|
|
|
return None
|
|
|
|
|
|
# 替换 "\_" 为 "_",解决特殊字符转义问题
|
|
@@ -293,15 +302,21 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
for indicator in error_indicators:
|
|
|
if indicator in sql_lower:
|
|
|
print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
|
|
|
+ # 保存LLM的解释性文本
|
|
|
+ self.last_llm_explanation = sql
|
|
|
return None
|
|
|
|
|
|
# 简单检查是否像SQL语句(至少包含一些SQL关键词)
|
|
|
sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
|
|
|
if not any(keyword in sql_lower for keyword in sql_keywords):
|
|
|
print(f"[WARNING] 返回内容不像有效SQL: {sql}")
|
|
|
+ # 保存LLM的解释性文本
|
|
|
+ self.last_llm_explanation = sql
|
|
|
return None
|
|
|
|
|
|
print(f"[SUCCESS] 成功生成SQL:\n {sql}")
|
|
|
+ # 清空解释性文本
|
|
|
+ self.last_llm_explanation = None
|
|
|
return sql
|
|
|
|
|
|
except Exception as e:
|
|
@@ -310,7 +325,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
# 导入traceback以获取详细错误信息
|
|
|
import traceback
|
|
|
print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
|
|
|
- # 返回 None 而不是抛出异常
|
|
|
+ self.last_llm_explanation = f"SQL生成过程中出现异常: {str(e)}"
|
|
|
return None
|
|
|
|
|
|
def generate_question(self, sql: str, **kwargs) -> str:
|
|
@@ -473,6 +488,130 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
cleaned_text = cleaned_text.strip()
|
|
|
|
|
|
return cleaned_text
|
|
|
+
|
|
|
+
|
|
|
+ def ask(
|
|
|
+ self,
|
|
|
+ question: Union[str, None] = None,
|
|
|
+ print_results: bool = True,
|
|
|
+ auto_train: bool = True,
|
|
|
+ visualize: bool = True,
|
|
|
+ allow_llm_to_see_data: bool = False,
|
|
|
+ ) -> Union[
|
|
|
+ Tuple[
|
|
|
+ Union[str, None],
|
|
|
+ Union[pd.DataFrame, None],
|
|
|
+ Union[plotly.graph_objs.Figure, None],
|
|
|
+ ],
|
|
|
+ None,
|
|
|
+ ]:
|
|
|
+ """
|
|
|
+ 重载父类的ask方法,处理LLM解释性文本
|
|
|
+ 当generate_sql无法生成SQL时,保存解释性文本供API层使用
|
|
|
+ """
|
|
|
+ if question is None:
|
|
|
+ question = input("Enter a question: ")
|
|
|
+
|
|
|
+ # 清空上次的解释性文本
|
|
|
+ self.last_llm_explanation = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ self.last_llm_explanation = str(e)
|
|
|
+ if print_results:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+ # 如果SQL为空,说明有解释性文本,按照正常流程返回None
|
|
|
+ # API层会检查 last_llm_explanation 来获取解释
|
|
|
+ if sql is None:
|
|
|
+ print(f"[INFO] 无法生成SQL,解释: {self.last_llm_explanation}")
|
|
|
+ if print_results:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+ # 以下是正常的SQL执行流程(保持VannaBase原有逻辑)
|
|
|
+ if print_results:
|
|
|
+ print(sql)
|
|
|
+
|
|
|
+ if self.run_sql_is_set is False:
|
|
|
+ print("If you want to run the SQL query, connect to a database first.")
|
|
|
+ if print_results:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return sql, None, None
|
|
|
+
|
|
|
+ try:
|
|
|
+ df = self.run_sql(sql)
|
|
|
+
|
|
|
+ if df is None:
|
|
|
+ print("The SQL query returned no results.")
|
|
|
+ if print_results:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return sql, None, None
|
|
|
+
|
|
|
+ if print_results:
|
|
|
+ # 显示结果表格
|
|
|
+ if len(df) > 10:
|
|
|
+ print(df.head(10).to_string())
|
|
|
+ print(f"... ({len(df)} rows)")
|
|
|
+ else:
|
|
|
+ print(df.to_string())
|
|
|
+
|
|
|
+ # 如果启用了自动训练,添加问题-SQL对到训练集
|
|
|
+ if auto_train:
|
|
|
+ try:
|
|
|
+ self.add_question_sql(question=question, sql=sql)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Could not add question and sql to training data: {e}")
|
|
|
+
|
|
|
+ if visualize:
|
|
|
+ try:
|
|
|
+ # 检查是否应该生成图表
|
|
|
+ if self.should_generate_chart(df):
|
|
|
+ plotly_code = self.generate_plotly_code(
|
|
|
+ question=question,
|
|
|
+ sql=sql,
|
|
|
+ df=df,
|
|
|
+ chart_instructions=""
|
|
|
+ )
|
|
|
+ if plotly_code is not None and plotly_code.strip() != "":
|
|
|
+ fig = self.get_plotly_figure(
|
|
|
+ plotly_code=plotly_code,
|
|
|
+ df=df,
|
|
|
+ dark_mode=False
|
|
|
+ )
|
|
|
+ if fig is not None:
|
|
|
+ if print_results:
|
|
|
+ print("Chart generated (use fig.show() to display)")
|
|
|
+ return sql, df, fig
|
|
|
+ else:
|
|
|
+ print("Could not generate chart")
|
|
|
+ return sql, df, None
|
|
|
+ else:
|
|
|
+ print("No chart generated")
|
|
|
+ return sql, df, None
|
|
|
+ else:
|
|
|
+ print("Not generating chart for this data")
|
|
|
+ return sql, df, None
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Couldn't generate chart: {e}")
|
|
|
+ return sql, df, None
|
|
|
+ else:
|
|
|
+ return sql, df, None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print("Couldn't run sql: ", e)
|
|
|
+ if print_results:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return sql, None, None
|
|
|
+
|
|
|
|
|
|
@abstractmethod
|
|
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
@@ -486,4 +625,5 @@ class BaseLLMChat(VannaBase, ABC):
|
|
|
Returns:
|
|
|
str: LLM的响应
|
|
|
"""
|
|
|
- pass
|
|
|
+ pass
|
|
|
+
|