Explorar el Código

完成对ollama llm的bug修复,添加参数,控制thinking的结果是否显示。

wangxq hace 3 semanas
padre
commit
67cf40c56c
Se han modificado 5 ficheros con 168 adiciones y 25 borrados
  1. 11 6
      app_config.py
  2. 95 2
      customllm/base_llm_chat.py
  3. 38 4
      customllm/deepseek_chat.py
  4. 15 10
      customllm/ollama_chat.py
  5. 9 3
      customllm/qianwen_chat.py

+ 11 - 6
app_config.py

@@ -10,9 +10,9 @@ load_dotenv(override=True)
 LLM_MODEL_TYPE = "api"  # api, ollama
 
 # Embedding模型提供商类型:api 或 ollama  
-EMBEDDING_MODEL_TYPE = "api"  # api, ollama
+EMBEDDING_MODEL_TYPE = "ollama"  # api, ollama
 
-# ===== 模型名称配置 =====
+# =====API 模型名称配置 =====
 # API LLM模型名称(当LLM_MODEL_TYPE="api"时使用:qianwen 或 deepseek )
 API_LLM_MODEL = "deepseek"
 
@@ -25,11 +25,11 @@ API_DEEPSEEK_CONFIG = {
     "api_key": os.getenv("DEEPSEEK_API_KEY"),  # 从环境变量读取API密钥
     "model": "deepseek-reasoner",  # deepseek-chat, deepseek-reasoner
     "allow_llm_to_see_data": True,
-    "temperature": 0.7,
+    "temperature": 0.6,
     "n_results": 6,
     "language": "Chinese",
     "stream": True,  # 是否使用流式模式
-    "enable_thinking": False  # 自定义,是否支持流模式
+    "enable_thinking": True  # 自定义,是否支持流模式
 }
 
 # Qwen模型配置
@@ -131,6 +131,11 @@ TRAINING_DATA_PATH = "./training/data"
 # 是否启用问题重写功能,也就是上下文问题合并。
 REWRITE_QUESTION_ENABLED = False
 
+# 是否在摘要中显示thinking过程
+# True: 显示 <think></think> 内容
+# False: 隐藏 <think></think> 内容,只显示最终答案
+DISPLAY_SUMMARY_THINKING = True
+
 # 是否启用向量查询结果得分阈值过滤
 # result = max((n + 1) // 2, 1)
 ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True
@@ -138,6 +143,6 @@ ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True
 RESULT_VECTOR_SQL_SCORE_THRESHOLD = 0.65
 RESULT_VECTOR_DDL_SCORE_THRESHOLD = 0.5
 RESULT_VECTOR_DOC_SCORE_THRESHOLD = 0.5
-RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD = 0.5
 
-ENABLE_ERROR_SQL_PROMPT = True
+ENABLE_ERROR_SQL_PROMPT = True
+RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD = 0.8

+ 95 - 2
customllm/base_llm_chat.py

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
 from typing import List, Dict, Any, Optional
 from vanna.base import VannaBase
 # 导入配置参数
-from app_config import REWRITE_QUESTION_ENABLED
+from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_SUMMARY_THINKING
 
 
 class BaseLLMChat(VannaBase, ABC):
@@ -60,7 +60,7 @@ class BaseLLMChat(VannaBase, ABC):
         
         if initial_prompt is None:
             initial_prompt = f"You are a {self.dialect} expert. " + \
-            "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
+            "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions."
 
         # 提取DDL内容(适配新的字典格式)
         ddl_content_list = []
@@ -381,6 +381,99 @@ class BaseLLMChat(VannaBase, ABC):
             # 如果合并失败,返回新问题
             return new_question
 
