Custom_QianwenAI_chat.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. class QianWenAI_Chat(VannaBase):
  5. def __init__(self, client=None, config=None):
  6. print("...QianWenAI_Chat init...")
  7. VannaBase.__init__(self, config=config)
  8. print("传入的 config 参数如下:")
  9. for key, value in self.config.items():
  10. print(f" {key}: {value}")
  11. # default parameters - can be overrided using config
  12. self.temperature = 0.7
  13. if "temperature" in config:
  14. print(f"temperature is changed to: {config['temperature']}")
  15. self.temperature = config["temperature"]
  16. if "api_type" in config:
  17. raise Exception(
  18. "Passing api_type is now deprecated. Please pass an OpenAI client instead."
  19. )
  20. if "api_base" in config:
  21. raise Exception(
  22. "Passing api_base is now deprecated. Please pass an OpenAI client instead."
  23. )
  24. if "api_version" in config:
  25. raise Exception(
  26. "Passing api_version is now deprecated. Please pass an OpenAI client instead."
  27. )
  28. if client is not None:
  29. self.client = client
  30. return
  31. if config is None and client is None:
  32. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  33. return
  34. if "api_key" in config:
  35. if "base_url" not in config:
  36. self.client = OpenAI(api_key=config["api_key"],
  37. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  38. else:
  39. self.client = OpenAI(api_key=config["api_key"],
  40. base_url=config["base_url"])
  41. # 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
  42. def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
  43. """
  44. 基于VannaBase源码实现,在第7点添加中文别名指令
  45. """
  46. print(f"[DEBUG] 开始生成SQL提示词,问题: {question}")
  47. if initial_prompt is None:
  48. initial_prompt = f"You are a {self.dialect} expert. " + \
  49. "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
  50. initial_prompt = self.add_ddl_to_prompt(
  51. initial_prompt, ddl_list, max_tokens=self.max_tokens
  52. )
  53. if self.static_documentation != "":
  54. doc_list.append(self.static_documentation)
  55. initial_prompt = self.add_documentation_to_prompt(
  56. initial_prompt, doc_list, max_tokens=self.max_tokens
  57. )
  58. initial_prompt += (
  59. "===Response Guidelines \n"
  60. "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
  61. "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
  62. "3. If the provided context is insufficient, please explain why it can't be generated. \n"
  63. "4. Please use the most relevant table(s). \n"
  64. "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
  65. f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
  66. "7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
  67. " - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
  68. " - 包括原始字段名也要添加中文别名,例如:gender AS 性别, card_category AS 卡片类型\n"
  69. " - 计算字段也要有中文别名,例如:COUNT(*) AS 持卡人数\n"
  70. " - 中文别名要准确反映字段的业务含义\n"
  71. " - 绝对不能有任何字段没有中文别名,这会影响表格的可读性\n"
  72. " - 这样可以提高图表的可读性和用户体验\n"
  73. " 正确示例:SELECT gender AS 性别, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
  74. " 错误示例:SELECT gender, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
  75. )
  76. message_log = [self.system_message(initial_prompt)]
  77. for example in question_sql_list:
  78. if example is None:
  79. print("example is None")
  80. else:
  81. if example is not None and "question" in example and "sql" in example:
  82. message_log.append(self.user_message(example["question"]))
  83. message_log.append(self.assistant_message(example["sql"]))
  84. message_log.append(self.user_message(question))
  85. print(f"[DEBUG] SQL提示词生成完成,消息数量: {len(message_log)}")
  86. return message_log
  87. # 生成图形的时候,使用中文标注
  88. def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
  89. """
  90. 重写父类方法,添加明确的中文图表指令
  91. """
  92. # 构建更智能的中文图表指令,根据问题和数据内容生成有意义的标签
  93. chinese_chart_instructions = (
  94. "使用中文创建图表,要求:\n"
  95. "1. 根据用户问题和数据内容,为图表生成有意义的中文标题\n"
  96. "2. 根据数据列的实际含义,为X轴和Y轴生成准确的中文标签\n"
  97. "3. 如果有图例,确保图例标签使用中文\n"
  98. "4. 所有文本(包括标题、轴标签、图例、数据标签等)都必须使用中文\n"
  99. "5. 标题应该简洁明了地概括图表要展示的内容\n"
  100. "6. 轴标签应该准确反映对应数据列的业务含义\n"
  101. "7. 选择最适合数据特点的图表类型(柱状图、折线图、饼图等)"
  102. )
  103. # 构建父类方法要求的message_log
  104. system_msg_parts = []
  105. if question:
  106. system_msg_parts.append(
  107. f"用户问题:'{question}'"
  108. )
  109. system_msg_parts.append(
  110. f"以下是回答用户问题的pandas DataFrame数据:"
  111. )
  112. else:
  113. system_msg_parts.append("以下是一个pandas DataFrame数据:")
  114. if sql:
  115. system_msg_parts.append(f"数据来源SQL查询:\n{sql}")
  116. system_msg_parts.append(f"DataFrame结构信息:\n{df_metadata}")
  117. system_msg = "\n\n".join(system_msg_parts)
  118. # 构建更详细的用户消息,强调中文标签的重要性
  119. user_msg = (
  120. "请为这个DataFrame生成Python Plotly可视化代码。要求:\n\n"
  121. "1. 假设数据存储在名为'df'的pandas DataFrame中\n"
  122. "2. 如果DataFrame只有一个值,使用Indicator图表\n"
  123. "3. 只返回Python代码,不要任何解释\n"
  124. "4. 代码必须可以直接运行\n\n"
  125. f"{chinese_chart_instructions}\n\n"
  126. "特别注意:\n"
  127. "- 不要使用'图表标题'、'X轴标签'、'Y轴标签'这样的通用标签\n"
  128. "- 要根据实际数据内容和用户问题生成具体、有意义的中文标签\n"
  129. "- 例如:如果是性别统计,X轴可能是'性别',Y轴可能是'人数'或'占比'\n"
  130. "- 标题应该概括图表的主要内容,如'男女持卡比例分布'\n\n"
  131. "数据标签和悬停信息要求:\n"
  132. "- 不要使用%{text}这样的占位符变量\n"
  133. "- 使用具体的数据值和中文单位,例如:text=df['列名'].astype(str) + '人'\n"
  134. "- 悬停信息要清晰易懂,使用中文描述\n"
  135. "- 确保所有显示的文本都是实际的数据值,不是变量占位符"
  136. )
  137. message_log = [
  138. self.system_message(system_msg),
  139. self.user_message(user_msg),
  140. ]
  141. # 调用父类submit_prompt方法,并清理结果
  142. plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
  143. return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
  144. def system_message(self, message: str) -> any:
  145. print(f"system_content: {message}")
  146. return {"role": "system", "content": message}
  147. def user_message(self, message: str) -> any:
  148. print(f"\nuser_content: {message}")
  149. return {"role": "user", "content": message}
  150. def assistant_message(self, message: str) -> any:
  151. print(f"assistant_content: {message}")
  152. return {"role": "assistant", "content": message}
  153. def submit_prompt(self, prompt, **kwargs) -> str:
  154. if prompt is None:
  155. raise Exception("Prompt is None")
  156. if len(prompt) == 0:
  157. raise Exception("Prompt is empty")
  158. # Count the number of tokens in the message log
  159. # Use 4 as an approximation for the number of characters per token
  160. num_tokens = 0
  161. for message in prompt:
  162. num_tokens += len(message["content"]) / 4
  163. # 从配置和参数中获取enable_thinking设置
  164. # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
  165. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
  166. # 公共参数
  167. common_params = {
  168. "messages": prompt,
  169. "stop": None,
  170. "temperature": self.temperature,
  171. }
  172. # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
  173. if enable_thinking:
  174. common_params["stream"] = True
  175. # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
  176. # 也可能它只是默认启用stream=True时的thinking功能
  177. model = None
  178. # 确定使用的模型
  179. if kwargs.get("model", None) is not None:
  180. model = kwargs.get("model", None)
  181. common_params["model"] = model
  182. elif kwargs.get("engine", None) is not None:
  183. engine = kwargs.get("engine", None)
  184. common_params["engine"] = engine
  185. model = engine
  186. elif self.config is not None and "engine" in self.config:
  187. common_params["engine"] = self.config["engine"]
  188. model = self.config["engine"]
  189. elif self.config is not None and "model" in self.config:
  190. common_params["model"] = self.config["model"]
  191. model = self.config["model"]
  192. else:
  193. if num_tokens > 3500:
  194. model = "qwen-long"
  195. else:
  196. model = "qwen-plus"
  197. common_params["model"] = model
  198. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  199. if enable_thinking:
  200. # 流式处理模式
  201. print("使用流式处理模式,启用thinking功能")
  202. # 检查是否需要通过headers传递enable_thinking参数
  203. response_stream = self.client.chat.completions.create(**common_params)
  204. # 收集流式响应
  205. collected_thinking = []
  206. collected_content = []
  207. for chunk in response_stream:
  208. # 处理thinking部分
  209. if hasattr(chunk, 'thinking') and chunk.thinking:
  210. collected_thinking.append(chunk.thinking)
  211. # 处理content部分
  212. if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  213. collected_content.append(chunk.choices[0].delta.content)
  214. # 可以在这里处理thinking的展示逻辑,如保存到日志等
  215. if collected_thinking:
  216. print("Model thinking process:", "".join(collected_thinking))
  217. # 返回完整的内容
  218. return "".join(collected_content)
  219. else:
  220. # 非流式处理模式
  221. print("使用非流式处理模式")
  222. response = self.client.chat.completions.create(**common_params)
  223. # Find the first response from the chatbot that has text in it (some responses may not have text)
  224. for choice in response.choices:
  225. if "text" in choice:
  226. return choice.text
  227. # If no response with text is found, return the first response's content (which may be empty)
  228. return response.choices[0].message.content
  229. # 重写 generate_sql 方法以增加异常处理
  230. def generate_sql(self, question: str, **kwargs) -> str:
  231. """
  232. 重写父类的 generate_sql 方法,增加异常处理
  233. """
  234. try:
  235. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  236. # 调用父类的 generate_sql
  237. sql = super().generate_sql(question, **kwargs)
  238. if not sql or sql.strip() == "":
  239. print(f"[WARNING] 生成的SQL为空")
  240. return None
  241. # 检查返回内容是否为有效SQL或错误信息
  242. sql_lower = sql.lower().strip()
  243. # 检查是否包含错误提示信息
  244. error_indicators = [
  245. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  246. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  247. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  248. ]
  249. for indicator in error_indicators:
  250. if indicator in sql_lower:
  251. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  252. return None
  253. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  254. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  255. if not any(keyword in sql_lower for keyword in sql_keywords):
  256. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  257. return None
  258. print(f"[SUCCESS] 成功生成SQL: {sql}")
  259. return sql
  260. except Exception as e:
  261. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  262. print(f"[ERROR] 异常类型: {type(e).__name__}")
  263. # 导入traceback以获取详细错误信息
  264. import traceback
  265. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  266. # 返回 None 而不是抛出异常
  267. return None
  268. # 为了解决通过sql生成question时,question是英文的问题。
  269. def generate_question(self, sql: str, **kwargs) -> str:
  270. # 这里可以自定义提示词/逻辑
  271. prompt = [
  272. self.system_message(
  273. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  274. ),
  275. self.user_message(sql)
  276. ]
  277. response = self.submit_prompt(prompt, **kwargs)
  278. # 你也可以在这里对response做后处理
  279. return response
  280. # 新增:直接与LLM对话的方法
  281. def chat_with_llm(self, question: str, **kwargs) -> str:
  282. """
  283. 直接与LLM对话,不涉及SQL生成
  284. """
  285. try:
  286. prompt = [
  287. self.system_message(
  288. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  289. ),
  290. self.user_message(question)
  291. ]
  292. response = self.submit_prompt(prompt, **kwargs)
  293. return response
  294. except Exception as e:
  295. print(f"[ERROR] LLM对话失败: {str(e)}")
  296. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"