custom_deepseek_chat.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. from openai import OpenAI
  3. from vanna.base import VannaBase
  4. #from base import VannaBase
  5. # from vanna.chromadb import ChromaDB_VectorStore
  6. # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
  7. # def __init__(self, config=None):
  8. # ChromaDB_VectorStore.__init__(self, config=config)
  9. # DeepSeekChat.__init__(self, config=config)
  10. # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
  11. class DeepSeekChat(VannaBase):
  12. def __init__(self, config=None):
  13. VannaBase.__init__(self, config=config)
  14. print("...DeepSeekChat init...")
  15. if config is None:
  16. raise ValueError(
  17. "For DeepSeek, config must be provided with an api_key and model"
  18. )
  19. if "api_key" not in config:
  20. raise ValueError("config must contain a DeepSeek api_key")
  21. if "model" not in config:
  22. config["model"] = "deepseek-chat" # 默认模型
  23. print(f"未指定模型,使用默认模型: {config['model']}")
  24. # 设置默认值
  25. self.temperature = config.get("temperature", 0.7)
  26. self.model = config["model"]
  27. print("传入的 config 参数如下:")
  28. for key, value in config.items():
  29. if key != "api_key": # 不打印API密钥
  30. print(f" {key}: {value}")
  31. # 使用标准的OpenAI客户端,但更改基础URL
  32. self.client = OpenAI(
  33. api_key=config["api_key"],
  34. base_url="https://api.deepseek.com/v1"
  35. )
  36. def system_message(self, message: str) -> any:
  37. print(f"system_content: {message}")
  38. return {"role": "system", "content": message}
  39. def user_message(self, message: str) -> any:
  40. print(f"\nuser_content: {message}")
  41. return {"role": "user", "content": message}
  42. def assistant_message(self, message: str) -> any:
  43. print(f"assistant_content: {message}")
  44. return {"role": "assistant", "content": message}
  45. def submit_prompt(self, prompt, **kwargs) -> str:
  46. if prompt is None:
  47. raise Exception("Prompt is None")
  48. if len(prompt) == 0:
  49. raise Exception("Prompt is empty")
  50. # Count the number of tokens in the message log
  51. # Use 4 as an approximation for the number of characters per token
  52. num_tokens = 0
  53. for message in prompt:
  54. num_tokens += len(message["content"]) / 4
  55. # 从配置和参数中获取model设置,kwargs优先
  56. model = kwargs.get("model", self.model)
  57. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  58. # 创建请求参数
  59. chat_params = {
  60. "model": model,
  61. "messages": prompt,
  62. "temperature": kwargs.get("temperature", self.temperature),
  63. }
  64. try:
  65. chat_response = self.client.chat.completions.create(**chat_params)
  66. # 返回生成的文本
  67. return chat_response.choices[0].message.content
  68. except Exception as e:
  69. print(f"DeepSeek API调用失败: {e}")
  70. raise
  71. def generate_sql(self, question: str, **kwargs) -> str:
  72. # 使用父类的 generate_sql
  73. sql = super().generate_sql(question, **kwargs)
  74. # 替换 "\_" 为 "_",解决特殊字符转义问题
  75. sql = sql.replace("\\_", "_")
  76. return sql
  77. # 为了解决通过sql生成question时,question是英文的问题。
  78. def generate_question(self, sql: str, **kwargs) -> str:
  79. # 这里可以自定义提示词/逻辑
  80. prompt = [
  81. self.system_message(
  82. "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,问题要使用中文,不要包含任何解释或SQL内容,也不要出现表名。"
  83. ),
  84. self.user_message(sql)
  85. ]
  86. response = self.submit_prompt(prompt, **kwargs)
  87. # 你也可以在这里对response做后处理
  88. return response