base_llm_chat.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import os
  2. from abc import ABC, abstractmethod
  3. from typing import List, Dict, Any, Optional, Union, Tuple
  4. import pandas as pd
  5. import plotly.graph_objs
  6. from vanna.base import VannaBase
  7. # 导入配置参数
  8. from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_SUMMARY_THINKING
  9. class BaseLLMChat(VannaBase, ABC):
  10. """自定义LLM聊天基类,包含公共方法"""
  11. def __init__(self, config=None):
  12. VannaBase.__init__(self, config=config)
  13. # 存储LLM解释性文本
  14. self.last_llm_explanation = None
  15. print("传入的 config 参数如下:")
  16. for key, value in self.config.items():
  17. print(f" {key}: {value}")
  18. # 默认参数
  19. self.temperature = 0.7
  20. if "temperature" in config:
  21. print(f"temperature is changed to: {config['temperature']}")
  22. self.temperature = config["temperature"]
  23. # 加载错误SQL提示配置
  24. self.enable_error_sql_prompt = self._load_error_sql_prompt_config()
  25. def _load_error_sql_prompt_config(self) -> bool:
  26. """从app_config.py加载错误SQL提示配置"""
  27. try:
  28. import app_config
  29. enable_error_sql = getattr(app_config, 'ENABLE_ERROR_SQL_PROMPT', False)
  30. print(f"[DEBUG] 错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
  31. return enable_error_sql
  32. except (ImportError, AttributeError) as e:
  33. print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
  34. return False
  35. def system_message(self, message: str) -> dict:
  36. """创建系统消息格式"""
  37. print(f"system_content: {message}")
  38. return {"role": "system", "content": message}
  39. def user_message(self, message: str) -> dict:
  40. """创建用户消息格式"""
  41. print(f"\nuser_content: {message}")
  42. return {"role": "user", "content": message}
  43. def assistant_message(self, message: str) -> dict:
  44. """创建助手消息格式"""
  45. print(f"assistant_content: {message}")
  46. return {"role": "assistant", "content": message}
  47. def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
  48. """
  49. 基于VannaBase源码实现,在第7点添加中文别名指令
  50. """
  51. print(f"[DEBUG] 开始生成SQL提示词,问题: {question}")
  52. if initial_prompt is None:
  53. initial_prompt = f"You are a {self.dialect} expert. " + \
  54. "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions."
  55. # 提取DDL内容(适配新的字典格式)
  56. ddl_content_list = []
  57. if ddl_list:
  58. for item in ddl_list:
  59. if isinstance(item, dict) and "content" in item:
  60. ddl_content_list.append(item["content"])
  61. elif isinstance(item, str):
  62. ddl_content_list.append(item)
  63. initial_prompt = self.add_ddl_to_prompt(
  64. initial_prompt, ddl_content_list, max_tokens=self.max_tokens
  65. )
  66. # 提取文档内容(适配新的字典格式)
  67. doc_content_list = []
  68. if doc_list:
  69. for item in doc_list:
  70. if isinstance(item, dict) and "content" in item:
  71. doc_content_list.append(item["content"])
  72. elif isinstance(item, str):
  73. doc_content_list.append(item)
  74. if self.static_documentation != "":
  75. doc_content_list.append(self.static_documentation)
  76. initial_prompt = self.add_documentation_to_prompt(
  77. initial_prompt, doc_content_list, max_tokens=self.max_tokens
  78. )
  79. # 新增:添加错误SQL示例作为负面示例(放在Response Guidelines之前)
  80. if self.enable_error_sql_prompt:
  81. try:
  82. error_sql_list = self.get_related_error_sql(question, **kwargs)
  83. if error_sql_list:
  84. print(f"[DEBUG] 找到 {len(error_sql_list)} 个相关的错误SQL示例")
  85. # 构建格式化的负面提示内容
  86. negative_prompt_content = "===Negative Examples\n"
  87. negative_prompt_content += "下面是错误的SQL示例,请分析这些错误SQL的问题所在,并在生成新SQL时避免类似错误:\n\n"
  88. for i, error_example in enumerate(error_sql_list, 1):
  89. if "question" in error_example and "sql" in error_example:
  90. similarity = error_example.get('similarity', 'N/A')
  91. print(f"[DEBUG] 错误SQL示例 {i}: 相似度={similarity}")
  92. negative_prompt_content += f"问题: {error_example['question']}\n"
  93. negative_prompt_content += f"错误的SQL: {error_example['sql']}\n\n"
  94. # 将负面提示添加到初始提示中
  95. initial_prompt += negative_prompt_content
  96. else:
  97. print("[DEBUG] 未找到相关的错误SQL示例")
  98. except Exception as e:
  99. print(f"[WARNING] 获取错误SQL示例失败: {e}")
  100. initial_prompt += (
  101. "===Response Guidelines \n"
  102. "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
  103. "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
  104. "3. If the provided context is insufficient, please explain why it can't be generated. \n"
  105. "4. Please use the most relevant table(s). \n"
  106. "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
  107. f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
  108. "7. 在生成 SQL 查询时,如果出现 ORDER BY 子句,请遵循以下规则:\n"
  109. " - 对所有的排序字段(如聚合字段 SUM()、普通列等),请在 ORDER BY 中显式添加 NULLS LAST。\n"
  110. " - 不论是否使用 LIMIT,只要排序字段存在,都必须添加 NULLS LAST,以防止 NULL 排在结果顶部。\n"
  111. " - 示例参考:\n"
  112. " - ORDER BY total DESC NULLS LAST\n"
  113. " - ORDER BY zf_order DESC NULLS LAST\n"
  114. " - ORDER BY SUM(c.customer_count) DESC NULLS LAST \n"
  115. "8. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
  116. " - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
  117. " - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
  118. " - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
  119. " - 中文别名要准确反映字段的业务含义"
  120. )
  121. message_log = [self.system_message(initial_prompt)]
  122. for example in question_sql_list:
  123. if example is None:
  124. print("example is None")
  125. else:
  126. if example is not None and "question" in example and "sql" in example:
  127. message_log.append(self.user_message(example["question"]))
  128. message_log.append(self.assistant_message(example["sql"]))
  129. message_log.append(self.user_message(question))
  130. return message_log
  131. def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
  132. """
  133. 重写父类方法,添加明确的中文图表指令
  134. """
  135. # 构建更智能的中文图表指令,根据问题和数据内容生成有意义的标签
  136. chinese_chart_instructions = (
  137. "使用中文创建图表,要求:\n"
  138. "1. 根据用户问题和数据内容,为图表生成有意义的中文标题\n"
  139. "2. 根据数据列的实际含义,为X轴和Y轴生成准确的中文标签\n"
  140. "3. 如果有图例,确保图例标签使用中文\n"
  141. "4. 所有文本(包括标题、轴标签、图例、数据标签等)都必须使用中文\n"
  142. "5. 标题应该简洁明了地概括图表要展示的内容\n"
  143. "6. 轴标签应该准确反映对应数据列的业务含义\n"
  144. "7. 选择最适合数据特点的图表类型(柱状图、折线图、饼图等)"
  145. )
  146. # 构建父类方法要求的message_log
  147. system_msg_parts = []
  148. if question:
  149. system_msg_parts.append(
  150. f"用户问题:'{question}'"
  151. )
  152. system_msg_parts.append(
  153. f"以下是回答用户问题的pandas DataFrame数据:"
  154. )
  155. else:
  156. system_msg_parts.append("以下是一个pandas DataFrame数据:")
  157. if sql:
  158. system_msg_parts.append(f"数据来源SQL查询:\n{sql}")
  159. system_msg_parts.append(f"DataFrame结构信息:\n{df_metadata}")
  160. system_msg = "\n\n".join(system_msg_parts)
  161. # 构建更详细的用户消息,强调中文标签的重要性
  162. user_msg = (
  163. "请为这个DataFrame生成Python Plotly可视化代码。要求:\n\n"
  164. "1. 假设数据存储在名为'df'的pandas DataFrame中\n"
  165. "2. 如果DataFrame只有一个值,使用Indicator图表\n"
  166. "3. 只返回Python代码,不要任何解释\n"
  167. "4. 代码必须可以直接运行\n\n"
  168. f"{chinese_chart_instructions}\n\n"
  169. "特别注意:\n"
  170. "- 不要使用'图表标题'、'X轴标签'、'Y轴标签'这样的通用标签\n"
  171. "- 要根据实际数据内容和用户问题生成具体、有意义的中文标签\n"
  172. "- 例如:如果是性别统计,X轴可能是'性别',Y轴可能是'人数'或'占比'\n"
  173. "- 标题应该概括图表的主要内容,如'男女持卡比例分布'\n\n"
  174. "数据标签和悬停信息要求:\n"
  175. "- 不要使用%{text}这样的占位符变量\n"
  176. "- 使用具体的数据值和中文单位,例如:text=df['列名'].astype(str) + '人'\n"
  177. "- 悬停信息要清晰易懂,使用中文描述\n"
  178. "- 确保所有显示的文本都是实际的数据值,不是变量占位符"
  179. )
  180. message_log = [
  181. self.system_message(system_msg),
  182. self.user_message(user_msg),
  183. ]
  184. # 调用submit_prompt方法,并清理结果
  185. plotly_code = self.submit_prompt(message_log, **kwargs)
  186. return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
  187. def _extract_python_code(self, response: str) -> str:
  188. """从LLM响应中提取Python代码"""
  189. if not response:
  190. return ""
  191. # 查找代码块
  192. import re
  193. # 匹配 ```python 或 ``` 代码块
  194. code_pattern = r'```(?:python)?\s*(.*?)```'
  195. matches = re.findall(code_pattern, response, re.DOTALL)
  196. if matches:
  197. return matches[0].strip()
  198. # 如果没有找到代码块,返回原始响应
  199. return response.strip()
  200. def _sanitize_plotly_code(self, code: str) -> str:
  201. """清理和验证Plotly代码"""
  202. if not code:
  203. return ""
  204. # 基本的代码清理
  205. lines = code.split('\n')
  206. cleaned_lines = []
  207. for line in lines:
  208. # 移除空行和注释行
  209. line = line.strip()
  210. if line and not line.startswith('#'):
  211. cleaned_lines.append(line)
  212. return '\n'.join(cleaned_lines)
  213. def should_generate_chart(self, df) -> bool:
  214. """
  215. 判断是否应该生成图表
  216. 对于Flask应用,这个方法决定了前端是否显示图表生成按钮
  217. """
  218. if df is None or df.empty:
  219. print(f"[DEBUG] should_generate_chart: df为空,返回False")
  220. return False
  221. # 如果数据有多行或多列,通常适合生成图表
  222. result = len(df) > 1 or len(df.columns) > 1
  223. print(f"[DEBUG] should_generate_chart: df.shape={df.shape}, 返回{result}")
  224. if result:
  225. return True
  226. return False
  227. def generate_sql(self, question: str, **kwargs) -> str:
  228. """
  229. 重写父类的 generate_sql 方法,增加异常处理和解释性文本保存
  230. """
  231. try:
  232. # 清空上次的解释性文本
  233. self.last_llm_explanation = None
  234. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  235. # 调用父类的 generate_sql
  236. sql = super().generate_sql(question, **kwargs)
  237. if not sql or sql.strip() == "":
  238. print(f"[WARNING] 生成的SQL为空")
  239. self.last_llm_explanation = "无法生成SQL查询,可能是问题描述不够清晰或缺少必要的数据表信息。"
  240. return None
  241. # 替换 "\_" 为 "_",解决特殊字符转义问题
  242. sql = sql.replace("\\_", "_")
  243. # 检查返回内容是否为有效SQL或错误信息
  244. sql_lower = sql.lower().strip()
  245. # 检查是否包含错误提示信息
  246. error_indicators = [
  247. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  248. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  249. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  250. ]
  251. for indicator in error_indicators:
  252. if indicator in sql_lower:
  253. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  254. # 保存LLM的解释性文本
  255. self.last_llm_explanation = sql
  256. return None
  257. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  258. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  259. if not any(keyword in sql_lower for keyword in sql_keywords):
  260. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  261. # 保存LLM的解释性文本
  262. self.last_llm_explanation = sql
  263. return None
  264. print(f"[SUCCESS] 成功生成SQL:\n {sql}")
  265. # 清空解释性文本
  266. self.last_llm_explanation = None
  267. return sql
  268. except Exception as e:
  269. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  270. print(f"[ERROR] 异常类型: {type(e).__name__}")
  271. # 导入traceback以获取详细错误信息
  272. import traceback
  273. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  274. self.last_llm_explanation = f"SQL生成过程中出现异常: {str(e)}"
  275. return None
  276. def generate_question(self, sql: str, **kwargs) -> str:
  277. """根据SQL生成中文问题"""
  278. prompt = [
  279. self.system_message(
  280. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  281. ),
  282. self.user_message(sql)
  283. ]
  284. response = self.submit_prompt(prompt, **kwargs)
  285. return response
  286. def chat_with_llm(self, question: str, **kwargs) -> str:
  287. """
  288. 直接与LLM对话,不涉及SQL生成
  289. """
  290. try:
  291. prompt = [
  292. self.system_message(
  293. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  294. ),
  295. self.user_message(question)
  296. ]
  297. response = self.submit_prompt(prompt, **kwargs)
  298. return response
  299. except Exception as e:
  300. print(f"[ERROR] LLM对话失败: {str(e)}")
  301. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
  302. def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
  303. """
  304. 重写问题合并方法,通过配置参数控制是否启用合并功能
  305. Args:
  306. last_question (str): 上一个问题
  307. new_question (str): 新问题
  308. **kwargs: 其他参数
  309. Returns:
  310. str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
  311. """
  312. # 如果未启用合并功能或没有上一个问题,直接返回新问题
  313. if not REWRITE_QUESTION_ENABLED or last_question is None:
  314. print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
  315. return new_question
  316. print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
  317. print(f"[DEBUG] 上一个问题: {last_question}")
  318. print(f"[DEBUG] 新问题: {new_question}")
  319. try:
  320. prompt = [
  321. self.system_message(
  322. "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
  323. "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
  324. "请用中文回答。"
  325. ),
  326. self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
  327. ]
  328. rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
  329. print(f"[DEBUG] 合并后的问题: {rewritten_question}")
  330. return rewritten_question
  331. except Exception as e:
  332. print(f"[ERROR] 问题合并失败: {str(e)}")
  333. # 如果合并失败,返回新问题
  334. return new_question
  335. def generate_summary(self, question: str, df, **kwargs) -> str:
  336. """
  337. 覆盖父类的 generate_summary 方法,添加中文思考和回答指令
  338. Args:
  339. question (str): 用户提出的问题
  340. df: 查询结果的 DataFrame
  341. **kwargs: 其他参数
  342. Returns:
  343. str: 数据摘要
  344. """
  345. try:
  346. # 导入 pandas 用于 DataFrame 处理
  347. import pandas as pd
  348. # 确保 df 是 pandas DataFrame
  349. if not isinstance(df, pd.DataFrame):
  350. print(f"[WARNING] df 不是 pandas DataFrame,类型: {type(df)}")
  351. return "无法生成摘要:数据格式不正确"
  352. if df.empty:
  353. return "查询结果为空,无数据可供摘要。"
  354. print(f"[DEBUG] 生成摘要 - 问题: {question}")
  355. print(f"[DEBUG] DataFrame 形状: {df.shape}")
  356. # 构建包含中文指令的系统消息
  357. system_content = (
  358. f"你是一个专业的数据分析助手。用户提出了问题:'{question}'\n\n"
  359. f"以下是查询结果的 pandas DataFrame 数据:\n{df.to_markdown()}\n\n"
  360. "请用中文进行思考和分析,并用中文回答。"
  361. )
  362. # 构建用户消息,强调中文思考和回答
  363. user_content = (
  364. "请基于用户提出的问题,简要总结这些数据。要求:\n"
  365. "1. 只进行简要总结,不要添加额外的解释\n"
  366. "2. 如果数据中有数字,请保留适当的精度\n"
  367. )
  368. message_log = [
  369. self.system_message(system_content),
  370. self.user_message(user_content)
  371. ]
  372. summary = self.submit_prompt(message_log, **kwargs)
  373. # 检查是否需要隐藏 thinking 内容
  374. display_thinking = kwargs.get("display_summary_thinking", DISPLAY_SUMMARY_THINKING)
  375. if not display_thinking:
  376. # 移除 <think></think> 标签及其内容
  377. original_summary = summary
  378. summary = self._remove_thinking_content(summary)
  379. print(f"[DEBUG] 隐藏thinking内容 - 原始长度: {len(original_summary)}, 处理后长度: {len(summary)}")
  380. print(f"[DEBUG] 生成的摘要: {summary[:100]}...")
  381. return summary
  382. except Exception as e:
  383. print(f"[ERROR] 生成摘要失败: {str(e)}")
  384. import traceback
  385. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  386. return f"生成摘要时出现错误:{str(e)}"
  387. def _remove_thinking_content(self, text: str) -> str:
  388. """
  389. 移除文本中的 <think></think> 标签及其内容
  390. Args:
  391. text (str): 包含可能的 thinking 标签的文本
  392. Returns:
  393. str: 移除 thinking 内容后的文本
  394. """
  395. if not text:
  396. return text
  397. import re
  398. # 移除 <think>...</think> 标签及其内容(支持多行)
  399. # 使用 re.DOTALL 标志使 . 匹配包括换行符在内的任何字符
  400. cleaned_text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL | re.IGNORECASE)
  401. # 移除可能的多余空行
  402. cleaned_text = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_text)
  403. # 去除开头和结尾的空白字符
  404. cleaned_text = cleaned_text.strip()
  405. return cleaned_text
  406. def ask(
  407. self,
  408. question: Union[str, None] = None,
  409. print_results: bool = True,
  410. auto_train: bool = True,
  411. visualize: bool = True,
  412. allow_llm_to_see_data: bool = False,
  413. ) -> Union[
  414. Tuple[
  415. Union[str, None],
  416. Union[pd.DataFrame, None],
  417. Union[plotly.graph_objs.Figure, None],
  418. ],
  419. None,
  420. ]:
  421. """
  422. 重载父类的ask方法,处理LLM解释性文本
  423. 当generate_sql无法生成SQL时,保存解释性文本供API层使用
  424. """
  425. if question is None:
  426. question = input("Enter a question: ")
  427. # 清空上次的解释性文本
  428. self.last_llm_explanation = None
  429. try:
  430. sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
  431. except Exception as e:
  432. print(e)
  433. self.last_llm_explanation = str(e)
  434. if print_results:
  435. return None
  436. else:
  437. return None, None, None
  438. # 如果SQL为空,说明有解释性文本,按照正常流程返回None
  439. # API层会检查 last_llm_explanation 来获取解释
  440. if sql is None:
  441. print(f"[INFO] 无法生成SQL,解释: {self.last_llm_explanation}")
  442. if print_results:
  443. return None
  444. else:
  445. return None, None, None
  446. # 以下是正常的SQL执行流程(保持VannaBase原有逻辑)
  447. if print_results:
  448. print(sql)
  449. if self.run_sql_is_set is False:
  450. print("If you want to run the SQL query, connect to a database first.")
  451. if print_results:
  452. return None
  453. else:
  454. return sql, None, None
  455. try:
  456. df = self.run_sql(sql)
  457. if df is None:
  458. print("The SQL query returned no results.")
  459. if print_results:
  460. return None
  461. else:
  462. return sql, None, None
  463. if print_results:
  464. # 显示结果表格
  465. if len(df) > 10:
  466. print(df.head(10).to_string())
  467. print(f"... ({len(df)} rows)")
  468. else:
  469. print(df.to_string())
  470. # 如果启用了自动训练,添加问题-SQL对到训练集
  471. if auto_train:
  472. try:
  473. self.add_question_sql(question=question, sql=sql)
  474. except Exception as e:
  475. print(f"Could not add question and sql to training data: {e}")
  476. if visualize:
  477. try:
  478. # 检查是否应该生成图表
  479. if self.should_generate_chart(df):
  480. plotly_code = self.generate_plotly_code(
  481. question=question,
  482. sql=sql,
  483. df=df,
  484. chart_instructions=""
  485. )
  486. if plotly_code is not None and plotly_code.strip() != "":
  487. fig = self.get_plotly_figure(
  488. plotly_code=plotly_code,
  489. df=df,
  490. dark_mode=False
  491. )
  492. if fig is not None:
  493. if print_results:
  494. print("Chart generated (use fig.show() to display)")
  495. return sql, df, fig
  496. else:
  497. print("Could not generate chart")
  498. return sql, df, None
  499. else:
  500. print("No chart generated")
  501. return sql, df, None
  502. else:
  503. print("Not generating chart for this data")
  504. return sql, df, None
  505. except Exception as e:
  506. print(f"Couldn't generate chart: {e}")
  507. return sql, df, None
  508. else:
  509. return sql, df, None
  510. except Exception as e:
  511. print("Couldn't run sql: ", e)
  512. if print_results:
  513. return None
  514. else:
  515. return sql, None, None
  516. @abstractmethod
  517. def submit_prompt(self, prompt, **kwargs) -> str:
  518. """
  519. 子类必须实现的核心提交方法
  520. Args:
  521. prompt: 消息列表
  522. **kwargs: 其他参数
  523. Returns:
  524. str: LLM的响应
  525. """
  526. pass