Przeglądaj źródła

修改代码,当用户问的问题与数据库无关时,也能回答。

wangxq 1 miesiąc temu
rodzic
commit
0b55b79660
3 zmienionych plików z 579 dodań i 238 usunięć
  1. 241 22
      app.py
  2. 112 55
      customdeepseek/custom_deepseek_chat.py
  3. 226 161
      customqianwen/Custom_QianwenAI_chat.py

+ 241 - 22
app.py

@@ -23,39 +23,256 @@ async def chat_profile():
 
 @cl.step(language="sql", name="Vanna")
 async def gen_query(human_query: str):
-    sql_query = vn.generate_sql(human_query)
-    return sql_query
+    """
+    安全的SQL生成函数,处理所有可能的异常
+    """
+    try:
+        print(f"[INFO] 开始生成SQL: {human_query}")
+        sql_query = vn.generate_sql(human_query)
+        
+        if sql_query is None:
+            print(f"[WARNING] generate_sql 返回 None")
+            return None
+            
+        if sql_query.strip() == "":
+            print(f"[WARNING] generate_sql 返回空字符串")
+            return None
+            
+        # 检查是否返回了错误信息而非SQL
+        if "insufficient context" in sql_query.lower() or "无法生成" in sql_query or "sorry" in sql_query.lower():
+            print(f"[WARNING] LLM返回无法生成SQL的消息: {sql_query}")
+            return None
+            
+        print(f"[SUCCESS] SQL生成成功: {sql_query}")
+        return sql_query
+        
+    except Exception as e:
+        print(f"[ERROR] gen_query 异常: {str(e)}")
+        print(f"[ERROR] 异常类型: {type(e).__name__}")
+        return None
 
 @cl.step(name="Vanna")
 async def execute_query(query):
     current_step = cl.context.current_step
-    df = vn.run_sql(query)
-    current_step.output = df.head().to_markdown(index=False)
-
-    return df
+    try:
+        if query is None or query.strip() == "":
+            current_step.output = "SQL查询为空,无法执行"
+            return None
+            
+        print(f"[INFO] 执行SQL: {query}")
+        df = vn.run_sql(query)
+        
+        if df is None or df.empty:
+            current_step.output = "查询执行成功,但没有返回数据"
+            return None
+            
+        current_step.output = df.head().to_markdown(index=False)
+        print(f"[SUCCESS] SQL执行成功,返回 {len(df)} 行数据")
+        return df
+        
+    except Exception as e:
+        error_msg = f"SQL执行失败: {str(e)}"
+        print(f"[ERROR] {error_msg}")
+        current_step.output = error_msg
+        return None
 
 @cl.step(name="Plot", language="python")
 async def plot(human_query, sql, df):
     current_step = cl.context.current_step
-    plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
-    fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
+    try:
+        if df is None or df.empty:
+            current_step.output = "无数据可用于生成图表"
+            return None
+            
+        plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
+        fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
+        current_step.output = plotly_code
+        return fig
+        
+    except Exception as e:
+        error_msg = f"图表生成失败: {str(e)}"
+        print(f"[ERROR] {error_msg}")
+        current_step.output = error_msg
+        return None
 
