Custom_QianwenAI_chat.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. # 导入配置参数
  5. from app_config import REWRITE_QUESTION_ENABLED
  6. class QianWenAI_Chat(VannaBase):
  7. def __init__(self, client=None, config=None):
  8. print("...QianWenAI_Chat init...")
  9. VannaBase.__init__(self, config=config)
  10. print("传入的 config 参数如下:")
  11. for key, value in self.config.items():
  12. print(f" {key}: {value}")
  13. # default parameters - can be overrided using config
  14. self.temperature = 0.7
  15. if "temperature" in config:
  16. print(f"temperature is changed to: {config['temperature']}")
  17. self.temperature = config["temperature"]
  18. if "api_type" in config:
  19. raise Exception(
  20. "Passing api_type is now deprecated. Please pass an OpenAI client instead."
  21. )
  22. if "api_base" in config:
  23. raise Exception(
  24. "Passing api_base is now deprecated. Please pass an OpenAI client instead."
  25. )
  26. if "api_version" in config:
  27. raise Exception(
  28. "Passing api_version is now deprecated. Please pass an OpenAI client instead."
  29. )
  30. if client is not None:
  31. self.client = client
  32. return
  33. if config is None and client is None:
  34. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  35. return
  36. if "api_key" in config:
  37. if "base_url" not in config:
  38. self.client = OpenAI(api_key=config["api_key"],
  39. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  40. else:
  41. self.client = OpenAI(api_key=config["api_key"],
  42. base_url=config["base_url"])
  43. # 新增:加载错误SQL提示配置
  44. self.enable_error_sql_prompt = self._load_error_sql_prompt_config()
  45. def _load_error_sql_prompt_config(self) -> bool:
  46. """从app_config.py加载错误SQL提示配置"""
  47. try:
  48. import app_config
  49. enable_error_sql = getattr(app_config, 'ENABLE_ERROR_SQL_PROMPT', False)
  50. print(f"[DEBUG] 错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
  51. return enable_error_sql
  52. except (ImportError, AttributeError) as e:
  53. print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
  54. return False
  55. # 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
  56. def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
  57. """
  58. 基于VannaBase源码实现,在第7点添加中文别名指令
  59. """
  60. print(f"[DEBUG] 开始生成SQL提示词,问题: {question}")
  61. if initial_prompt is None:
  62. initial_prompt = f"You are a {self.dialect} expert. " + \
  63. "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
  64. # 提取DDL内容(适配新的字典格式)
  65. ddl_content_list = []
  66. if ddl_list:
  67. for item in ddl_list:
  68. if isinstance(item, dict) and "content" in item:
  69. ddl_content_list.append(item["content"])
  70. elif isinstance(item, str):
  71. ddl_content_list.append(item)
  72. initial_prompt = self.add_ddl_to_prompt(
  73. initial_prompt, ddl_content_list, max_tokens=self.max_tokens
  74. )
  75. # 提取文档内容(适配新的字典格式)
  76. doc_content_list = []
  77. if doc_list:
  78. for item in doc_list:
  79. if isinstance(item, dict) and "content" in item:
  80. doc_content_list.append(item["content"])
  81. elif isinstance(item, str):
  82. doc_content_list.append(item)
  83. if self.static_documentation != "":
  84. doc_content_list.append(self.static_documentation)
  85. initial_prompt = self.add_documentation_to_prompt(
  86. initial_prompt, doc_content_list, max_tokens=self.max_tokens
  87. )
  88. # 新增:添加错误SQL示例作为负面示例(放在Response Guidelines之前)
  89. if self.enable_error_sql_prompt:
  90. try:
  91. error_sql_list = self.get_related_error_sql(question, **kwargs)
  92. if error_sql_list:
  93. print(f"[DEBUG] 找到 {len(error_sql_list)} 个相关的错误SQL示例")
  94. # 构建格式化的负面提示内容
  95. negative_prompt_content = "===Negative Examples\n"
  96. negative_prompt_content += "下面是错误的SQL示例,请分析这些错误SQL的问题所在,并在生成新SQL时避免类似错误:\n\n"
  97. for i, error_example in enumerate(error_sql_list, 1):
  98. if "question" in error_example and "sql" in error_example:
  99. similarity = error_example.get('similarity', 'N/A')
  100. print(f"[DEBUG] 错误SQL示例 {i}: 相似度={similarity}")
  101. negative_prompt_content += f"问题: {error_example['question']}\n"
  102. negative_prompt_content += f"错误的SQL: {error_example['sql']}\n\n"
  103. # 将负面提示添加到初始提示中
  104. initial_prompt += negative_prompt_content
  105. else:
  106. print("[DEBUG] 未找到相关的错误SQL示例")
  107. except Exception as e:
  108. print(f"[WARNING] 获取错误SQL示例失败: {e}")
  109. initial_prompt += (
  110. "===Response Guidelines \n"
  111. "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
  112. "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
  113. "3. If the provided context is insufficient, please explain why it can't be generated. \n"
  114. "4. Please use the most relevant table(s). \n"
  115. "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
  116. f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
  117. "7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
  118. " - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
  119. " - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
  120. " - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
  121. " - 中文别名要准确反映字段的业务含义"
  122. )
  123. message_log = [self.system_message(initial_prompt)]
  124. for example in question_sql_list:
  125. if example is None:
  126. print("example is None")
  127. else:
  128. if example is not None and "question" in example and "sql" in example:
  129. message_log.append(self.user_message(example["question"]))
  130. message_log.append(self.assistant_message(example["sql"]))
  131. message_log.append(self.user_message(question))
  132. return message_log
  133. # 生成图形的时候,使用中文标注
  134. def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
  135. """
  136. 重写父类方法,添加明确的中文图表指令
  137. """
  138. # 构建更智能的中文图表指令,根据问题和数据内容生成有意义的标签
  139. chinese_chart_instructions = (
  140. "使用中文创建图表,要求:\n"
  141. "1. 根据用户问题和数据内容,为图表生成有意义的中文标题\n"
  142. "2. 根据数据列的实际含义,为X轴和Y轴生成准确的中文标签\n"
  143. "3. 如果有图例,确保图例标签使用中文\n"
  144. "4. 所有文本(包括标题、轴标签、图例、数据标签等)都必须使用中文\n"
  145. "5. 标题应该简洁明了地概括图表要展示的内容\n"
  146. "6. 轴标签应该准确反映对应数据列的业务含义\n"
  147. "7. 选择最适合数据特点的图表类型(柱状图、折线图、饼图等)"
  148. )
  149. # 构建父类方法要求的message_log
  150. system_msg_parts = []
  151. if question:
  152. system_msg_parts.append(
  153. f"用户问题:'{question}'"
  154. )
  155. system_msg_parts.append(
  156. f"以下是回答用户问题的pandas DataFrame数据:"
  157. )
  158. else:
  159. system_msg_parts.append("以下是一个pandas DataFrame数据:")
  160. if sql:
  161. system_msg_parts.append(f"数据来源SQL查询:\n{sql}")
  162. system_msg_parts.append(f"DataFrame结构信息:\n{df_metadata}")
  163. system_msg = "\n\n".join(system_msg_parts)
  164. # 构建更详细的用户消息,强调中文标签的重要性
  165. user_msg = (
  166. "请为这个DataFrame生成Python Plotly可视化代码。要求:\n\n"
  167. "1. 假设数据存储在名为'df'的pandas DataFrame中\n"
  168. "2. 如果DataFrame只有一个值,使用Indicator图表\n"
  169. "3. 只返回Python代码,不要任何解释\n"
  170. "4. 代码必须可以直接运行\n\n"
  171. f"{chinese_chart_instructions}\n\n"
  172. "特别注意:\n"
  173. "- 不要使用'图表标题'、'X轴标签'、'Y轴标签'这样的通用标签\n"
  174. "- 要根据实际数据内容和用户问题生成具体、有意义的中文标签\n"
  175. "- 例如:如果是性别统计,X轴可能是'性别',Y轴可能是'人数'或'占比'\n"
  176. "- 标题应该概括图表的主要内容,如'男女持卡比例分布'\n\n"
  177. "数据标签和悬停信息要求:\n"
  178. "- 不要使用%{text}这样的占位符变量\n"
  179. "- 使用具体的数据值和中文单位,例如:text=df['列名'].astype(str) + '人'\n"
  180. "- 悬停信息要清晰易懂,使用中文描述\n"
  181. "- 确保所有显示的文本都是实际的数据值,不是变量占位符"
  182. )
  183. message_log = [
  184. self.system_message(system_msg),
  185. self.user_message(user_msg),
  186. ]
  187. # 调用父类submit_prompt方法,并清理结果
  188. plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
  189. return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
  190. def system_message(self, message: str) -> any:
  191. print(f"system_content: {message}")
  192. return {"role": "system", "content": message}
  193. def user_message(self, message: str) -> any:
  194. print(f"\nuser_content: {message}")
  195. return {"role": "user", "content": message}
  196. def assistant_message(self, message: str) -> any:
  197. print(f"assistant_content: {message}")
  198. return {"role": "assistant", "content": message}
  199. def should_generate_chart(self, df) -> bool:
  200. """
  201. 判断是否应该生成图表
  202. 对于Flask应用,这个方法决定了前端是否显示图表生成按钮
  203. """
  204. if df is None or df.empty:
  205. print(f"[DEBUG] should_generate_chart: df为空,返回False")
  206. return False
  207. # 如果数据有多行或多列,通常适合生成图表
  208. result = len(df) > 1 or len(df.columns) > 1
  209. print(f"[DEBUG] should_generate_chart: df.shape={df.shape}, 返回{result}")
  210. if result:
  211. return True
  212. return False
  213. # def get_plotly_figure(self, plotly_code: str, df, dark_mode: bool = True):
  214. # """
  215. # 重写父类方法,确保Flask应用也使用我们的自定义图表生成逻辑
  216. # 这个方法会被VannaFlaskApp调用,而不是generate_plotly_code
  217. # """
  218. # print(f"[DEBUG] get_plotly_figure被调用,plotly_code长度: {len(plotly_code) if plotly_code else 0}")
  219. # # 如果没有提供plotly_code,尝试生成一个
  220. # if not plotly_code or plotly_code.strip() == "":
  221. # print(f"[DEBUG] plotly_code为空,尝试生成默认图表")
  222. # # 生成一个简单的默认图表
  223. # df_metadata = f"DataFrame形状: {df.shape}\n列名: {list(df.columns)}\n数据类型:\n{df.dtypes}"
  224. # plotly_code = self.generate_plotly_code(
  225. # question="数据可视化",
  226. # sql=None,
  227. # df_metadata=df_metadata
  228. # )
  229. # # 调用父类方法执行plotly代码
  230. # try:
  231. # return super().get_plotly_figure(plotly_code=plotly_code, df=df, dark_mode=dark_mode)
  232. # except Exception as e:
  233. # print(f"[ERROR] 执行plotly代码失败: {e}")
  234. # print(f"[ERROR] plotly_code: {plotly_code}")
  235. # # 如果执行失败,返回None或生成一个简单的备用图表
  236. # return None
  237. def submit_prompt(self, prompt, **kwargs) -> str:
  238. if prompt is None:
  239. raise Exception("Prompt is None")
  240. if len(prompt) == 0:
  241. raise Exception("Prompt is empty")
  242. # Count the number of tokens in the message log
  243. # Use 4 as an approximation for the number of characters per token
  244. num_tokens = 0
  245. for message in prompt:
  246. num_tokens += len(message["content"]) / 4
  247. # 从配置和参数中获取enable_thinking设置
  248. # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
  249. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
  250. # 公共参数
  251. common_params = {
  252. "messages": prompt,
  253. "stop": None,
  254. "temperature": self.temperature,
  255. }
  256. # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
  257. if enable_thinking:
  258. common_params["stream"] = True
  259. # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
  260. # 也可能它只是默认启用stream=True时的thinking功能
  261. model = None
  262. # 确定使用的模型
  263. if kwargs.get("model", None) is not None:
  264. model = kwargs.get("model", None)
  265. common_params["model"] = model
  266. elif kwargs.get("engine", None) is not None:
  267. engine = kwargs.get("engine", None)
  268. common_params["engine"] = engine
  269. model = engine
  270. elif self.config is not None and "engine" in self.config:
  271. common_params["engine"] = self.config["engine"]
  272. model = self.config["engine"]
  273. elif self.config is not None and "model" in self.config:
  274. common_params["model"] = self.config["model"]
  275. model = self.config["model"]
  276. else:
  277. if num_tokens > 3500:
  278. model = "qwen-long"
  279. else:
  280. model = "qwen-plus"
  281. common_params["model"] = model
  282. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  283. if enable_thinking:
  284. # 流式处理模式
  285. print("使用流式处理模式,启用thinking功能")
  286. # 检查是否需要通过headers传递enable_thinking参数
  287. response_stream = self.client.chat.completions.create(**common_params)
  288. # 收集流式响应
  289. collected_thinking = []
  290. collected_content = []
  291. for chunk in response_stream:
  292. # 处理thinking部分
  293. if hasattr(chunk, 'thinking') and chunk.thinking:
  294. collected_thinking.append(chunk.thinking)
  295. # 处理content部分
  296. if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  297. collected_content.append(chunk.choices[0].delta.content)
  298. # 可以在这里处理thinking的展示逻辑,如保存到日志等
  299. if collected_thinking:
  300. print("Model thinking process:", "".join(collected_thinking))
  301. # 返回完整的内容
  302. return "".join(collected_content)
  303. else:
  304. # 非流式处理模式
  305. print("使用非流式处理模式")
  306. response = self.client.chat.completions.create(**common_params)
  307. # Find the first response from the chatbot that has text in it (some responses may not have text)
  308. for choice in response.choices:
  309. if "text" in choice:
  310. return choice.text
  311. # If no response with text is found, return the first response's content (which may be empty)
  312. return response.choices[0].message.content
  313. # 重写 generate_sql 方法以增加异常处理
  314. def generate_sql(self, question: str, **kwargs) -> str:
  315. """
  316. 重写父类的 generate_sql 方法,增加异常处理
  317. """
  318. try:
  319. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  320. # 调用父类的 generate_sql
  321. sql = super().generate_sql(question, **kwargs)
  322. if not sql or sql.strip() == "":
  323. print(f"[WARNING] 生成的SQL为空")
  324. return None
  325. # 检查返回内容是否为有效SQL或错误信息
  326. sql_lower = sql.lower().strip()
  327. # 检查是否包含错误提示信息
  328. error_indicators = [
  329. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  330. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  331. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  332. ]
  333. for indicator in error_indicators:
  334. if indicator in sql_lower:
  335. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  336. return None
  337. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  338. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  339. if not any(keyword in sql_lower for keyword in sql_keywords):
  340. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  341. return None
  342. print(f"[SUCCESS] 成功生成SQL: {sql}")
  343. return sql
  344. except Exception as e:
  345. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  346. print(f"[ERROR] 异常类型: {type(e).__name__}")
  347. # 导入traceback以获取详细错误信息
  348. import traceback
  349. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  350. # 返回 None 而不是抛出异常
  351. return None
  352. # 为了解决通过sql生成question时,question是英文的问题。
  353. def generate_question(self, sql: str, **kwargs) -> str:
  354. # 这里可以自定义提示词/逻辑
  355. prompt = [
  356. self.system_message(
  357. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  358. ),
  359. self.user_message(sql)
  360. ]
  361. response = self.submit_prompt(prompt, **kwargs)
  362. # 你也可以在这里对response做后处理
  363. return response
  364. # 新增:直接与LLM对话的方法
  365. def chat_with_llm(self, question: str, **kwargs) -> str:
  366. """
  367. 直接与LLM对话,不涉及SQL生成
  368. """
  369. try:
  370. prompt = [
  371. self.system_message(
  372. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  373. ),
  374. self.user_message(question)
  375. ]
  376. response = self.submit_prompt(prompt, **kwargs)
  377. return response
  378. except Exception as e:
  379. print(f"[ERROR] LLM对话失败: {str(e)}")
  380. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
  381. def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
  382. """
  383. 重写问题合并方法,通过配置参数控制是否启用合并功能
  384. Args:
  385. last_question (str): 上一个问题
  386. new_question (str): 新问题
  387. **kwargs: 其他参数
  388. Returns:
  389. str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
  390. """
  391. # 如果未启用合并功能或没有上一个问题,直接返回新问题
  392. if not REWRITE_QUESTION_ENABLED or last_question is None:
  393. print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
  394. return new_question
  395. print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
  396. print(f"[DEBUG] 上一个问题: {last_question}")
  397. print(f"[DEBUG] 新问题: {new_question}")
  398. try:
  399. prompt = [
  400. self.system_message(
  401. "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
  402. "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
  403. "请用中文回答。"
  404. ),
  405. self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
  406. ]
  407. rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
  408. print(f"[DEBUG] 合并后的问题: {rewritten_question}")
  409. return rewritten_question
  410. except Exception as e:
  411. print(f"[ERROR] 问题合并失败: {str(e)}")
  412. # 如果合并失败,返回新问题
  413. return new_question