custom_deepseek_chat.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. #from base import VannaBase
  5. # from vanna.chromadb import ChromaDB_VectorStore
  6. # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
  7. # def __init__(self, config=None):
  8. # ChromaDB_VectorStore.__init__(self, config=config)
  9. # DeepSeekChat.__init__(self, config=config)
  10. # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
  11. class DeepSeekChat(VannaBase):
  12. def __init__(self, config=None):
  13. VannaBase.__init__(self, config=config)
  14. print("...DeepSeekChat init...")
  15. print("传入的 config 参数如下:")
  16. for key, value in self.config.items():
  17. print(f" {key}: {value}")
  18. # default parameters
  19. self.temperature = 0.7
  20. if "temperature" in config:
  21. print(f"temperature is changed to: {config['temperature']}")
  22. self.temperature = config["temperature"]
  23. if config is None:
  24. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  25. return
  26. if "api_key" in config:
  27. if "base_url" not in config:
  28. self.client = OpenAI(api_key=config["api_key"], base_url="https://api.deepseek.com")
  29. else:
  30. self.client = OpenAI(api_key=config["api_key"], base_url=config["base_url"])
  31. def system_message(self, message: str) -> any:
  32. print(f"system_content: {message}")
  33. return {"role": "system", "content": message}
  34. def user_message(self, message: str) -> any:
  35. print(f"\nuser_content: {message}")
  36. return {"role": "user", "content": message}
  37. def assistant_message(self, message: str) -> any:
  38. print(f"assistant_content: {message}")
  39. return {"role": "assistant", "content": message}
  40. def submit_prompt(self, prompt, **kwargs) -> str:
  41. if prompt is None:
  42. raise Exception("Prompt is None")
  43. if len(prompt) == 0:
  44. raise Exception("Prompt is empty")
  45. # Count the number of tokens in the message log
  46. num_tokens = 0
  47. for message in prompt:
  48. num_tokens += len(message["content"]) / 4
  49. model = None
  50. if kwargs.get("model", None) is not None:
  51. model = kwargs.get("model", None)
  52. elif kwargs.get("engine", None) is not None:
  53. model = kwargs.get("engine", None)
  54. elif self.config is not None and "engine" in self.config:
  55. model = self.config["engine"]
  56. elif self.config is not None and "model" in self.config:
  57. model = self.config["model"]
  58. else:
  59. if num_tokens > 3500:
  60. model = "deepseek-chat"
  61. else:
  62. model = "deepseek-chat"
  63. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  64. response = self.client.chat.completions.create(
  65. model=model,
  66. messages=prompt,
  67. stop=None,
  68. temperature=self.temperature,
  69. )
  70. return response.choices[0].message.content
  71. def generate_sql(self, question: str, **kwargs) -> str:
  72. """
  73. 重写父类的 generate_sql 方法,增加异常处理
  74. """
  75. try:
  76. print(f"[DEBUG] 尝试为问题生成SQL: {question}")
  77. # 使用父类的 generate_sql
  78. sql = super().generate_sql(question, **kwargs)
  79. if not sql or sql.strip() == "":
  80. print(f"[WARNING] 生成的SQL为空")
  81. return None
  82. # 替换 "\_" 为 "_",解决特殊字符转义问题
  83. sql = sql.replace("\\_", "_")
  84. # 检查返回内容是否为有效SQL或错误信息
  85. sql_lower = sql.lower().strip()
  86. # 检查是否包含错误提示信息
  87. error_indicators = [
  88. "insufficient context", "无法生成", "sorry", "cannot", "不能",
  89. "no relevant", "no suitable", "unable to", "无法", "抱歉",
  90. "i don't have", "i cannot", "没有相关", "找不到", "不存在"
  91. ]
  92. for indicator in error_indicators:
  93. if indicator in sql_lower:
  94. print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
  95. return None
  96. # 简单检查是否像SQL语句(至少包含一些SQL关键词)
  97. sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
  98. if not any(keyword in sql_lower for keyword in sql_keywords):
  99. print(f"[WARNING] 返回内容不像有效SQL: {sql}")
  100. return None
  101. print(f"[SUCCESS] 成功生成SQL: {sql}")
  102. return sql
  103. except Exception as e:
  104. print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
  105. print(f"[ERROR] 异常类型: {type(e).__name__}")
  106. # 导入traceback以获取详细错误信息
  107. import traceback
  108. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  109. # 返回 None 而不是抛出异常
  110. return None
  111. def generate_question(self, sql: str, **kwargs) -> str:
  112. # 这里可以自定义提示词/逻辑
  113. prompt = [
  114. self.system_message(
  115. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,问题要使用中文,不要包含任何解释或SQL内容,也不要出现表名。"
  116. ),
  117. self.user_message(sql)
  118. ]
  119. response = self.submit_prompt(prompt, **kwargs)
  120. # 你也可以在这里对response做后处理
  121. return response
  122. # 新增:直接与LLM对话的方法
  123. def chat_with_llm(self, question: str, **kwargs) -> str:
  124. """
  125. 直接与LLM对话,不涉及SQL生成
  126. """
  127. try:
  128. prompt = [
  129. self.system_message(
  130. "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
  131. ),
  132. self.user_message(question)
  133. ]
  134. response = self.submit_prompt(prompt, **kwargs)
  135. return response
  136. except Exception as e:
  137. print(f"[ERROR] LLM对话失败: {str(e)}")
  138. return f"抱歉,我暂时无法回答您的问题。请稍后再试。"