-    current_step.output = plotly_code
-    return fig
+@cl.step(name="LLM Chat")
+async def llm_chat(human_query: str, context: str = None):
+    """直接与LLM对话,用于非数据库相关问题或SQL生成失败的情况"""
+    current_step = cl.context.current_step
+    try:
+        print(f"[INFO] 使用LLM直接对话: {human_query}")
+        
+        # 构建更智能的提示词
+        if context:
+            # 有上下文时(SQL生成失败)
+            system_message = (
+                "你是一个友好的数据库查询助手。用户刚才的问题无法生成有效的SQL查询,"
+                "可能是因为相关数据不在数据库中,或者问题需要重新表述。"
+                "请友好地回复用户,解释可能的原因,并建议如何重新表述问题。"
+            )
+            user_message = f"用户问题:{human_query}\n\n{context}"
+        else:
+            # 无上下文时(一般性对话)
+            system_message = (
+                "你是一个友好的AI助手。你主要专注于数据库查询,"
+                "但也可以回答一般性问题。如果用户询问数据相关问题,"
+                "请建议他们重新表述以便进行SQL查询。"
+            )
+            user_message = human_query
+        
+        # 使用我们新增的 chat_with_llm 方法
+        if hasattr(vn, 'chat_with_llm'):
+            response = vn.chat_with_llm(user_message)
+        else:
+            # 回退方案:使用 submit_prompt
+            if hasattr(vn, 'submit_prompt'):
+                messages = [
+                    {"role": "system", "content": system_message},
+                    {"role": "user", "content": user_message}
+                ]
+                response = vn.submit_prompt(messages)
+            else:
+                # 最终回退方案
+                response = f"我理解您的问题:'{human_query}'。我主要专注于数据库查询,如果您有数据相关的问题,请尝试重新表述,我可以帮您生成SQL查询并分析数据。"
+        
+        current_step.output = response
+        return response
+        
+    except Exception as e:
+        error_msg = f"LLM对话失败: {str(e)}"
+        print(f"[ERROR] {error_msg}")
+        fallback_response = f"抱歉,我暂时无法回答您的问题:'{human_query}'。请稍后重试,或者尝试重新表述您的问题。"
+        current_step.output = fallback_response
+        return fallback_response
+
+def is_database_related_query(query: str) -> bool:
+    """
+    判断查询是否与数据库相关(保留函数用于调试和可能的后续优化,但不在主流程中使用)
+    """
+    # 数据库相关关键词
+    db_keywords = [
+        # 中文关键词
+        '查询', '数据', '表', '统计', '分析', '汇总', '计算', '查找', '显示', 
+        '列出', '多少', '总计', '平均', '最大', '最小', '排序', '筛选',
+        '销售', '订单', '客户', '产品', '用户', '记录', '报表',
+        # 英文关键词
+        'select', 'count', 'sum', 'avg', 'max', 'min', 'table', 'data',
+        'query', 'database', 'records', 'show', 'list', 'find', 'search'
+    ]
+    
+    # 非数据库关键词
+    non_db_keywords = [
+        '天气', '新闻', '今天', '明天', '时间', '日期', '你好', '谢谢',
+        '什么是', '如何', '为什么', '帮助', '介绍', '说明',
+        'weather', 'news', 'today', 'tomorrow', 'time', 'hello', 'thank',
+        'what is', 'how to', 'why', 'help', 'introduce'
+    ]
+    
+    query_lower = query.lower()
+    
+    # 检查是否包含非数据库关键词
+    for keyword in non_db_keywords:
+        if keyword in query_lower:
+            return False
+    
+    # 检查是否包含数据库关键词
+    for keyword in db_keywords:
+        if keyword in query_lower:
+            return True
+    
+    # 默认认为是数据库相关(保守策略)
+    return True
 
 @cl.step(type="run", name="Vanna")
 async def chain(human_query: str):
-    sql_query = await gen_query(human_query)
-    df = await execute_query(sql_query)    
-    fig = await plot(human_query, sql_query, df)
-
-    # 创建表格和图表元素
-    elements = [
-        cl.Text(name="data_table", content=df.to_markdown(index=False), display="inline"),
-        cl.Plotly(name="chart", figure=fig, display="inline")
-    ]
+    """
+    主要的处理链 - 方案二:尝试-回退策略
+    对所有查询都先尝试生成SQL,如果失败则自动fallback到LLM对话
+    """
     
