123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- import os
- from openai import OpenAI
- from vanna.base import VannaBase
- #from base import VannaBase
- # from vanna.chromadb import ChromaDB_VectorStore
- # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
- # def __init__(self, config=None):
- # ChromaDB_VectorStore.__init__(self, config=config)
- # DeepSeekChat.__init__(self, config=config)
- # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
- class DeepSeekChat(VannaBase):
- def __init__(self, config=None):
- VannaBase.__init__(self, config=config)
-
- print("...DeepSeekChat init...")
- if config is None:
- raise ValueError(
- "For DeepSeek, config must be provided with an api_key and model"
- )
- if "api_key" not in config:
- raise ValueError("config must contain a DeepSeek api_key")
- if "model" not in config:
- config["model"] = "deepseek-chat" # 默认模型
- print(f"未指定模型,使用默认模型: {config['model']}")
-
- # 设置默认值
- self.temperature = config.get("temperature", 0.7)
- self.model = config["model"]
-
- print("传入的 config 参数如下:")
- for key, value in config.items():
- if key != "api_key": # 不打印API密钥
- print(f" {key}: {value}")
-
- # 使用标准的OpenAI客户端,但更改基础URL
- self.client = OpenAI(
- api_key=config["api_key"],
- base_url="https://api.deepseek.com/v1"
- )
-
- def system_message(self, message: str) -> any:
- print(f"system_content: {message}")
- return {"role": "system", "content": message}
- def user_message(self, message: str) -> any:
- print(f"\nuser_content: {message}")
- return {"role": "user", "content": message}
- def assistant_message(self, message: str) -> any:
- print(f"assistant_content: {message}")
- return {"role": "assistant", "content": message}
- def submit_prompt(self, prompt, **kwargs) -> str:
- if prompt is None:
- raise Exception("Prompt is None")
- if len(prompt) == 0:
- raise Exception("Prompt is empty")
- # Count the number of tokens in the message log
- # Use 4 as an approximation for the number of characters per token
- num_tokens = 0
- for message in prompt:
- num_tokens += len(message["content"]) / 4
-
- # 从配置和参数中获取model设置,kwargs优先
- model = kwargs.get("model", self.model)
-
- print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
-
- # 创建请求参数
- chat_params = {
- "model": model,
- "messages": prompt,
- "temperature": kwargs.get("temperature", self.temperature),
- }
-
- try:
- chat_response = self.client.chat.completions.create(**chat_params)
- # 返回生成的文本
- return chat_response.choices[0].message.content
- except Exception as e:
- print(f"DeepSeek API调用失败: {e}")
- raise
- def generate_sql(self, question: str, **kwargs) -> str:
- # 使用父类的 generate_sql
- sql = super().generate_sql(question, **kwargs)
-
- # 替换 "\_" 为 "_",解决特殊字符转义问题
- sql = sql.replace("\\_", "_")
-
- return sql
-
- # 为了解决通过sql生成question时,question是英文的问题。
- def generate_question(self, sql: str, **kwargs) -> str:
- # 这里可以自定义提示词/逻辑
- prompt = [
- self.system_message(
- "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,问题要使用中文,不要包含任何解释或SQL内容,也不要出现表名。"
- ),
- self.user_message(sql)
- ]
- response = self.submit_prompt(prompt, **kwargs)
- # 你也可以在这里对response做后处理
- return response
|