Custom_QianwenAI_chat.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. class QianWenAI_Chat(VannaBase):
  5. def __init__(self, client=None, config=None):
  6. print("...QianWenAI_Chat init...")
  7. VannaBase.__init__(self, config=config)
  8. print("传入的 config 参数如下:")
  9. for key, value in self.config.items():
  10. print(f" {key}: {value}")
  11. # default parameters - can be overrided using config
  12. self.temperature = 0.7
  13. if "temperature" in config:
  14. print(f"temperature is changed to: {config['temperature']}")
  15. self.temperature = config["temperature"]
  16. if "api_type" in config:
  17. raise Exception(
  18. "Passing api_type is now deprecated. Please pass an OpenAI client instead."
  19. )
  20. if "api_base" in config:
  21. raise Exception(
  22. "Passing api_base is now deprecated. Please pass an OpenAI client instead."
  23. )
  24. if "api_version" in config:
  25. raise Exception(
  26. "Passing api_version is now deprecated. Please pass an OpenAI client instead."
  27. )
  28. if client is not None:
  29. self.client = client
  30. return
  31. if config is None and client is None:
  32. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  33. return
  34. if "api_key" in config:
  35. if "base_url" not in config:
  36. self.client = OpenAI(api_key=config["api_key"],
  37. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  38. else:
  39. self.client = OpenAI(api_key=config["api_key"],
  40. base_url=config["base_url"])
  41. def system_message(self, message: str) -> any:
  42. print(f"system_content: {message}")
  43. return {"role": "system", "content": message}
  44. def user_message(self, message: str) -> any:
  45. print(f"\nuser_content: {message}")
  46. return {"role": "user", "content": message}
  47. def assistant_message(self, message: str) -> any:
  48. print(f"assistant_content: {message}")
  49. return {"role": "assistant", "content": message}
  50. def submit_prompt(self, prompt, **kwargs) -> str:
  51. if prompt is None:
  52. raise Exception("Prompt is None")
  53. if len(prompt) == 0:
  54. raise Exception("Prompt is empty")
  55. # Count the number of tokens in the message log
  56. # Use 4 as an approximation for the number of characters per token
  57. num_tokens = 0
  58. for message in prompt:
  59. num_tokens += len(message["content"]) / 4
  60. # 从配置和参数中获取enable_thinking设置
  61. # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
  62. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
  63. # 公共参数
  64. common_params = {
  65. "messages": prompt,
  66. "stop": None,
  67. "temperature": self.temperature,
  68. }
  69. # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
  70. if enable_thinking:
  71. common_params["stream"] = True
  72. # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
  73. # 也可能它只是默认启用stream=True时的thinking功能
  74. model = None
  75. # 确定使用的模型
  76. if kwargs.get("model", None) is not None:
  77. model = kwargs.get("model", None)
  78. common_params["model"] = model
  79. elif kwargs.get("engine", None) is not None:
  80. engine = kwargs.get("engine", None)
  81. common_params["engine"] = engine
  82. model = engine
  83. elif self.config is not None and "engine" in self.config:
  84. common_params["engine"] = self.config["engine"]
  85. model = self.config["engine"]
  86. elif self.config is not None and "model" in self.config:
  87. common_params["model"] = self.config["model"]
  88. model = self.config["model"]
  89. else:
  90. if num_tokens > 3500:
  91. model = "qwen-long"
  92. else:
  93. model = "qwen-plus"
  94. common_params["model"] = model
  95. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  96. if enable_thinking:
  97. # 流式处理模式
  98. print("使用流式处理模式,启用thinking功能")
  99. # 检查是否需要通过headers传递enable_thinking参数
  100. response_stream = self.client.chat.completions.create(**common_params)
  101. # 收集流式响应
  102. collected_thinking = []
  103. collected_content = []
  104. for chunk in response_stream:
  105. # 处理thinking部分
  106. if hasattr(chunk, 'thinking') and chunk.thinking:
  107. collected_thinking.append(chunk.thinking)
  108. # 处理content部分
  109. if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  110. collected_content.append(chunk.choices[0].delta.content)
  111. # 可以在这里处理thinking的展示逻辑,如保存到日志等
  112. if collected_thinking:
  113. print("Model thinking process:", "".join(collected_thinking))
  114. # 返回完整的内容
  115. return "".join(collected_content)
  116. else:
  117. # 非流式处理模式
  118. print("使用非流式处理模式")
  119. response = self.client.chat.completions.create(**common_params)
  120. # Find the first response from the chatbot that has text in it (some responses may not have text)
  121. for choice in response.choices:
  122. if "text" in choice:
  123. return choice.text
  124. # If no response with text is found, return the first response's content (which may be empty)
  125. return response.choices[0].message.content
  126. # 重写 generate_sql 方法以增加异常处理
  127. def generate_sql(self, question: str, **kwargs) -> str:
  128. """
  129. 重写父类的 generate_sql 方法,增加异常处理
  130. """
  131. try:
  132. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  133. # 调用父类的 generate_sql
  134. sql = super().generate_sql(question, **kwargs)
  135. if not sql or sql.strip() == "":
  136. print(f"[WARNING] 生成的SQL为空")
  137. return None
  138. # 检查返回内容是否为有效SQL或错误信息
  139. sql_lower = sql.lower().strip()
  140. # 检查是否包含错误提示信息
  141. error_indicators = [
  142. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  143. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  144. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  145. ]
  146. for indicator in error_indicators:
  147. if indicator in sql_lower:
  148. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  149. return None
  150. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  151. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  152. if not any(keyword in sql_lower for keyword in sql_keywords):
  153. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  154. return None
  155. print(f"[SUCCESS] 成功生成SQL: {sql}")
  156. return sql
  157. except Exception as e:
  158. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  159. print(f"[ERROR] 异常类型: {type(e).__name__}")
  160. # 导入traceback以获取详细错误信息
  161. import traceback
  162. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  163. # 返回 None 而不是抛出异常
  164. return None
  165. # 为了解决通过sql生成question时,question是英文的问题。
  166. def generate_question(self, sql: str, **kwargs) -> str:
  167. # 这里可以自定义提示词/逻辑
  168. prompt = [
  169. self.system_message(
  170. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
  171. ),
  172. self.user_message(sql)
  173. ]
  174. response = self.submit_prompt(prompt, **kwargs)
  175. # 你也可以在这里对response做后处理
  176. return response
  177. # 新增:直接与LLM对话的方法
  178. def chat_with_llm(self, question: str, **kwargs) -> str:
  179. """
  180. 直接与LLM对话,不涉及SQL生成
  181. """
  182. try:
  183. prompt = [
  184. self.system_message(
  185. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  186. ),
  187. self.user_message(question)
  188. ]
  189. response = self.submit_prompt(prompt, **kwargs)
  190. return response
  191. except Exception as e:
  192. print(f"[ERROR] LLM对话失败: {str(e)}")
  193. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"