-    await cl.Message(content=human_query, elements=elements, author="Vanna助手").send()
+    try:
+        # 第一步:直接尝试生成SQL(不做预判断)
+        print(f"[INFO] 尝试为查询生成SQL: {human_query}")
+        sql_query = await gen_query(human_query)
+        
+        if sql_query is None or sql_query.strip() == "":
+            # SQL生成失败,自动fallback到LLM对话
+            print(f"[INFO] SQL生成失败,自动fallback到LLM对话")
+            
+            # 构建上下文信息
+            context = (
+                "我尝试为您的问题生成SQL查询,但没有成功。这可能是因为:\n"
+                "1. 相关数据不在当前数据库中\n"
+                "2. 问题需要更具体的表述\n"
+                "3. 涉及的表或字段不在我的训练数据中"
+            )
+            
+            response = await llm_chat(human_query, context)
+            await cl.Message(content=response, author="Vanna助手").send()
+            return
+        
+        # 第二步:SQL生成成功,执行查询
+        print(f"[INFO] 成功生成SQL,开始执行: {sql_query}")
+        df = await execute_query(sql_query)
+        
+        if df is None or df.empty:
+            # SQL执行失败或无结果,提供详细信息并建议
+            error_context = (
+                f"我为您生成了SQL查询,但执行后没有找到相关数据。\n\n"
+                f"生成的SQL:\n```sql\n{sql_query}\n```\n\n"
+                f"这可能是因为查询条件太严格,或者数据库中暂时没有符合条件的记录。"
+            )
+            
+            response = await llm_chat(
+                f"用户询问:{human_query},但SQL查询没有返回数据。请给出建议。",
+                error_context
+            )
+            
+            await cl.Message(
+                content=f"{error_context}\n\n{response}", 
+                author="Vanna助手"
+            ).send()
+            return
+        
+        # 第三步:成功获取数据,生成图表和返回结果
+        print(f"[INFO] 成功获取数据,生成图表")
+        fig = await plot(human_query, sql_query, df)
+
+        # 创建返回元素
+        elements = [
+            cl.Text(name="data_table", content=df.to_markdown(index=False), display="inline")
+        ]
+        
+        if fig is not None:
+            elements.append(cl.Plotly(name="chart", figure=fig, display="inline"))
+        
+        await cl.Message(
+            content=f"查询完成!以下是关于 '{human_query}' 的分析结果:", 
+            elements=elements, 
+            author="Vanna助手"
+        ).send()
+        
+    except Exception as e:
+        # 最外层异常处理 - 最终fallback
+        error_msg = f"处理请求时发生意外错误: {str(e)}"
+        print(f"[ERROR] {error_msg}")
+        print(f"[ERROR] 异常类型: {type(e).__name__}")
+        
+        # 使用LLM生成友好的错误回复
+        try:
+            final_response = await llm_chat(
+                f"系统遇到技术问题,用户询问:{human_query},请提供友好的回复和建议。"
+            )
+            await cl.Message(
+                content=f"抱歉,系统遇到了一些技术问题。\n\n{final_response}", 
+                author="Vanna助手"
+            ).send()
+        except:
+            # 如果连LLM都失败了,使用硬编码回复
+            await cl.Message(
+                content=f"抱歉,系统暂时遇到技术问题,请稍后重试。如果问题持续存在,请检查网络连接或联系技术支持。", 
+                author="Vanna助手"
+            ).send()
 
 @cl.on_message
 async def main(message: cl.Message):
@@ -71,12 +288,14 @@ async def on_chat_start():
 - 🔍 将自然语言问题转换为SQL查询
 - 📊 执行数据库查询并展示结果
 - 📈 生成数据可视化图表
+- 💬 回答一般性问题
 
-请直接输入您想了解数据问题,例如:
+请直接输入您的问题,例如:
 - "交易次数最多的前5位客户是谁?"
 - "查看过去30天的交易趋势"
+- "你好,今天天气怎么样?"
 