+    def generate_summary(self, question: str, df, **kwargs) -> str:
+        """
+        覆盖父类的 generate_summary 方法,添加中文思考和回答指令
+        
+        Args:
+            question (str): 用户提出的问题
+            df: 查询结果的 DataFrame
+            **kwargs: 其他参数
+            
+        Returns:
+            str: 数据摘要
+        """
+        try:
+            # 导入 pandas 用于 DataFrame 处理
+            import pandas as pd
+            
+            # 确保 df 是 pandas DataFrame
+            if not isinstance(df, pd.DataFrame):
+                print(f"[WARNING] df 不是 pandas DataFrame,类型: {type(df)}")
+                return "无法生成摘要:数据格式不正确"
+            
+            if df.empty:
+                return "查询结果为空,无数据可供摘要。"
+            
+            print(f"[DEBUG] 生成摘要 - 问题: {question}")
+            print(f"[DEBUG] DataFrame 形状: {df.shape}")
+            
+            # 构建包含中文指令的系统消息
+            system_content = (
+                f"你是一个专业的数据分析助手。用户提出了问题:'{question}'\n\n"
+                f"以下是查询结果的 pandas DataFrame 数据:\n{df.to_markdown()}\n\n"
+                "请用中文进行思考和分析,并用中文回答。"
+            )
+            
+            # 构建用户消息,强调中文思考和回答
+            user_content = (
+                "请基于用户提出的问题,简要总结这些数据。要求:\n"             
+                "1. 只进行简要总结,不要添加额外的解释\n"
+                "2. 如果数据中有数字,请保留适当的精度\n"            
+            )
+            
+            message_log = [
+                self.system_message(system_content),
+                self.user_message(user_content)
+            ]
+            
+            summary = self.submit_prompt(message_log, **kwargs)
+            
+            # 检查是否需要隐藏 thinking 内容
+            display_thinking = kwargs.get("display_summary_thinking", DISPLAY_SUMMARY_THINKING)
+            
+            if not display_thinking:
+                # 移除 <think></think> 标签及其内容
+                original_summary = summary
+                summary = self._remove_thinking_content(summary)
+                print(f"[DEBUG] 隐藏thinking内容 - 原始长度: {len(original_summary)}, 处理后长度: {len(summary)}")
+            
+            print(f"[DEBUG] 生成的摘要: {summary[:100]}...")
+            return summary
+            
+        except Exception as e:
+            print(f"[ERROR] 生成摘要失败: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            return f"生成摘要时出现错误:{str(e)}"
+
+    def _remove_thinking_content(self, text: str) -> str:
+        """
+        移除文本中的 <think></think> 标签及其内容
+        
+        Args:
+            text (str): 包含可能的 thinking 标签的文本
+            
+        Returns:
+            str: 移除 thinking 内容后的文本
+        """
+        if not text:
+            return text
+        
+        import re
+        
+        # 移除 <think>...</think> 标签及其内容(支持多行)
+        # 使用 re.DOTALL 标志使 . 匹配包括换行符在内的任何字符
+        cleaned_text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL | re.IGNORECASE)
+        
+        # 移除可能的多余空行
+        cleaned_text = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_text)
+        
+        # 去除开头和结尾的空白字符
+        cleaned_text = cleaned_text.strip()
+        
+        return cleaned_text
+
     @abstractmethod
     def submit_prompt(self, prompt, **kwargs) -> str:
         """

+ 38 - 4
customllm/deepseek_chat.py

@@ -71,6 +71,25 @@ class DeepSeekChat(BaseLLMChat):
         print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
         print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
 
+        # 方案1:通过 system prompt 控制中文输出(DeepSeek 不支持 language 参数)
+        # 检查配置中的语言设置,并在 system prompt 中添加中文指令
+        # language_setting = self.config.get("language", "").lower() if self.config else ""
+        # print(f"DEBUG: language_setting='{language_setting}', model='{model}', enable_thinking={enable_thinking}")
+        
+        # if language_setting == "chinese" and enable_thinking:
+        #     print("DEBUG: ✅ 触发中文指令添加")
+        #     # 为推理模型添加中文思考指令
+        #     chinese_instruction = {"role": "system", "content": "请用中文进行思考和回答。在推理过程中,请使用中文进行分析和思考。<think></think>之间也请使用中文"}
+        #     # 如果第一条消息不是 system 消息,则添加中文指令
+        #     if not prompt or prompt[0].get("role") != "system":
+        #         prompt = [chinese_instruction] + prompt
+        #     else:
+        #         # 如果已有 system 消息,则在其内容中添加中文指令
+        #         existing_content = prompt[0]["content"]
+        #         prompt[0]["content"] = f"{existing_content}\n\n请用中文进行思考和回答。在推理过程中,请使用中文进行分析和思考。<think></think>之间也请使用中文"
+        # else:
+        #     print(f"DEBUG: ❌ 未触发中文指令 - language_setting==chinese: {language_setting == 'chinese'}, model==deepseek-reasoner: {model == 'deepseek-reasoner'}, enable_thinking: {enable_thinking}")
+
         # 构建 API 调用参数
         api_params = {
             "model": model,
@@ -81,6 +100,7 @@ class DeepSeekChat(BaseLLMChat):
         }
 
         # 过滤掉自定义参数,避免传递给 API
+        # 注意:保留 language 参数,让 DeepSeek API 自己处理
         filtered_kwargs = {k: v for k, v in kwargs.items() 
                           if k not in ['model', 'engine', 'enable_thinking', 'stream']}
 
@@ -128,9 +148,16 @@ class DeepSeekChat(BaseLLMChat):
                 
                 # 可选:打印推理过程
                 if collected_reasoning:
-                    print("Model reasoning process:\n", "".join(collected_reasoning))
+                    reasoning_text = "".join(collected_reasoning)
+                    print("Model reasoning process:\n", reasoning_text)
                 
-                return "".join(collected_content)
+                # 方案2:返回包含 <think></think> 标签的完整内容,与 QianWen 保持一致
+                final_content = "".join(collected_content)
+                if collected_reasoning:
+                    reasoning_text = "".join(collected_reasoning)
+                    return f"<think>{reasoning_text}</think>\n\n{final_content}"
+                else:
+                    return final_content
             else:
                 # 其他模型的流式处理(如 deepseek-chat)
                 collected_content = []
@@ -155,10 +182,17 @@ class DeepSeekChat(BaseLLMChat):
                 message = response.choices[0].message
                 
                 # 可选:打印推理过程
+                reasoning_content = ""
                 if hasattr(message, 'reasoning_content') and message.reasoning_content:
-                    print("Model reasoning process:\n", message.reasoning_content)
+                    reasoning_content = message.reasoning_content
+                    print("Model reasoning process:\n", reasoning_content)
                 
-                return message.content
+                # 方案2:返回包含 <think></think> 标签的完整内容,与 QianWen 保持一致
+                final_content = message.content
+                if reasoning_content:
+                    return f"<think>{reasoning_content}</think>\n\n{final_content}"
+                else:
+                    return final_content
             else:
                 # 其他模型的非流式处理(如 deepseek-chat)
                 return response.choices[0].message.content 

+ 15 - 10
customllm/ollama_chat.py

@@ -72,7 +72,7 @@ class OllamaChat(BaseLLMChat):
         
         # 模型兼容性提示(但不强制切换)
         if enable_thinking and not is_reasoning_model:
-            print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
+            print(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
 
         print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
         print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
@@ -83,26 +83,27 @@ class OllamaChat(BaseLLMChat):
             "model": model,
             "messages": prompt,
             "stream": stream_mode,
-            "options": self._build_options(kwargs, is_reasoning_model)
+            "think": enable_thinking,  # Ollama API 使用 think 参数控制推理功能
+            "options": self._build_options(kwargs, is_reasoning_model, enable_thinking)
         }
 
         try:
             if stream_mode:
                 # 流式处理模式
-                if is_reasoning_model and enable_thinking:
+                if enable_thinking:
                     print("使用流式处理模式,启用推理功能")
                 else:
                     print("使用流式处理模式,常规聊天")
                 
-                return self._handle_stream_response(url, payload, is_reasoning_model and enable_thinking)
+                return self._handle_stream_response(url, payload, enable_thinking)
             else:
                 # 非流式处理模式
-                if is_reasoning_model and enable_thinking:
+                if enable_thinking:
                     print("使用非流式处理模式,启用推理功能")
                 else:
                     print("使用非流式处理模式,常规聊天")
                 
-                return self._handle_non_stream_response(url, payload, is_reasoning_model and enable_thinking)
+                return self._handle_non_stream_response(url, payload, enable_thinking)
                 
         except requests.exceptions.RequestException as e:
             print(f"Ollama API请求失败: {e}")
@@ -266,7 +267,7 @@ class OllamaChat(BaseLLMChat):
         reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
         return any(keyword in model.lower() for keyword in reasoning_keywords)
 
-    def _build_options(self, kwargs: dict, is_reasoning_model: bool) -> dict:
+    def _build_options(self, kwargs: dict, is_reasoning_model: bool, enable_thinking: bool = False) -> dict:
         """构建 Ollama options 参数"""
         options = {
             "temperature": self.temperature,
@@ -283,13 +284,17 @@ class OllamaChat(BaseLLMChat):
         for k, v in filtered_kwargs.items():
             options[k] = v
 
-        # 推理模型特定参数调整
-        if is_reasoning_model:
-            # 推理模型可能需要更多的预测token
+        # 推理功能参数调整(当启用推理功能时)
+        if enable_thinking:
+            # 启用推理功能时可能需要更多的预测token
             if options["num_predict"] == -1:
                 options["num_predict"] = 2048
             # 降低重复惩罚,允许更多的推理重复
             options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
+            # 对于推理模型,可以进一步优化参数
+            if is_reasoning_model:
+                # 推理模型可能需要更长的上下文
+                options["num_ctx"] = max(options["num_ctx"], 8192)
 
         return options
 

+ 9 - 3
customllm/qianwen_chat.py

@@ -143,10 +143,16 @@ class QianWenChat(BaseLLMChat):
             
             # 可以在这里处理thinking的展示逻辑,如保存到日志等
             if enable_thinking and collected_thinking:
-                print("Model thinking process:\n", "".join(collected_thinking))
+                thinking_text = "".join(collected_thinking)
+                print("Model thinking process:\n", thinking_text)
             
-            # 返回完整的内容
-            return "".join(collected_content)
+            # 返回包含 <think></think> 标签的完整内容,与界面显示需求保持一致
+            final_content = "".join(collected_content)
+            if enable_thinking and collected_thinking:
+                thinking_text = "".join(collected_thinking)
+                return f"<think>{thinking_text}</think>\n\n{final_content}"
+            else:
+                return final_content
         else:
             # 非流式处理模式
             print("使用非流式处理模式")