Custom_QiawenAI_chat_cn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. """
  2. 中文千问AI实现
  3. 基于对源码的正确理解,实现正确的方法
  4. """
  5. import os
  6. from openai import OpenAI
  7. from vanna.base import VannaBase
  8. from typing import List, Dict, Any, Optional
  9. class QianWenAI_Chat_CN(VannaBase):
  10. """
  11. 中文千问AI聊天类,直接继承VannaBase
  12. 实现正确的方法名(get_sql_prompt而不是generate_sql_prompt)
  13. """
  14. def __init__(self, client=None, config=None):
  15. """
  16. 初始化中文千问AI实例
  17. Args:
  18. client: 可选,OpenAI兼容的客户端
  19. config: 配置字典,包含API密钥等配置
  20. """
  21. print("初始化QianWenAI_Chat_CN...")
  22. VannaBase.__init__(self, config=config)
  23. print("传入的 config 参数如下:")
  24. for key, value in self.config.items():
  25. print(f" {key}: {value}")
  26. # 设置语言为中文
  27. self.language = "Chinese"
  28. # 默认参数 - 可通过config覆盖
  29. self.temperature = 0.7
  30. if "temperature" in config:
  31. print(f"temperature is changed to: {config['temperature']}")
  32. self.temperature = config["temperature"]
  33. if "api_type" in config:
  34. raise Exception(
  35. "Passing api_type is now deprecated. Please pass an OpenAI client instead."
  36. )
  37. if "api_base" in config:
  38. raise Exception(
  39. "Passing api_base is now deprecated. Please pass an OpenAI client instead."
  40. )
  41. if "api_version" in config:
  42. raise Exception(
  43. "Passing api_version is now deprecated. Please pass an OpenAI client instead."
  44. )
  45. if client is not None:
  46. self.client = client
  47. return
  48. if config is None and client is None:
  49. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  50. return
  51. if "api_key" in config:
  52. if "base_url" not in config:
  53. self.client = OpenAI(api_key=config["api_key"],
  54. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  55. else:
  56. self.client = OpenAI(api_key=config["api_key"],
  57. base_url=config["base_url"])
  58. print("中文千问AI初始化完成")
  59. def _response_language(self) -> str:
  60. """
  61. 返回响应语言指示
  62. """
  63. return "请用中文回答。"
  64. def system_message(self, message: str) -> any:
  65. """
  66. 创建系统消息
  67. """
  68. print(f"[DEBUG] 系统消息: {message}")
  69. return {"role": "system", "content": message}
  70. def user_message(self, message: str) -> any:
  71. """
  72. 创建用户消息
  73. """
  74. print(f"[DEBUG] 用户消息: {message}")
  75. return {"role": "user", "content": message}
  76. def assistant_message(self, message: str) -> any:
  77. """
  78. 创建助手消息
  79. """
  80. print(f"[DEBUG] 助手消息: {message}")
  81. return {"role": "assistant", "content": message}
  82. def submit_prompt(self, prompt, **kwargs) -> str:
  83. """
  84. 提交提示词到LLM
  85. """
  86. if prompt is None:
  87. raise Exception("Prompt is None")
  88. if len(prompt) == 0:
  89. raise Exception("Prompt is empty")
  90. # Count the number of tokens in the message log
  91. # Use 4 as an approximation for the number of characters per token
  92. num_tokens = 0
  93. for message in prompt:
  94. num_tokens += len(message["content"]) / 4
  95. # 从配置和参数中获取enable_thinking设置
  96. # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
  97. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
  98. # 公共参数
  99. common_params = {
  100. "messages": prompt,
  101. "stop": None,
  102. "temperature": self.temperature,
  103. }
  104. # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
  105. if enable_thinking:
  106. common_params["stream"] = True
  107. # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
  108. # 也可能它只是默认启用stream=True时的thinking功能
  109. model = None
  110. # 确定使用的模型
  111. if kwargs.get("model", None) is not None:
  112. model = kwargs.get("model", None)
  113. common_params["model"] = model
  114. elif kwargs.get("engine", None) is not None:
  115. engine = kwargs.get("engine", None)
  116. common_params["engine"] = engine
  117. model = engine
  118. elif self.config is not None and "engine" in self.config:
  119. common_params["engine"] = self.config["engine"]
  120. model = self.config["engine"]
  121. elif self.config is not None and "model" in self.config:
  122. common_params["model"] = self.config["model"]
  123. model = self.config["model"]
  124. else:
  125. if num_tokens > 3500:
  126. model = "qwen-long"
  127. else:
  128. model = "qwen-plus"
  129. common_params["model"] = model
  130. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  131. if enable_thinking:
  132. # 流式处理模式
  133. print("使用流式处理模式,启用thinking功能")
  134. # 检查是否需要通过headers传递enable_thinking参数
  135. response_stream = self.client.chat.completions.create(**common_params)
  136. # 收集流式响应
  137. collected_thinking = []
  138. collected_content = []
  139. for chunk in response_stream:
  140. # 处理thinking部分
  141. if hasattr(chunk, 'thinking') and chunk.thinking:
  142. collected_thinking.append(chunk.thinking)
  143. # 处理content部分
  144. if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  145. collected_content.append(chunk.choices[0].delta.content)
  146. # 可以在这里处理thinking的展示逻辑,如保存到日志等
  147. if collected_thinking:
  148. print("Model thinking process:", "".join(collected_thinking))
  149. # 返回完整的内容
  150. return "".join(collected_content)
  151. else:
  152. # 非流式处理模式
  153. print("使用非流式处理模式")
  154. response = self.client.chat.completions.create(**common_params)
  155. # Find the first response from the chatbot that has text in it (some responses may not have text)
  156. for choice in response.choices:
  157. if "text" in choice:
  158. return choice.text
  159. # If no response with text is found, return the first response's content (which may be empty)
  160. return response.choices[0].message.content
  161. # 核心方法:get_sql_prompt
  162. def get_sql_prompt(self, question: str,
  163. question_sql_list: list,
  164. ddl_list: list,
  165. doc_list: list,
  166. **kwargs) -> List[Dict[str, str]]:
  167. """
  168. 生成SQL查询的中文提示词
  169. """
  170. print("[DEBUG] 正在生成中文SQL提示词...")
  171. print(f"[DEBUG] 问题: {question}")
  172. print(f"[DEBUG] 相关SQL数量: {len(question_sql_list) if question_sql_list else 0}")
  173. print(f"[DEBUG] 相关DDL数量: {len(ddl_list) if ddl_list else 0}")
  174. print(f"[DEBUG] 相关文档数量: {len(doc_list) if doc_list else 0}")
  175. # 获取dialect
  176. dialect = getattr(self, 'dialect', 'SQL')
  177. # 创建基础提示词
  178. messages = [
  179. self.system_message(
  180. f"""你是一个专业的SQL助手,根据用户的问题生成正确的{dialect}查询语句。
  181. 你只需生成SQL语句,不需要任何解释或评论。
  182. 用户问题: {question}
  183. """
  184. )
  185. ]
  186. # 添加相关的DDL(如果有)
  187. if ddl_list and len(ddl_list) > 0:
  188. ddl_text = "\n\n".join([f"-- DDL项 {i+1}:\n{ddl}" for i, ddl in enumerate(ddl_list)])
  189. messages.append(
  190. self.user_message(
  191. f"""
  192. 以下是可能相关的数据库表结构定义,请基于这些信息生成SQL:
  193. {ddl_text}
  194. 记住,这些只是参考信息,可能并不包含所有需要的表和字段。
  195. """
  196. )
  197. )
  198. # 添加相关的文档(如果有)
  199. if doc_list and len(doc_list) > 0:
  200. doc_text = "\n\n".join([f"-- 文档项 {i+1}:\n{doc}" for i, doc in enumerate(doc_list)])
  201. messages.append(
  202. self.user_message(
  203. f"""
  204. 以下是可能有用的业务逻辑文档:
  205. {doc_text}
  206. """
  207. )
  208. )
  209. # 添加相关的问题和SQL(如果有)
  210. if question_sql_list and len(question_sql_list) > 0:
  211. qs_text = ""
  212. for i, qs_item in enumerate(question_sql_list):
  213. qs_text += f"问题 {i+1}: {qs_item.get('question', '')}\n"
  214. qs_text += f"SQL:\n```sql\n{qs_item.get('sql', '')}\n```\n\n"
  215. messages.append(
  216. self.user_message(
  217. f"""
  218. 以下是与当前问题相似的问题及其对应的SQL查询:
  219. {qs_text}
  220. 请参考这些样例来生成当前问题的SQL查询。
  221. """
  222. )
  223. )
  224. # 添加最终的用户请求和限制
  225. messages.append(
  226. self.user_message(
  227. f"""
  228. 根据以上信息,为以下问题生成一个{dialect}查询语句:
  229. 问题: {question}
  230. 要求:
  231. 1. 仅输出SQL语句,不要有任何解释或说明
  232. 2. 确保语法正确,符合{dialect}标准
  233. 3. 不要使用不存在的表或字段
  234. 4. 查询应尽可能高效
  235. """
  236. )
  237. )
  238. return messages
  239. def get_followup_questions_prompt(self,
  240. question: str,
  241. sql: str,
  242. df_metadata: str,
  243. **kwargs) -> List[Dict[str, str]]:
  244. """
  245. 生成后续问题的中文提示词
  246. """
  247. print("[DEBUG] 正在生成中文后续问题提示词...")
  248. messages = [
  249. self.system_message(
  250. f"""你是一个专业的数据分析师,能够根据已有问题提出相关的后续问题。
  251. {self._response_language()}
  252. """
  253. ),
  254. self.user_message(
  255. f"""
  256. 原始问题: {question}
  257. 已执行的SQL查询:
  258. ```sql
  259. {sql}
  260. ```
  261. 数据结构:
  262. {df_metadata}
  263. 请基于上述信息,生成3-5个相关的后续问题,这些问题应该:
  264. 1. 与原始问题和数据相关,是自然的延续
  265. 2. 提供更深入的分析视角或维度拓展
  266. 3. 探索可能的业务洞见和价值发现
  267. 4. 简洁明了,便于用户理解
  268. 5. 确保问题可以通过SQL查询解答,与现有数据结构相关
  269. 只需列出问题,不要提供任何解释或SQL。每个问题应该是完整的句子,以问号结尾。
  270. """
  271. )
  272. ]
  273. return messages
  274. def get_summary_prompt(self, question: str, df_markdown: str, **kwargs) -> List[Dict[str, str]]:
  275. """
  276. 生成摘要的中文提示词
  277. """
  278. print("[DEBUG] 正在生成中文摘要提示词...")
  279. messages = [
  280. self.system_message(
  281. f"""你是一个专业的数据分析师,能够清晰解释SQL查询的含义和结果。
  282. {self._response_language()}
  283. """
  284. ),
  285. self.user_message(
  286. f"""
  287. 你是一个有帮助的数据助手。用户问了这个问题: '{question}'
  288. 以下是一个pandas DataFrame,包含查询的结果:
  289. {df_markdown}
  290. 请用中文简明扼要地总结这些数据,回答用户的问题。不要提供任何额外的解释,只需提供摘要。
  291. """
  292. )
  293. ]
  294. return messages
  295. def get_plotly_prompt(self, question: str, sql: str, df_metadata: str,
  296. chart_instructions: Optional[str] = None, **kwargs) -> List[Dict[str, str]]:
  297. """
  298. 生成Python可视化代码的中文提示词
  299. """
  300. print("[DEBUG] 正在生成中文Python可视化提示词...")
  301. instructions = chart_instructions if chart_instructions else "生成一个适合展示数据的图表"
  302. messages = [
  303. self.system_message(
  304. f"""你是一个专业的Python数据可视化专家,擅长使用Plotly创建数据可视化图表。
  305. {self._response_language()}
  306. """
  307. ),
  308. self.user_message(
  309. f"""
  310. 问题: {question}
  311. SQL查询:
  312. ```sql
  313. {sql}
  314. ```
  315. 数据结构:
  316. {df_metadata}
  317. 请生成一个Python函数,使用Plotly库为上述数据创建一个可视化图表。要求:
  318. 1. {instructions}
  319. 2. 确保代码语法正确,可直接运行
  320. 3. 图表应直观展示数据中的关键信息和关系
  321. 4. 只需提供Python代码,不要有任何解释
  322. 5. 使用中文作为图表标题、轴标签和图例
  323. 6. 添加合适的颜色方案,保证图表美观
  324. 7. 针对数据类型选择最合适的图表类型
  325. 输出格式必须是可以直接运行的Python代码。
  326. """
  327. )
  328. ]
  329. return messages