-让我们开始探索数据吧!✨
+让我们开始吧!✨
     """
     
     await cl.Message(

+ 112 - 55
customdeepseek/custom_deepseek_chat.py

@@ -18,34 +18,29 @@ from vanna.base import VannaBase
 class DeepSeekChat(VannaBase):
     def __init__(self, config=None):
         VannaBase.__init__(self, config=config)
-        
         print("...DeepSeekChat init...")
-        if config is None:
-            raise ValueError(
-                "For DeepSeek, config must be provided with an api_key and model"
-            )
-        if "api_key" not in config:
-            raise ValueError("config must contain a DeepSeek api_key")
-
-        if "model" not in config:
-            config["model"] = "deepseek-chat"  # 默认模型
-            print(f"未指定模型,使用默认模型: {config['model']}")
-        
-        # 设置默认值
-        self.temperature = config.get("temperature", 0.7)
-        self.model = config["model"]
-        
+
         print("传入的 config 参数如下:")
-        for key, value in config.items():
-            if key != "api_key":  # 不打印API密钥
-                print(f"  {key}: {value}")
-        
-        # 使用标准的OpenAI客户端,但更改基础URL
-        self.client = OpenAI(
-            api_key=config["api_key"], 
-            base_url="https://api.deepseek.com/v1"
-        )
-    
+        for key, value in self.config.items():
+            print(f"  {key}: {value}")
+
+        # default parameters
+        self.temperature = 0.7
+
+        if "temperature" in config:
+            print(f"temperature is changed to: {config['temperature']}")
+            self.temperature = config["temperature"]
+
+        if config is None:
+            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+            return
+
+        if "api_key" in config:
+            if "base_url" not in config:
+                self.client = OpenAI(api_key=config["api_key"], base_url="https://api.deepseek.com")
+            else:
+                self.client = OpenAI(api_key=config["api_key"], base_url=config["base_url"])
+
     def system_message(self, message: str) -> any:
         print(f"system_content: {message}")
         return {"role": "system", "content": message}
@@ -66,41 +61,85 @@ class DeepSeekChat(VannaBase):
             raise Exception("Prompt is empty")
 
         # Count the number of tokens in the message log
-        # Use 4 as an approximation for the number of characters per token
         num_tokens = 0
         for message in prompt:
             num_tokens += len(message["content"]) / 4
-        
-        # 从配置和参数中获取model设置,kwargs优先
-        model = kwargs.get("model", self.model)
-        
+
+        model = None
+        if kwargs.get("model", None) is not None:
+            model = kwargs.get("model", None)
+        elif kwargs.get("engine", None) is not None:
+            model = kwargs.get("engine", None)
+        elif self.config is not None and "engine" in self.config:
+            model = self.config["engine"]
+        elif self.config is not None and "model" in self.config:
+            model = self.config["model"]
+        else:
+            if num_tokens > 3500:
+                model = "deepseek-chat"
+            else:
+                model = "deepseek-chat"
+
         print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
-        
-        # 创建请求参数
-        chat_params = {
-            "model": model,
-            "messages": prompt,
-            "temperature": kwargs.get("temperature", self.temperature),
-        }
-        
+
+        response = self.client.chat.completions.create(
+            model=model,
+            messages=prompt,
+            stop=None,
+            temperature=self.temperature,
+        )
+
+        return response.choices[0].message.content
+
+    def generate_sql(self, question: str, **kwargs) -> str:
+        """
+        重写父类的 generate_sql 方法,增加异常处理
+        """
         try:
-            chat_response = self.client.chat.completions.create(**chat_params)
-            # 返回生成的文本
-            return chat_response.choices[0].message.content
+            print(f"[DEBUG] 尝试为问题生成SQL: {question}")
+            # 使用父类的 generate_sql
+            sql = super().generate_sql(question, **kwargs)
+            
+            if not sql or sql.strip() == "":
+                print(f"[WARNING] 生成的SQL为空")
+                return None
+            
+            # 替换 "\_" 为 "_",解决特殊字符转义问题
+            sql = sql.replace("\\_", "_")
+            
+            # 检查返回内容是否为有效SQL或错误信息
+            sql_lower = sql.lower().strip()
+            
+            # 检查是否包含错误提示信息
+            error_indicators = [
+                "insufficient context", "无法生成", "sorry", "cannot", "不能",
+                "no relevant", "no suitable", "unable to", "无法", "抱歉",
+                "i don't have", "i cannot", "没有相关", "找不到", "不存在"
+            ]
+            
+            for indicator in error_indicators:
+                if indicator in sql_lower:
+                    print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
+                    return None
+            
+            # 简单检查是否像SQL语句(至少包含一些SQL关键词)
+            sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
+            if not any(keyword in sql_lower for keyword in sql_keywords):
+                print(f"[WARNING] 返回内容不像有效SQL: {sql}")
+                return None
+            
+            print(f"[SUCCESS] 成功生成SQL: {sql}")
+            return sql
+            
         except Exception as e:
-            print(f"DeepSeek API调用失败: {e}")
-            raise
+            print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
+            print(f"[ERROR] 异常类型: {type(e).__name__}")
+            # 导入traceback以获取详细错误信息
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            # 返回 None 而不是抛出异常
+            return None
 
-    def generate_sql(self, question: str, **kwargs) -> str:
-        # 使用父类的 generate_sql
-        sql = super().generate_sql(question, **kwargs)
-        
-        # 替换 "\_" 为 "_",解决特殊字符转义问题
-        sql = sql.replace("\\_", "_")
-        
-        return sql
-    
-    # 为了解决通过sql生成question时,question是英文的问题。
     def generate_question(self, sql: str, **kwargs) -> str:
         # 这里可以自定义提示词/逻辑
         prompt = [
@@ -111,4 +150,22 @@ class DeepSeekChat(VannaBase):
         ]
         response = self.submit_prompt(prompt, **kwargs)
         # 你也可以在这里对response做后处理
-        return response
+        return response
+    
+    # 新增:直接与LLM对话的方法
+    def chat_with_llm(self, question: str, **kwargs) -> str:
+        """
+        直接与LLM对话,不涉及SQL生成
+        """
+        try:
+            prompt = [
+                self.system_message(
+                    "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
+                ),
+                self.user_message(question)
+            ]
+            response = self.submit_prompt(prompt, **kwargs)
+            return response
+        except Exception as e:
+            print(f"[ERROR] LLM对话失败: {str(e)}")
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"

+ 226 - 161
customqianwen/Custom_QianwenAI_chat.py

@@ -4,166 +4,231 @@ from vanna.base import VannaBase
 
 
 class QianWenAI_Chat(VannaBase):
-  def __init__(self, client=None, config=None):
-    print("...QianWenAI_Chat init...")
-    VannaBase.__init__(self, config=config)
-
-    print("传入的 config 参数如下:")
-    for key, value in self.config.items():
-        print(f"  {key}: {value}")
-
-    # default parameters - can be overrided using config
-    self.temperature = 0.7
-
-    if "temperature" in config:
-      print(f"temperature is changed to: {config['temperature']}")
-      self.temperature = config["temperature"]
-
-    if "api_type" in config:
-      raise Exception(
-        "Passing api_type is now deprecated. Please pass an OpenAI client instead."
-      )
-
-    if "api_base" in config:
-      raise Exception(
-        "Passing api_base is now deprecated. Please pass an OpenAI client instead."
-      )
-
-    if "api_version" in config:
-      raise Exception(
-        "Passing api_version is now deprecated. Please pass an OpenAI client instead."
-      )
-
-    if client is not None:
-      self.client = client
-      return
-
-    if config is None and client is None:
-      self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
-      return
-
-    if "api_key" in config:
-      if "base_url" not in config:
-        self.client = OpenAI(api_key=config["api_key"],
-                             base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
-      else:
-        self.client = OpenAI(api_key=config["api_key"],
-                             base_url=config["base_url"])
+    def __init__(self, client=None, config=None):
+        print("...QianWenAI_Chat init...")
+        VannaBase.__init__(self, config=config)
+
+        print("传入的 config 参数如下:")
+        for key, value in self.config.items():
+            print(f"  {key}: {value}")
+
+        # default parameters - can be overrided using config
+        self.temperature = 0.7
+
+        if "temperature" in config:
+            print(f"temperature is changed to: {config['temperature']}")
+            self.temperature = config["temperature"]
+
+        if "api_type" in config:
+            raise Exception(
+                "Passing api_type is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if "api_base" in config:
+            raise Exception(
+                "Passing api_base is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if "api_version" in config:
+            raise Exception(
+                "Passing api_version is now deprecated. Please pass an OpenAI client instead."
+            )
+
+        if client is not None:
+            self.client = client
+            return
+
+        if config is None and client is None:
+            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+            return
+
+        if "api_key" in config:
+            if "base_url" not in config:
+                self.client = OpenAI(api_key=config["api_key"],
+                                     base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
+            else:
+                self.client = OpenAI(api_key=config["api_key"],
+                                     base_url=config["base_url"])
    
-  def system_message(self, message: str) -> any:
-    print(f"system_content: {message}")
-    return {"role": "system", "content": message}
-
-  def user_message(self, message: str) -> any:
-    print(f"\nuser_content: {message}")
-    return {"role": "user", "content": message}
-
-  def assistant_message(self, message: str) -> any:
-    print(f"assistant_content: {message}")
-    return {"role": "assistant", "content": message}
-
-  def submit_prompt(self, prompt, **kwargs) -> str:
-    if prompt is None:
-      raise Exception("Prompt is None")
-
-    if len(prompt) == 0:
-      raise Exception("Prompt is empty")
-
-    # Count the number of tokens in the message log
-    # Use 4 as an approximation for the number of characters per token
-    num_tokens = 0
-    for message in prompt:
-      num_tokens += len(message["content"]) / 4
-
-    # 从配置和参数中获取enable_thinking设置
-    # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
-    enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
-    
-    # 公共参数
-    common_params = {
-      "messages": prompt,
-      "stop": None,
-      "temperature": self.temperature,
-    }
-    
-    # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
-    if enable_thinking:
-      common_params["stream"] = True
-      # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
-      # 也可能它只是默认启用stream=True时的thinking功能
-    
-    model = None
-    # 确定使用的模型
-    if kwargs.get("model", None) is not None:
-      model = kwargs.get("model", None)
-      common_params["model"] = model
-    elif kwargs.get("engine", None) is not None:
-      engine = kwargs.get("engine", None)
-      common_params["engine"] = engine
-      model = engine
-    elif self.config is not None and "engine" in self.config:
-      common_params["engine"] = self.config["engine"]
-      model = self.config["engine"]
-    elif self.config is not None and "model" in self.config:
-      common_params["model"] = self.config["model"]
-      model = self.config["model"]
-    else:
-      if num_tokens > 3500:
-        model = "qwen-long"
-      else:
-        model = "qwen-plus"
-      common_params["model"] = model
-    
-    print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
-    
-    if enable_thinking:
-      # 流式处理模式
-      print("使用流式处理模式,启用thinking功能")
-      
-      # 检查是否需要通过headers传递enable_thinking参数
-      response_stream = self.client.chat.completions.create(**common_params)
-      
-      # 收集流式响应
-      collected_thinking = []
-      collected_content = []
-      
-      for chunk in response_stream:
-        # 处理thinking部分
-        if hasattr(chunk, 'thinking') and chunk.thinking:
-          collected_thinking.append(chunk.thinking)
+    def system_message(self, message: str) -> any:
+        print(f"system_content: {message}")
+        return {"role": "system", "content": message}
+
+    def user_message(self, message: str) -> any:
+        print(f"\nuser_content: {message}")
+        return {"role": "user", "content": message}
+
+    def assistant_message(self, message: str) -> any:
+        print(f"assistant_content: {message}")
+        return {"role": "assistant", "content": message}
+
+    def submit_prompt(self, prompt, **kwargs) -> str:
+        if prompt is None:
+            raise Exception("Prompt is None")
+
+        if len(prompt) == 0:
+            raise Exception("Prompt is empty")
+
+        # Count the number of tokens in the message log
+        # Use 4 as an approximation for the number of characters per token
+        num_tokens = 0
+        for message in prompt:
+            num_tokens += len(message["content"]) / 4
+
+        # 从配置和参数中获取enable_thinking设置
+        # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
+        enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
         
-        # 处理content部分
-        if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
-          collected_content.append(chunk.choices[0].delta.content)
-      
-      # 可以在这里处理thinking的展示逻辑,如保存到日志等
-      if collected_thinking:
-        print("Model thinking process:", "".join(collected_thinking))
-      
-      # 返回完整的内容
-      return "".join(collected_content)
-    else:
-      # 非流式处理模式
-      print("使用非流式处理模式")
-      response = self.client.chat.completions.create(**common_params)
-      
-      # Find the first response from the chatbot that has text in it (some responses may not have text)
-      for choice in response.choices:
-        if "text" in choice:
-          return choice.text
-
-      # If no response with text is found, return the first response's content (which may be empty)
-      return response.choices[0].message.content
-
-# 为了解决通过sql生成question时,question是英文的问题。
-  def generate_question(self, sql: str, **kwargs) -> str:
-      # 这里可以自定义提示词/逻辑
-      prompt = [
-          self.system_message(
-              "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
-          ),
-          self.user_message(sql)
-      ]
-      response = self.submit_prompt(prompt, **kwargs)
-      # 你也可以在这里对response做后处理
-      return response
+        # 公共参数
+        common_params = {
+            "messages": prompt,
+            "stop": None,
+            "temperature": self.temperature,
+        }
+        
+        # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
+        if enable_thinking:
+            common_params["stream"] = True
+            # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
+            # 也可能它只是默认启用stream=True时的thinking功能
+        
+        model = None
+        # 确定使用的模型
+        if kwargs.get("model", None) is not None:
+            model = kwargs.get("model", None)
+            common_params["model"] = model
+        elif kwargs.get("engine", None) is not None:
+            engine = kwargs.get("engine", None)
+            common_params["engine"] = engine
+            model = engine
+        elif self.config is not None and "engine" in self.config:
+            common_params["engine"] = self.config["engine"]
+            model = self.config["engine"]
+        elif self.config is not None and "model" in self.config:
+            common_params["model"] = self.config["model"]
+            model = self.config["model"]
+        else:
+            if num_tokens > 3500:
+                model = "qwen-long"
+            else:
+                model = "qwen-plus"
+            common_params["model"] = model
+        
+        print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        
+        if enable_thinking:
+            # 流式处理模式
+            print("使用流式处理模式,启用thinking功能")
+            
+            # 检查是否需要通过headers传递enable_thinking参数
+            response_stream = self.client.chat.completions.create(**common_params)
+            
+            # 收集流式响应
+            collected_thinking = []
+            collected_content = []
+            
+            for chunk in response_stream:
+                # 处理thinking部分
+                if hasattr(chunk, 'thinking') and chunk.thinking:
+                    collected_thinking.append(chunk.thinking)
+                
+                # 处理content部分
+                if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
+                    collected_content.append(chunk.choices[0].delta.content)
+            
+            # 可以在这里处理thinking的展示逻辑,如保存到日志等
+            if collected_thinking:
+                print("Model thinking process:", "".join(collected_thinking))
+            
+            # 返回完整的内容
+            return "".join(collected_content)
+        else:
+            # 非流式处理模式
+            print("使用非流式处理模式")
+            response = self.client.chat.completions.create(**common_params)
+            
+            # Find the first response from the chatbot that has text in it (some responses may not have text)
+            for choice in response.choices:
+                if "text" in choice:
+                    return choice.text
+
+            # If no response with text is found, return the first response's content (which may be empty)
+            return response.choices[0].message.content
+
+    # 重写 generate_sql 方法以增加异常处理
+    def generate_sql(self, question: str, **kwargs) -> str:
+        """
+        重写父类的 generate_sql 方法,增加异常处理
+        """
+        try:
+            print(f"[DEBUG] 尝试为问题生成SQL: {question}")
+            # 调用父类的 generate_sql
+            sql = super().generate_sql(question, **kwargs)
+            
+            if not sql or sql.strip() == "":
+                print(f"[WARNING] 生成的SQL为空")
+                return None
+            
+            # 检查返回内容是否为有效SQL或错误信息
+            sql_lower = sql.lower().strip()
+            
+            # 检查是否包含错误提示信息
+            error_indicators = [
+                "insufficient context", "无法生成", "sorry", "cannot", "不能",
+                "no relevant", "no suitable", "unable to", "无法", "抱歉",
+                "i don't have", "i cannot", "没有相关", "找不到", "不存在"
+            ]
+            
+            for indicator in error_indicators:
+                if indicator in sql_lower:
+                    print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
+                    return None
+            
+            # 简单检查是否像SQL语句(至少包含一些SQL关键词)
+            sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
+            if not any(keyword in sql_lower for keyword in sql_keywords):
+                print(f"[WARNING] 返回内容不像有效SQL: {sql}")
+                return None
+                
+            print(f"[SUCCESS] 成功生成SQL: {sql}")
+            return sql
+            
+        except Exception as e:
+            print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
+            print(f"[ERROR] 异常类型: {type(e).__name__}")
+            # 导入traceback以获取详细错误信息
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            # 返回 None 而不是抛出异常
+            return None
+
+    # 为了解决通过sql生成question时,question是英文的问题。
+    def generate_question(self, sql: str, **kwargs) -> str:
+        # 这里可以自定义提示词/逻辑
+        prompt = [
+            self.system_message(
+                "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
+            ),
+            self.user_message(sql)
+        ]
+        response = self.submit_prompt(prompt, **kwargs)
+        # 你也可以在这里对response做后处理
+        return response
+    
+    # 新增:直接与LLM对话的方法
+    def chat_with_llm(self, question: str, **kwargs) -> str:
+        """
+        直接与LLM对话,不涉及SQL生成
+        """
+        try:
+            prompt = [
+                self.system_message(
+                    "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
+                ),
+                self.user_message(question)
+            ]
+            response = self.submit_prompt(prompt, **kwargs)
+            return response
+        except Exception as e:
+            print(f"[ERROR] LLM对话失败: {str(e)}")
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"