custom_deepseek_chat.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. #from base import VannaBase
  5. # 导入配置参数
  6. from app_config import REWRITE_QUESTION_ENABLED
  7. # from vanna.chromadb import ChromaDB_VectorStore
  8. # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
  9. # def __init__(self, config=None):
  10. # ChromaDB_VectorStore.__init__(self, config=config)
  11. # DeepSeekChat.__init__(self, config=config)
  12. # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
  13. class DeepSeekChat(VannaBase):
  14. def __init__(self, config=None):
  15. VannaBase.__init__(self, config=config)
  16. print("...DeepSeekChat init...")
  17. print("传入的 config 参数如下:")
  18. for key, value in self.config.items():
  19. print(f" {key}: {value}")
  20. # default parameters
  21. self.temperature = 0.7
  22. if "temperature" in config:
  23. print(f"temperature is changed to: {config['temperature']}")
  24. self.temperature = config["temperature"]
  25. if config is None:
  26. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  27. return
  28. if "api_key" in config:
  29. if "base_url" not in config:
  30. self.client = OpenAI(api_key=config["api_key"], base_url="https://api.deepseek.com")
  31. else:
  32. self.client = OpenAI(api_key=config["api_key"], base_url=config["base_url"])
  33. def system_message(self, message: str) -> any:
  34. print(f"system_content: {message}")
  35. return {"role": "system", "content": message}
  36. def user_message(self, message: str) -> any:
  37. print(f"\nuser_content: {message}")
  38. return {"role": "user", "content": message}
  39. def assistant_message(self, message: str) -> any:
  40. print(f"assistant_content: {message}")
  41. return {"role": "assistant", "content": message}
  42. def submit_prompt(self, prompt, **kwargs) -> str:
  43. if prompt is None:
  44. raise Exception("Prompt is None")
  45. if len(prompt) == 0:
  46. raise Exception("Prompt is empty")
  47. # Count the number of tokens in the message log
  48. num_tokens = 0
  49. for message in prompt:
  50. num_tokens += len(message["content"]) / 4
  51. model = None
  52. if kwargs.get("model", None) is not None:
  53. model = kwargs.get("model", None)
  54. elif kwargs.get("engine", None) is not None:
  55. model = kwargs.get("engine", None)
  56. elif self.config is not None and "engine" in self.config:
  57. model = self.config["engine"]
  58. elif self.config is not None and "model" in self.config:
  59. model = self.config["model"]
  60. else:
  61. if num_tokens > 3500:
  62. model = "deepseek-chat"
  63. else:
  64. model = "deepseek-chat"
  65. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  66. response = self.client.chat.completions.create(
  67. model=model,
  68. messages=prompt,
  69. stop=None,
  70. temperature=self.temperature,
  71. )
  72. return response.choices[0].message.content
  73. def generate_sql(self, question: str, **kwargs) -> str:
  74. """
  75. 重写父类的 generate_sql 方法,增加异常处理
  76. """
  77. try:
  78. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  79. # 使用父类的 generate_sql
  80. sql = super().generate_sql(question, **kwargs)
  81. if not sql or sql.strip() == "":
  82. print(f"[WARNING] 生成的SQL为空")
  83. return None
  84. # 替换 "\_" 为 "_",解决特殊字符转义问题
  85. sql = sql.replace("\\_", "_")
  86. # 检查返回内容是否为有效SQL或错误信息
  87. sql_lower = sql.lower().strip()
  88. # 检查是否包含错误提示信息
  89. error_indicators = [
  90. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  91. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  92. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  93. ]
  94. for indicator in error_indicators:
  95. if indicator in sql_lower:
  96. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  97. return None
  98. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  99. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  100. if not any(keyword in sql_lower for keyword in sql_keywords):
  101. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  102. return None
  103. print(f"[SUCCESS] 成功生成SQL: {sql}")
  104. return sql
  105. except Exception as e:
  106. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  107. print(f"[ERROR] 异常类型: {type(e).__name__}")
  108. # 导入traceback以获取详细错误信息
  109. import traceback
  110. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  111. # 返回 None 而不是抛出异常
  112. return None
  113. def generate_question(self, sql: str, **kwargs) -> str:
  114. # 这里可以自定义提示词/逻辑
  115. prompt = [
  116. self.system_message(
  117. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,问题要使用中文,不要包含任何解释或SQL内容,也不要出现表名。"
  118. ),
  119. self.user_message(sql)
  120. ]
  121. response = self.submit_prompt(prompt, **kwargs)
  122. # 你也可以在这里对response做后处理
  123. return response
  124. # 新增:直接与LLM对话的方法
  125. def chat_with_llm(self, question: str, **kwargs) -> str:
  126. """
  127. 直接与LLM对话,不涉及SQL生成
  128. """
  129. try:
  130. prompt = [
  131. self.system_message(
  132. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  133. ),
  134. self.user_message(question)
  135. ]
  136. response = self.submit_prompt(prompt, **kwargs)
  137. return response
  138. except Exception as e:
  139. print(f"[ERROR] LLM对话失败: {str(e)}")
  140. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
  141. def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
  142. """
  143. 重写问题合并方法,通过配置参数控制是否启用合并功能
  144. Args:
  145. last_question (str): 上一个问题
  146. new_question (str): 新问题
  147. **kwargs: 其他参数
  148. Returns:
  149. str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
  150. """
  151. # 如果未启用合并功能或没有上一个问题,直接返回新问题
  152. if not REWRITE_QUESTION_ENABLED or last_question is None:
  153. print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
  154. return new_question
  155. print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
  156. print(f"[DEBUG] 上一个问题: {last_question}")
  157. print(f"[DEBUG] 新问题: {new_question}")
  158. try:
  159. prompt = [
  160. self.system_message(
  161. "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
  162. "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
  163. "请用中文回答。"
  164. ),
  165. self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
  166. ]
  167. rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
  168. print(f"[DEBUG] 合并后的问题: {rewritten_question}")
  169. return rewritten_question
  170. except Exception as e:
  171. print(f"[ERROR] 问题合并失败: {str(e)}")
  172. # 如果合并失败,返回新问题
  173. return new_question