Explorar el Código

统一了DISPLAY_RESULT_THINKING的配置.

wangxq hace 2 semanas
padre
commit
2cbe2d2dca

+ 97 - 163
agent/citu_agent.py

@@ -27,20 +27,11 @@ class CituLangGraphAgent:
         self.tools = TOOLS
         self.llm = get_compatible_llm()
         
-        # 预创建Agent实例以提升性能
-        enable_reuse = self.config.get("performance", {}).get("enable_agent_reuse", True)
-        if enable_reuse:
-            print("[CITU_AGENT] 预创建Agent实例中...")
-            self._database_executor = self._create_database_agent()
-            self._chat_executor = self._create_chat_agent()
-            print("[CITU_AGENT] Agent实例预创建完成")
-        else:
-            self._database_executor = None
-            self._chat_executor = None
-            print("[CITU_AGENT] Agent实例重用已禁用,将在运行时创建")
+        # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
+        print("[CITU_AGENT] 使用直接工具调用模式")
         
         self.workflow = self._create_workflow()
-        print("[CITU_AGENT] LangGraph Agent with Tools初始化完成")
+        print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
     
     def _create_workflow(self) -> StateGraph:
         """创建LangGraph工作流"""
@@ -99,63 +90,81 @@ class CituLangGraphAgent:
             state["execution_path"].append("classify_error")
             return state
     
-    def _create_database_agent(self):
-        """创建数据库专用Agent(预创建)"""
-        from agent.config import get_nested_config
-        
-        # 获取配置
-        max_iterations = get_nested_config(self.config, "database_agent.max_iterations", 5)
-        enable_verbose = get_nested_config(self.config, "database_agent.enable_verbose", True)
-        early_stopping_method = get_nested_config(self.config, "database_agent.early_stopping_method", "generate")
-        
-        database_prompt = ChatPromptTemplate.from_messages([
-            SystemMessage(content="""
-你是一个专业的数据库查询助手。你的任务是帮助用户查询数据库并生成报告。
-
-工具使用流程:
-1. 首先使用 generate_sql 工具将用户问题转换为SQL
-2. 然后使用 execute_sql 工具执行SQL查询
-3. 最后使用 generate_summary 工具为结果生成自然语言摘要
-
-如果任何步骤失败,请提供清晰的错误信息并建议解决方案。
-"""),
-            MessagesPlaceholder(variable_name="chat_history", optional=True),
-            HumanMessage(content="{input}"),
-            MessagesPlaceholder(variable_name="agent_scratchpad")
-        ])
-        
-        database_tools = [generate_sql, execute_sql, generate_summary]
-        agent = create_openai_tools_agent(self.llm, database_tools, database_prompt)
-        
-        return AgentExecutor(
-            agent=agent,
-            tools=database_tools,
-            verbose=enable_verbose,
-            handle_parsing_errors=True,
-            max_iterations=max_iterations,
-            early_stopping_method=early_stopping_method
-        )
-    
     def _agent_database_node(self, state: AgentState) -> AgentState:
-        """数据库Agent节点 - 使用预创建或动态创建的Agent"""
+        """数据库Agent节点 - 直接工具调用模式"""
         try:
             print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
             
-            # 使用预创建的Agent或动态创建
-            if self._database_executor is not None:
-                executor = self._database_executor
-                print(f"[DATABASE_AGENT] 使用预创建的Agent实例")
-            else:
-                executor = self._create_database_agent()
-                print(f"[DATABASE_AGENT] 动态创建Agent实例")
-            
-            # 执行Agent
-            result = executor.invoke({
-                "input": state["question"]
+            question = state["question"]
+            
+            # 步骤1:生成SQL
+            print(f"[DATABASE_AGENT] 步骤1:生成SQL")
+            sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
+            
+            if not sql_result.get("success"):
+                print(f"[DATABASE_AGENT] SQL生成失败: {sql_result.get('error')}")
+                state["error"] = sql_result.get("error", "SQL生成失败")
+                state["error_code"] = 500
+                state["current_step"] = "database_error"
+                state["execution_path"].append("agent_database_error")
+                return state
+            
+            sql = sql_result.get("sql")
+            state["sql"] = sql
+            print(f"[DATABASE_AGENT] SQL生成成功: {sql}")
+            
+            # 步骤1.5:检查是否为解释性响应而非SQL
+            error_type = sql_result.get("error_type")
+            if error_type == "llm_explanation":
+                # LLM返回了解释性文本,直接作为最终答案
+                explanation = sql_result.get("error", "")
+                state["summary"] = explanation + " 请尝试提问其它问题。"
+                state["current_step"] = "database_completed"
+                state["execution_path"].append("agent_database")
+                print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
+                return state
+            
+            # 额外验证:检查SQL格式(防止工具误判)
+            from agent.utils import _is_valid_sql_format
+            if not _is_valid_sql_format(sql):
+                # 内容看起来不是SQL,当作解释性响应处理
+                state["summary"] = sql + " 请尝试提问其它问题。"
+                state["current_step"] = "database_completed"  
+                state["execution_path"].append("agent_database")
+                print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
+                return state
+            
+            # 步骤2:执行SQL
+            print(f"[DATABASE_AGENT] 步骤2:执行SQL")
+            execute_result = execute_sql.invoke({"sql": sql, "max_rows": 200})
+            
+            if not execute_result.get("success"):
+                print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
+                state["error"] = execute_result.get("error", "SQL执行失败")
+                state["error_code"] = 500
+                state["current_step"] = "database_error"
+                state["execution_path"].append("agent_database_error")
+                return state
+            
+            data_result = execute_result.get("data_result")
+            state["data_result"] = data_result
+            print(f"[DATABASE_AGENT] SQL执行成功,返回 {data_result.get('row_count', 0)} 行数据")
+            
+            # 步骤3:生成摘要
+            print(f"[DATABASE_AGENT] 步骤3:生成摘要")
+            summary_result = generate_summary.invoke({
+                "question": question,
+                "data_result": data_result,
+                "sql": sql
             })
             
-            # 解析Agent执行结果
-            self._parse_database_agent_result(state, result)
+            if not summary_result.get("success"):
+                print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
+                # 摘要生成失败不是致命错误,使用默认摘要
+                state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
+            else:
+                state["summary"] = summary_result.get("summary")
+                print(f"[DATABASE_AGENT] 摘要生成成功")
             
             state["current_step"] = "database_completed"
             state["execution_path"].append("agent_database")
@@ -165,60 +174,20 @@ class CituLangGraphAgent:
             
         except Exception as e:
             print(f"[ERROR] 数据库Agent异常: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
             state["error"] = f"数据库查询失败: {str(e)}"
             state["error_code"] = 500
             state["current_step"] = "database_error"
             state["execution_path"].append("agent_database_error")
             return state
     
-    def _create_chat_agent(self):
-        """创建聊天专用Agent(预创建)"""
-        from agent.config import get_nested_config
-        
-        # 获取配置
-        max_iterations = get_nested_config(self.config, "chat_agent.max_iterations", 3)
-        enable_verbose = get_nested_config(self.config, "chat_agent.enable_verbose", True)
-        
-        chat_prompt = ChatPromptTemplate.from_messages([
-            SystemMessage(content="""
-你是Citu智能数据问答平台的友好助手。
-
-使用 general_chat 工具来处理用户的一般性问题、概念解释、操作指导等。
-
-特别注意:
-- 如果用户的问题可能涉及数据查询,建议他们尝试数据库查询功能
-- 如果问题不够明确,主动询问更多细节以便更好地帮助用户
-- 对于模糊的问题,可以提供多种可能的解决方案
-- 当遇到不确定的问题时,通过友好的对话来澄清用户意图
-"""),
-            MessagesPlaceholder(variable_name="chat_history", optional=True),
-            HumanMessage(content="{input}"),
-            MessagesPlaceholder(variable_name="agent_scratchpad")
-        ])
-        
-        chat_tools = [general_chat]
-        agent = create_openai_tools_agent(self.llm, chat_tools, chat_prompt)
-        
-        return AgentExecutor(
-            agent=agent,
-            tools=chat_tools,
-            verbose=enable_verbose,
-            handle_parsing_errors=True,
-            max_iterations=max_iterations
-        )
-    
     def _agent_chat_node(self, state: AgentState) -> AgentState:
-        """聊天Agent节点 - 使用预创建或动态创建的Agent"""
+        """聊天Agent节点 - 直接工具调用模式"""
         try:
             print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
             
-            # 使用预创建的Agent或动态创建
-            if self._chat_executor is not None:
-                executor = self._chat_executor
-                print(f"[CHAT_AGENT] 使用预创建的Agent实例")
-            else:
-                executor = self._create_chat_agent()
-                print(f"[CHAT_AGENT] 动态创建Agent实例")
+            question = state["question"]
             
             # 构建上下文
             enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
@@ -226,17 +195,21 @@ class CituLangGraphAgent:
             if enable_context_injection and state.get("classification_reason"):
                 context = f"分类原因: {state['classification_reason']}"
             
-            # 执行Agent
-            input_text = state["question"]
-            if context:
-                input_text = f"{state['question']}\n\n上下文: {context}"
-            
-            result = executor.invoke({
-                "input": input_text
+            # 直接调用general_chat工具
+            print(f"[CHAT_AGENT] 调用general_chat工具")
+            chat_result = general_chat.invoke({
+                "question": question,
+                "context": context
             })
             
-            # 提取聊天响应
-            state["chat_response"] = result.get("output", "")
+            if chat_result.get("success"):
+                state["chat_response"] = chat_result.get("response", "")
+                print(f"[CHAT_AGENT] 聊天处理成功")
+            else:
+                # 处理失败,使用备用响应
+                state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
+                print(f"[CHAT_AGENT] 聊天处理失败,使用备用响应: {chat_result.get('error')}")
+            
             state["current_step"] = "chat_completed"
             state["execution_path"].append("agent_chat")
             
@@ -245,6 +218,8 @@ class CituLangGraphAgent:
             
         except Exception as e:
             print(f"[ERROR] 聊天Agent异常: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
             state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
             state["current_step"] = "chat_error"
             state["execution_path"].append("agent_chat_error")
@@ -276,14 +251,14 @@ class CituLangGraphAgent:
             
             elif state["question_type"] == "DATABASE":
                 # 数据库查询类型
-                if state.get("data_result") and state.get("summary"):
-                    # 完整的数据库查询流程
+                if state.get("summary"):
+                    # 有摘要的情况(包括解释性响应和完整查询结果)
                     state["final_response"] = {
                         "success": True,
                         "response": state["summary"],
                         "type": "DATABASE",
                         "sql": state.get("sql"),
-                        "data_result": state["data_result"],
+                        "data_result": state.get("data_result"),  # 可能为None(解释性响应)
                         "summary": state["summary"],
                         "execution_path": state["execution_path"],
                         "classification_info": {
@@ -293,7 +268,7 @@ class CituLangGraphAgent:
                         }
                     }
                 else:
-                    # 数据库查询失败,但有部分结果
+                    # 数据库查询失败,没有任何结果
                     state["final_response"] = {
                         "success": False,
                         "error": state.get("error", "数据库查询未完成"),
@@ -351,47 +326,6 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
     
-    def _parse_database_agent_result(self, state: AgentState, agent_result: Dict[str, Any]):
-        """解析数据库Agent的执行结果"""
-        try:
-            output = agent_result.get("output", "")
-            intermediate_steps = agent_result.get("intermediate_steps", [])
-            
-            # 从intermediate_steps中提取工具调用结果
-            for step in intermediate_steps:
-                if len(step) >= 2:
-                    action, observation = step[0], step[1]
-                    
-                    if hasattr(action, 'tool') and hasattr(action, 'tool_input'):
-                        tool_name = action.tool
-                        tool_result = observation
-                        
-                        # 解析工具结果
-                        if tool_name == "generate_sql" and isinstance(tool_result, dict):
-                            if tool_result.get("success"):
-                                state["sql"] = tool_result.get("sql")
-                            else:
-                                state["error"] = tool_result.get("error")
-                        
-                        elif tool_name == "execute_sql" and isinstance(tool_result, dict):
-                            if tool_result.get("success"):
-                                state["data_result"] = tool_result.get("data_result")
-                            else:
-                                state["error"] = tool_result.get("error")
-                        
-                        elif tool_name == "generate_summary" and isinstance(tool_result, dict):
-                            if tool_result.get("success"):
-                                state["summary"] = tool_result.get("summary")
-            
-            # 如果没有从工具结果中获取到摘要,使用Agent的最终输出
-            if not state.get("summary") and output:
-                state["summary"] = output
-                
-        except Exception as e:
-            print(f"[WARNING] 解析数据库Agent结果失败: {str(e)}")
-            # 使用Agent的输出作为摘要
-            state["summary"] = agent_result.get("output", "查询处理完成")
-    
     def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
         """
         统一的问题处理入口
@@ -489,7 +423,7 @@ class CituLangGraphAgent:
                     "test_result": test_result.get("success", False),
                     "workflow_compiled": self.workflow is not None,
                     "tools_count": len(self.tools),
-                    "agent_reuse_enabled": self._database_executor is not None and self._chat_executor is not None,
+                    "agent_reuse_enabled": False,
                     "message": "Agent健康检查完成"
                 }
             else:
@@ -499,7 +433,7 @@ class CituLangGraphAgent:
                     "test_result": True,
                     "workflow_compiled": self.workflow is not None,
                     "tools_count": len(self.tools),
-                    "agent_reuse_enabled": self._database_executor is not None and self._chat_executor is not None,
+                    "agent_reuse_enabled": False,
                     "message": "Agent简单健康检查完成"
                 }
             

+ 28 - 25
agent/classifier.py

@@ -17,7 +17,7 @@ class QuestionClassifier:
     """
     
     def __init__(self):
-        # 从配置文件加载阈值参数
+    # 从配置文件加载阈值参数
         try:
             from agent.config import get_current_config, get_nested_config
             config = get_current_config()
@@ -40,10 +40,12 @@ class QuestionClassifier:
             self.uncertain_confidence = 0.2
             print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
         
+        # 移除了 LLM 实例存储,现在使用 Vanna 实例
+        
         self.db_keywords = {
             "数据类": [
                 "收入", "销量", "数量", "平均", "总计", "统计", "合计", "累计",
-                "营业额", "利润", "成本", "费用", "金额", "价格", "单价"
+                "营业额", "利润", "成本", "费用", "金额", "价格", "单价", "服务区", "多少个"
             ],
             "分析类": [
                 "分组", "排行", "排名", "增长率", "趋势", "对比", "比较", "占比",
@@ -62,7 +64,7 @@ class QuestionClassifier:
         # SQL关键词
         self.sql_patterns = [
             r"\b(select|from|where|group by|order by|having|join)\b",
-            r"\b(查询|统计|汇总|计算|分析)\b",
+            r"\b(查询|统计|汇总|计算|分析|有多少)\b",
             r"\b(表|字段|数据库)\b"
         ]
         
@@ -156,39 +158,40 @@ class QuestionClassifier:
                 method="rule_based"
             )
     
+
     def _llm_classify(self, question: str) -> ClassificationResult:
         """基于LLM的分类"""
         try:
-            from common.utils import get_current_llm_config
-            from customllm.qianwen_chat import QianWenChat
-            
-            llm_config = get_current_llm_config()
-            llm = QianWenChat(config=llm_config)
+            # 使用 Vanna 实例进行分类
+            from common.vanna_instance import get_vanna_instance
+            vn = get_vanna_instance()
             
             # 分类提示词
             classification_prompt = f"""
-请判断以下问题是否需要查询数据库。
+    请判断以下问题是否需要查询数据库。
 
-问题: {question}
+    问题: {question}
 
-判断标准:
-1. 如果问题涉及数据查询、统计、分析、报表等,返回 "DATABASE"
-2. 如果问题是一般性咨询、概念解释、操作指导、闲聊等,返回 "CHAT"
+    判断标准:
+    1. 如果问题涉及数据查询、统计、分析、报表等,返回 "DATABASE"
+    2. 如果问题是一般性咨询、概念解释、操作指导、闲聊等,返回 "CHAT"
 
-请只返回 "DATABASE" 或 "CHAT",并在下一行简要说明理由。
+    请只返回 "DATABASE" 或 "CHAT",并在下一行简要说明理由。
 
-格式:
-分类: [DATABASE/CHAT]
-理由: [简要说明]
-置信度: [0.0-1.0之间的数字]
-"""
+    格式:
+    分类: [DATABASE/CHAT]
+    理由: [简要说明]
+    置信度: [0.0-1.0之间的数字]
+    """
             
-            prompt = [
-                llm.system_message("你是一个专业的问题分类助手,能准确判断问题类型。"),
-                llm.user_message(classification_prompt)
-            ]
+            # 分类专用的系统提示词
+            system_prompt = "你是一个专业的问题分类助手,能准确判断问题类型。请严格按照要求的格式返回分类结果。"
             
-            response = llm.submit_prompt(prompt)
+            # 使用 Vanna 实例的 chat_with_llm 方法
+            response = vn.chat_with_llm(
+                question=classification_prompt,
+                system_prompt=system_prompt
+            )
             
             # 解析响应
             return self._parse_llm_response(response)
@@ -201,7 +204,7 @@ class QuestionClassifier:
                 reason=f"LLM分类异常: {str(e)}",
                 method="llm_error"
             )
-    
+        
     def _parse_llm_response(self, response: str) -> ClassificationResult:
         """解析LLM响应"""
         try:

+ 14 - 15
agent/tools/general_chat.py

@@ -1,8 +1,7 @@
 # agent/tools/general_chat.py
 from langchain.tools import tool
 from typing import Dict, Any, Optional
-from common.utils import get_current_llm_config
-from customllm.qianwen_chat import QianWenChat
+from common.vanna_instance import get_vanna_instance
 
 @tool
 def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]:
@@ -25,7 +24,7 @@ def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]
         print(f"[TOOL:general_chat] 处理聊天问题: {question}")
         
         system_prompt = """
-你是Cito智能数据问答平台的AI助手,专门为用户提供帮助和支持。
+你是Citu智能数据问答平台的AI助手,专门为用户提供帮助和支持。
 
 你的职责包括:
 1. 回答关于平台功能和使用方法的问题
@@ -42,18 +41,18 @@ def general_chat(question: str, context: Optional[str] = None) -> Dict[str, Any]
 - 保持中文回答,语言自然流畅
 """
         
-        # 生成聊天响应
-        llm_config = get_current_llm_config()
-        llm = QianWenChat(config=llm_config)
-        
-        messages = [llm.system_message(system_prompt)]
-        
+        # 构建完整的问题(如果有上下文)
         if context:
-            messages.append(llm.user_message(f"上下文信息:{context}"))
-        
-        messages.append(llm.user_message(question))
+            full_question = f"上下文信息:{context}\n\n用户问题:{question}"
+        else:
+            full_question = question
         
-        response = llm.submit_prompt(messages)
+        # 使用 Vanna 实例进行聊天
+        vn = get_vanna_instance()
+        response = vn.chat_with_llm(
+            question=full_question,
+            system_prompt=system_prompt
+        )
         
         if response:
             print(f"[TOOL:general_chat] 聊天响应生成成功: {response[:100]}...")
@@ -82,7 +81,7 @@ def _get_fallback_response(question: str) -> str:
     question_lower = question.lower()
     
     if any(keyword in question_lower for keyword in ["你好", "hello", "hi"]):
-        return "您好!我是Cito智能数据问答平台的AI助手。我可以帮助您进行数据查询和分析,也可以回答关于平台使用的问题。有什么可以帮助您的吗?"
+        return "您好!我是Citu智能数据问答平台的AI助手。我可以帮助您进行数据查询和分析,也可以回答关于平台使用的问题。有什么可以帮助您的吗?"
     
     elif any(keyword in question_lower for keyword in ["谢谢", "thank"]):
         return "不客气!如果您还有其他问题,随时可以问我。我可以帮您查询数据或解答疑问。"
@@ -91,7 +90,7 @@ def _get_fallback_response(question: str) -> str:
         return "再见!期待下次为您服务。如果需要数据查询或其他帮助,随时欢迎回来!"
     
     elif any(keyword in question_lower for keyword in ["怎么", "如何", "怎样"]):
-        return "我理解您想了解使用方法。Cito平台支持自然语言数据查询,您可以直接用中文描述您想要查询的数据,比如'查询本月销售额'或'统计各部门人数'等。有具体问题欢迎继续询问!"
+        return "我理解您想了解使用方法。Citu平台支持自然语言数据查询,您可以直接用中文描述您想要查询的数据,比如'查询本月销售额'或'统计各部门人数'等。有具体问题欢迎继续询问!"
     
     elif any(keyword in question_lower for keyword in ["功能", "作用", "能做"]):
         return "我主要可以帮助您:\n1. 进行数据库查询和分析\n2. 解答平台使用问题\n3. 解释数据相关概念\n4. 提供操作指导\n\n您可以用自然语言描述数据需求,我会帮您生成相应的查询。"

+ 2 - 1
agent/tools/sql_generation.py

@@ -28,7 +28,7 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
         sql = vn.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
         
         if sql is None:
-            # 检查是否有LLM解释性文本
+            # 检查是否有LLM解释性文本(已在base_llm_chat.py中处理thinking内容)
             explanation = getattr(vn, 'last_llm_explanation', None)
             if explanation:
                 return {
@@ -65,6 +65,7 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
         ]
         
         if any(indicator in sql_clean.lower() for indicator in error_indicators):
+            # 这是解释性文本(已在base_llm_chat.py中处理thinking内容)
             return {
                 "success": False,
                 "sql": None,

+ 3 - 22
agent/tools/summary_generation.py

@@ -2,9 +2,7 @@
 from langchain.tools import tool
 from typing import Dict, Any
 import pandas as pd
-import re
 from common.vanna_instance import get_vanna_instance
-import app_config
 
 @tool
 def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Dict[str, Any]:
@@ -44,7 +42,7 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
                 "message": "空数据摘要"
             }
         
-        # 调用Vanna生成摘要
+        # 调用Vanna生成摘要(thinking内容已在base_llm_chat.py中统一处理)
         vn = get_vanna_instance()
         summary = vn.generate_summary(question=question, df=df)
         
@@ -52,15 +50,11 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
             # 生成默认摘要
             summary = _generate_default_summary(question, data_result, sql)
         
-        # 处理thinking内容
-        display_summary_thinking = getattr(app_config, 'DISPLAY_SUMMARY_THINKING', False)
-        processed_summary = _process_thinking_content(summary, display_summary_thinking)
-        
-        print(f"[TOOL:generate_summary] 摘要生成成功: {processed_summary[:100]}...")
+        print(f"[TOOL:generate_summary] 摘要生成成功: {summary[:100]}...")
         
         return {
             "success": True,
-            "summary": processed_summary,
+            "summary": summary,
             "message": "摘要生成成功"
         }
         
@@ -91,19 +85,6 @@ def _reconstruct_dataframe(data_result: Dict[str, Any]) -> pd.DataFrame:
         print(f"[WARNING] DataFrame重构失败: {str(e)}")
         return pd.DataFrame()
 
-def _process_thinking_content(summary: str, display_thinking: bool) -> str:
-    """处理thinking内容"""
-    if not summary:
-        return ""
-    
-    if not display_thinking:
-        # 移除thinking标签内容
-        cleaned_summary = re.sub(r'<think>.*?</think>\s*', '', summary, flags=re.DOTALL | re.IGNORECASE)
-        cleaned_summary = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_summary)
-        return cleaned_summary.strip()
-    
-    return summary
-
 def _generate_default_summary(question: str, data_result: Dict[str, Any], sql: str) -> str:
     """生成默认摘要"""
     try:

+ 174 - 28
agent/utils.py

@@ -3,8 +3,10 @@
 Agent相关的工具函数
 """
 import functools
-from typing import Dict, Any, Callable
-from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+import json
+from typing import Dict, Any, Callable, List, Optional
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage
+from langchain_core.tools import BaseTool
 
 def handle_tool_errors(func: Callable) -> Callable:
     """
@@ -24,11 +26,12 @@ def handle_tool_errors(func: Callable) -> Callable:
     return wrapper
 
 class LLMWrapper:
-    """自定义LLM的LangChain兼容包装器"""
+    """自定义LLM的LangChain兼容包装器,支持工具调用"""
     
     def __init__(self, llm_instance):
         self.llm = llm_instance
         self._model_name = getattr(llm_instance, 'model', 'custom_llm')
+        self._bound_tools = []
     
     def invoke(self, input_data, **kwargs):
         """LangChain invoke接口"""
@@ -40,27 +43,143 @@ class LLMWrapper:
             else:
                 messages = [HumanMessage(content=str(input_data))]
             
-            # 转换消息格式
-            prompt = []
-            for msg in messages:
-                if isinstance(msg, SystemMessage):
-                    prompt.append(self.llm.system_message(msg.content))
-                elif isinstance(msg, HumanMessage):
-                    prompt.append(self.llm.user_message(msg.content))
-                elif isinstance(msg, AIMessage):
-                    prompt.append(self.llm.assistant_message(msg.content))
-                else:
-                    prompt.append(self.llm.user_message(str(msg.content)))
+            # 检查是否需要工具调用
+            if self._bound_tools and self._should_use_tools(messages):
+                return self._invoke_with_tools(messages, **kwargs)
+            else:
+                return self._invoke_without_tools(messages, **kwargs)
+                
+        except Exception as e:
+            print(f"[ERROR] LLM包装器调用失败: {str(e)}")
+            return AIMessage(content=f"LLM调用失败: {str(e)}")
+    
+    def _should_use_tools(self, messages: List[BaseMessage]) -> bool:
+        """判断是否应该使用工具"""
+        # 检查最后一条消息是否包含工具相关的指令
+        if messages:
+            last_message = messages[-1]
+            if isinstance(last_message, HumanMessage):
+                content = last_message.content.lower()
+                # 检查是否包含工具相关的关键词
+                tool_keywords = ["生成sql", "执行sql", "generate sql", "execute sql", "查询", "数据库"]
+                return any(keyword in content for keyword in tool_keywords)
+        return True  # 默认使用工具
+    
+    def _invoke_with_tools(self, messages: List[BaseMessage], **kwargs):
+        """使用工具调用的方式"""
+        try:
+            # 构建工具调用提示
+            tool_prompt = self._build_tool_prompt(messages)
             
             # 调用底层LLM
-            response = self.llm.submit_prompt(prompt, **kwargs)
+            response = self.llm.submit_prompt(tool_prompt, **kwargs)
             
-            # 返回LangChain格式的结果
-            return AIMessage(content=response)
+            # 解析工具调用
+            tool_calls = self._parse_tool_calls(response)
             
+            if tool_calls:
+                # 如果有工具调用,返回包含工具调用的AIMessage
+                return AIMessage(
+                    content=response,
+                    tool_calls=tool_calls
+                )
+            else:
+                # 没有工具调用,返回普通响应
+                return AIMessage(content=response)
+                
         except Exception as e:
-            print(f"[ERROR] LLM包装器调用失败: {str(e)}")
-            return AIMessage(content=f"LLM调用失败: {str(e)}")
+            print(f"[ERROR] 工具调用失败: {str(e)}")
+            return self._invoke_without_tools(messages, **kwargs)
+    
+    def _invoke_without_tools(self, messages: List[BaseMessage], **kwargs):
+        """不使用工具的普通调用"""
+        # 转换消息格式
+        prompt = []
+        for msg in messages:
+            if isinstance(msg, SystemMessage):
+                prompt.append(self.llm.system_message(msg.content))
+            elif isinstance(msg, HumanMessage):
+                prompt.append(self.llm.user_message(msg.content))
+            elif isinstance(msg, AIMessage):
+                prompt.append(self.llm.assistant_message(msg.content))
+            else:
+                prompt.append(self.llm.user_message(str(msg.content)))
+        
+        # 调用底层LLM
+        response = self.llm.submit_prompt(prompt, **kwargs)
+        
+        # 返回LangChain格式的结果
+        return AIMessage(content=response)
+    
+    def _build_tool_prompt(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
+        """构建包含工具信息的提示"""
+        prompt = []
+        
+        # 添加系统消息,包含工具定义
+        system_content = self._get_system_message_with_tools(messages)
+        prompt.append(self.llm.system_message(system_content))
+        
+        # 添加用户消息
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                prompt.append(self.llm.user_message(msg.content))
+            elif isinstance(msg, AIMessage) and not isinstance(msg, SystemMessage):
+                prompt.append(self.llm.assistant_message(msg.content))
+        
+        return prompt
+    
+    def _get_system_message_with_tools(self, messages: List[BaseMessage]) -> str:
+        """获取包含工具定义的系统消息"""
+        # 查找原始系统消息
+        original_system = ""
+        for msg in messages:
+            if isinstance(msg, SystemMessage):
+                original_system = msg.content
+                break
+        
+        # 构建工具定义
+        tool_definitions = []
+        for tool in self._bound_tools:
+            tool_def = {
+                "name": tool.name,
+                "description": tool.description,
+                "parameters": getattr(tool, 'args_schema', {})
+            }
+            tool_definitions.append(f"- {tool.name}: {tool.description}")
+        
+        # 组合系统消息
+        if tool_definitions:
+            tools_text = "\n".join(tool_definitions)
+            return f"""{original_system}
+
+你有以下工具可以使用:
+{tools_text}
+
+使用工具时,请明确说明你要调用哪个工具以及需要的参数。对于数据库查询问题,请按照以下步骤:
+1. 使用 generate_sql 工具生成SQL查询
+2. 使用 execute_sql 工具执行SQL查询
+3. 使用 generate_summary 工具生成结果摘要
+
+请直接开始执行工具调用,不要只是提供指导。"""
+        else:
+            return original_system
+    
+    def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]:
+        """解析LLM响应中的工具调用"""
+        tool_calls = []
+        
+        # 简单的工具调用解析逻辑
+        # 这里可以根据实际的LLM响应格式进行调整
+        
+        response_lower = response.lower()
+        if "generate_sql" in response_lower:
+            tool_calls.append({
+                "name": "generate_sql",
+                "args": {},
+                "id": "generate_sql_call"
+            })
+        
+        return tool_calls
     
     @property
     def model_name(self) -> str:
@@ -68,6 +187,7 @@ class LLMWrapper:
     
     def bind_tools(self, tools):
         """绑定工具(用于支持工具调用)"""
+        self._bound_tools = tools if isinstance(tools, list) else [tools]
         return self
 
 def get_compatible_llm():
@@ -80,26 +200,52 @@ def get_compatible_llm():
         if llm_config.get("base_url") and llm_config.get("api_key"):
             try:
                 from langchain_openai import ChatOpenAI
-                return ChatOpenAI(
+                llm = ChatOpenAI(
                     base_url=llm_config.get("base_url"),
                     api_key=llm_config.get("api_key"),
                     model=llm_config.get("model"),
                     temperature=llm_config.get("temperature", 0.7)
                 )
+                print("[INFO] 使用标准OpenAI兼容API")
+                return llm
             except ImportError:
-                print("[WARNING] langchain_openai 未安装,使用自定义包装器")
+                print("[WARNING] langchain_openai 未安装,使用 Vanna 实例包装器")
         
-        # 使用自定义LLM包装器
-        from customllm.qianwen_chat import QianWenChat
-        custom_llm = QianWenChat(config=llm_config)
-        return LLMWrapper(custom_llm)
+        # 优先使用统一的 Vanna 实例
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        print("[INFO] 使用Vanna实例包装器")
+        return LLMWrapper(vn)
         
     except Exception as e:
-        print(f"[ERROR] 获取LLM失败: {str(e)}")
-        # 返回基础包装器
+        print(f"[ERROR] 获取 Vanna 实例失败: {str(e)}")
+        # 回退到原有逻辑
         from common.utils import get_current_llm_config
         from customllm.qianwen_chat import QianWenChat
         
         llm_config = get_current_llm_config()
         custom_llm = QianWenChat(config=llm_config)
-        return LLMWrapper(custom_llm)
+        print("[INFO] 使用QianWen包装器")
+        return LLMWrapper(custom_llm)
+
+def _is_valid_sql_format(sql_text: str) -> bool:
+    """验证文本是否为有效的SQL查询格式"""
+    if not sql_text or not sql_text.strip():
+        return False
+    
+    sql_clean = sql_text.strip().upper()
+    
+    # 检查是否以SQL关键字开头
+    sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'WITH']
+    starts_with_sql = any(sql_clean.startswith(keyword) for keyword in sql_keywords)
+    
+    # 检查是否包含解释性语言
+    explanation_phrases = [
+        '无法', '不能', '抱歉', 'SORRY', 'UNABLE', 'CANNOT', 
+        '需要更多信息', '请提供', '表不存在', '字段不存在',
+        '不清楚', '不确定', '没有足够', '无法理解', '无法生成',
+        '无法确定', '不支持', '不可用', '缺少', '未找到'
+    ]
+    contains_explanation = any(phrase in sql_clean for phrase in explanation_phrases)
+    
+    return starts_with_sql and not contains_explanation

+ 3 - 2
app_config.py

@@ -134,10 +134,11 @@ TRAINING_DATA_PATH = "./training/data"
 # 是否启用问题重写功能,也就是上下文问题合并。
 REWRITE_QUESTION_ENABLED = False
 
-# 是否在摘要中显示thinking过程
+# 是否在返回结果中显示thinking过程
 # True: 显示 <think></think> 内容
 # False: 隐藏 <think></think> 内容,只显示最终答案
-DISPLAY_SUMMARY_THINKING = False
+# 此参数影响:摘要生成、SQL生成解释性文本、API返回结果等所有输出内容
+DISPLAY_RESULT_THINKING = False
 
 # 是否启用向量查询结果得分阈值过滤
 # result = max((n + 1) // 2, 1)

+ 8 - 52
citu_app.py

@@ -6,8 +6,10 @@ import pandas as pd
 import common.result as result
 from datetime import datetime, timedelta
 from common.session_aware_cache import WebSessionAwareMemoryCache
-from app_config import API_MAX_RETURN_ROWS, DISPLAY_SUMMARY_THINKING
+from app_config import API_MAX_RETURN_ROWS
 import re
+import chainlit as cl
+import json
 
 # 设置默认的最大返回行数
 DEFAULT_MAX_RETURN_ROWS = 200
@@ -32,34 +34,6 @@ app = VannaFlaskApp(
     debug=True
 )
 
-
-def _remove_thinking_content(text: str) -> str:
-    """
-    移除文本中的 <think></think> 标签及其内容
-    复用自 base_llm_chat.py 中的同名方法
-    
-    Args:
-        text (str): 包含可能的 thinking 标签的文本
-        
-    Returns:
-        str: 移除 thinking 内容后的文本
-    """
-    if not text:
-        return text
-    
-    # 移除 <think>...</think> 标签及其内容(支持多行)
-    # 使用 re.DOTALL 标志使 . 匹配包括换行符在内的任何字符
-    cleaned_text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL | re.IGNORECASE)
-    
-    # 移除可能的多余空行
-    cleaned_text = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_text)
-    
-    # 去除开头和结尾的空白字符
-    cleaned_text = cleaned_text.strip()
-    
-    return cleaned_text
-
-
 # 修改ask接口,支持前端传递session_id
 @app.flask_app.route('/api/v0/ask', methods=['POST'])
 def ask_full():
@@ -89,18 +63,12 @@ def ask_full():
 
         # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
         if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
-            # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
-            explanation_message = vn.last_llm_explanation
-            if not DISPLAY_SUMMARY_THINKING:
-                explanation_message = _remove_thinking_content(explanation_message)
-                print(f"[DEBUG] 隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
-            
             # 在解释性文本末尾添加提示语
-            explanation_message = explanation_message + "请尝试提问其它问题。"
+            explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
             
             # 使用 result.failed 返回,success为false,但在message中包含LLM友好的解释
             return jsonify(result.failed(
-                message=explanation_message,  # 处理的解释性文本
+                message=explanation_message,  # 已处理的解释性文本
                 code=400,  # 业务逻辑错误,使用400
                 data={
                     "sql": None,
@@ -135,7 +103,7 @@ def ask_full():
             rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
             columns = list(df.columns)
             
-            # 生成数据摘要
+            # 生成数据摘要(thinking内容已在base_llm_chat.py中统一处理)
             try:
                 summary = vn.generate_summary(question=question, df=df)
                 print(f"[INFO] 成功生成摘要: {summary}")
@@ -157,14 +125,8 @@ def ask_full():
         
         # 即使发生异常,也检查是否有业务层面的解释
         if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
-            # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
-            explanation_message = vn.last_llm_explanation
-            if not DISPLAY_SUMMARY_THINKING:
-                explanation_message = _remove_thinking_content(explanation_message)
-                print(f"[DEBUG] 异常处理中隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
-            
             # 在解释性文本末尾添加提示语
-            explanation_message = explanation_message + "请尝试提问其它问题。"
+            explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
             
             return jsonify(result.failed(
                 message=explanation_message,
@@ -262,14 +224,8 @@ def ask_cached():
             
             # 检查是否有LLM解释性文本(无法生成SQL的情况)
             if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
-                # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
-                explanation_message = vn.last_llm_explanation
-                if not DISPLAY_SUMMARY_THINKING:
-                    explanation_message = _remove_thinking_content(explanation_message)
-                    print(f"[DEBUG] ask_cached中隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
-                
                 # 在解释性文本末尾添加提示语
-                explanation_message = explanation_message + "请尝试用其它方式提问。"
+                explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
                 
                 return jsonify(result.failed(
                     message=explanation_message,

+ 85 - 13
customllm/base_llm_chat.py

@@ -5,7 +5,7 @@ import pandas as pd
 import plotly.graph_objs
 from vanna.base import VannaBase
 # 导入配置参数
-from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_SUMMARY_THINKING
+from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_RESULT_THINKING
 
 
 class BaseLLMChat(VannaBase, ABC):
@@ -220,6 +220,12 @@ class BaseLLMChat(VannaBase, ABC):
 
         # 调用submit_prompt方法,并清理结果
         plotly_code = self.submit_prompt(message_log, **kwargs)
+        
+        # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+        if not DISPLAY_RESULT_THINKING:
+            original_code = plotly_code
+            plotly_code = self._remove_thinking_content(plotly_code)
+            print(f"[DEBUG] generate_plotly_code隐藏thinking内容 - 原始长度: {len(original_code)}, 处理后长度: {len(plotly_code)}")
 
         return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
 
@@ -290,7 +296,11 @@ class BaseLLMChat(VannaBase, ABC):
             
             if not sql or sql.strip() == "":
                 print(f"[WARNING] 生成的SQL为空")
-                self.last_llm_explanation = "无法生成SQL查询,可能是问题描述不够清晰或缺少必要的数据表信息。"
+                explanation = "无法生成SQL查询,可能是问题描述不够清晰或缺少必要的数据表信息。"
+                # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+                if not DISPLAY_RESULT_THINKING:
+                    explanation = self._remove_thinking_content(explanation)
+                self.last_llm_explanation = explanation
                 return None
             
             # 替换 "\_" 为 "_",解决特殊字符转义问题
@@ -309,16 +319,24 @@ class BaseLLMChat(VannaBase, ABC):
             for indicator in error_indicators:
                 if indicator in sql_lower:
                     print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
-                    # 保存LLM的解释性文本
-                    self.last_llm_explanation = sql
+                    # 保存LLM的解释性文本,并根据配置处理thinking内容
+                    explanation = sql
+                    if not DISPLAY_RESULT_THINKING:
+                        explanation = self._remove_thinking_content(explanation)
+                        print(f"[DEBUG] 隐藏thinking内容 - SQL生成解释性文本")
+                    self.last_llm_explanation = explanation
                     return None
             
             # 简单检查是否像SQL语句(至少包含一些SQL关键词)
             sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
             if not any(keyword in sql_lower for keyword in sql_keywords):
                 print(f"[WARNING] 返回内容不像有效SQL: {sql}")
-                # 保存LLM的解释性文本
-                self.last_llm_explanation = sql
+                # 保存LLM的解释性文本,并根据配置处理thinking内容
+                explanation = sql
+                if not DISPLAY_RESULT_THINKING:
+                    explanation = self._remove_thinking_content(explanation)
+                    print(f"[DEBUG] 隐藏thinking内容 - SQL生成非有效SQL内容")
+                self.last_llm_explanation = explanation
                 return None
                 
             print(f"[SUCCESS] 成功生成SQL:\n {sql}")
@@ -332,7 +350,11 @@ class BaseLLMChat(VannaBase, ABC):
             # 导入traceback以获取详细错误信息
             import traceback
             print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
-            self.last_llm_explanation = f"SQL生成过程中出现异常: {str(e)}"
+            explanation = f"SQL生成过程中出现异常: {str(e)}"
+            # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+            if not DISPLAY_RESULT_THINKING:
+                explanation = self._remove_thinking_content(explanation)
+            self.last_llm_explanation = explanation
             return None
 
     def generate_question(self, sql: str, **kwargs) -> str:
@@ -344,21 +366,64 @@ class BaseLLMChat(VannaBase, ABC):
             self.user_message(sql)
         ]
         response = self.submit_prompt(prompt, **kwargs)
+        
+        # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+        if not DISPLAY_RESULT_THINKING:
+            original_response = response
+            response = self._remove_thinking_content(response)
+            print(f"[DEBUG] generate_question隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
+        
         return response
 
-    def chat_with_llm(self, question: str, **kwargs) -> str:
+    # def chat_with_llm(self, question: str, **kwargs) -> str:
+    #     """
+    #     直接与LLM对话,不涉及SQL生成
+    #     """
+    #     try:
+    #         prompt = [
+    #             self.system_message(
+    #                 "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
+    #             ),
+    #             self.user_message(question)
+    #         ]
+    #         response = self.submit_prompt(prompt, **kwargs)
+    #         return response
+    #     except Exception as e:
+    #         print(f"[ERROR] LLM对话失败: {str(e)}")
+    #         return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+
+    def chat_with_llm(self, question: str, system_prompt: str = None, **kwargs) -> str:
         """
-        直接与LLM对话,不涉及SQL生成
+        直接与LLM对话,不涉及SQL生成        
+        Args:
+            question: 用户问题
+            system_prompt: 自定义系统提示词,如果为None则使用默认提示词
+            **kwargs: 其他传递给submit_prompt的参数            
+        Returns:
+            LLM的响应文本
         """
         try:
+            # 如果没有提供自定义系统提示词,使用默认的
+            if system_prompt is None:
+                system_prompt = (
+                    "你是一个友好的AI助手,请用中文回答用户的问题。"
+                )
+            
             prompt = [
-                self.system_message(
-                    "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
-                ),
+                self.system_message(system_prompt),
                 self.user_message(question)
             ]
+            
             response = self.submit_prompt(prompt, **kwargs)
+            
+            # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+            if not DISPLAY_RESULT_THINKING:
+                original_response = response
+                response = self._remove_thinking_content(response)
+                print(f"[DEBUG] chat_with_llm隐藏thinking内容 - 原始长度: {len(original_response)}, 处理后长度: {len(response)}")
+            
             return response
+            
         except Exception as e:
             print(f"[ERROR] LLM对话失败: {str(e)}")
             return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
@@ -395,6 +460,13 @@ class BaseLLMChat(VannaBase, ABC):
             ]
             
             rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
+            
+            # 根据 DISPLAY_RESULT_THINKING 参数处理thinking内容
+            if not DISPLAY_RESULT_THINKING:
+                original_question = rewritten_question
+                rewritten_question = self._remove_thinking_content(rewritten_question)
+                print(f"[DEBUG] generate_rewritten_question隐藏thinking内容 - 原始长度: {len(original_question)}, 处理后长度: {len(rewritten_question)}")
+            
             print(f"[DEBUG] 合并后的问题: {rewritten_question}")
             return rewritten_question
             
@@ -452,7 +524,7 @@ class BaseLLMChat(VannaBase, ABC):
             summary = self.submit_prompt(message_log, **kwargs)
             
             # 检查是否需要隐藏 thinking 内容
-            display_thinking = kwargs.get("display_summary_thinking", DISPLAY_SUMMARY_THINKING)
+            display_thinking = kwargs.get("display_result_thinking", DISPLAY_RESULT_THINKING)
             
             if not display_thinking:
                 # 移除 <think></think> 标签及其内容

+ 168 - 0
docs/thinking_content_refactor_summary.md

@@ -0,0 +1,168 @@
+# Thinking内容处理重构总结
+
+## 重构目标
+
+统一所有 `<thinking></thinking>` 标签的处理逻辑,避免重复代码,确保 `DISPLAY_RESULT_THINKING` 参数在整个系统中的一致性控制。
+
+## 重构前的问题
+
+### 1. 重复的处理逻辑
+在多个文件中都有 `_remove_thinking_content()` 函数的重复实现:
+- `agent/tools/sql_generation.py`
+- `agent/tools/summary_generation.py` 
+- `citu_app.py`
+- `customllm/base_llm_chat.py`
+
+### 2. 多重处理问题
+某些场景下thinking内容被多次处理:
+- **ask() API**: `customllm/base_llm_chat.py:generate_summary()` 已经处理,但 `citu_app.py` 又重复处理
+- **ask_agent() API**: `agent/tools/summary_generation.py` 处理后,`customllm/base_llm_chat.py:generate_summary()` 又处理一次
+
+### 3. 处理遗漏
+`agent/tools/sql_generation.py` 中的解释性文本可能包含thinking内容,但没有被处理。
+
+## 重构方案
+
+### 核心原则
+**统一在最底层处理** - 只在 `customllm/base_llm_chat.py` 中处理thinking内容,其他地方不再重复处理。
+
+### 架构合理性
+- `customllm/base_llm_chat.py` 是所有LLM响应的**统一出口**
+- 无论是 `ask()` API 还是 `ask_agent()` API,最终都会调用到这里
+- 在数据源头处理thinking内容,确保一致性
+
+## 具体修改内容
+
+### 1. 保留的处理逻辑
+
+**`customllm/base_llm_chat.py`** - 统一处理中心:
+
+```python
+# generate_sql() 方法中处理解释性文本
+if not DISPLAY_RESULT_THINKING:
+    explanation = self._remove_thinking_content(explanation)
+
+# generate_summary() 方法中处理摘要内容  
+if not display_thinking:
+    summary = self._remove_thinking_content(summary)
+
+# chat_with_llm() 方法中处理聊天对话内容
+if not DISPLAY_RESULT_THINKING:
+    response = self._remove_thinking_content(response)
+
+# generate_question() 方法中处理问题生成内容
+if not DISPLAY_RESULT_THINKING:
+    response = self._remove_thinking_content(response)
+
+# generate_rewritten_question() 方法中处理问题合并内容
+if not DISPLAY_RESULT_THINKING:
+    rewritten_question = self._remove_thinking_content(rewritten_question)
+
+# generate_plotly_code() 方法中处理图表代码生成内容
+if not DISPLAY_RESULT_THINKING:
+    plotly_code = self._remove_thinking_content(plotly_code)
+```
+
+### 2. 移除的重复代码
+
+#### `agent/tools/sql_generation.py`
+- ❌ 移除 `_remove_thinking_content()` 函数定义
+- ❌ 移除 `from app_config import DISPLAY_RESULT_THINKING` 导入
+- ❌ 移除所有thinking内容处理逻辑
+- ✅ 依赖 `customllm/base_llm_chat.py` 中的统一处理
+
+#### `agent/tools/summary_generation.py`  
+- ❌ 移除 `_process_thinking_content()` 函数
+- ❌ 移除 `import app_config` 导入
+- ❌ 移除thinking内容处理调用
+- ✅ 依赖 `customllm/base_llm_chat.py:generate_summary()` 中的统一处理
+
+#### `citu_app.py`
+- ❌ 移除 `_remove_thinking_content()` 函数定义
+- ❌ 移除 `DISPLAY_RESULT_THINKING` 导入
+- ❌ 移除所有thinking内容处理调用
+- ✅ 直接使用已处理的结果
+
+## 调用链路分析
+
+### ask() API 调用链路
+```
+citu_app.py:ask_full() 
+→ vn.ask() 
+→ customllm/base_llm_chat.py:ask() 
+→ customllm/base_llm_chat.py:generate_summary()  # ✅ 在这里统一处理thinking
+→ 返回到 citu_app.py  # ✅ 使用已处理的结果
+```
+
+### ask_agent() API 调用链路  
+```
+citu_app.py:ask_agent() 
+→ agent/citu_agent.py:process_question()
+→ agent/tools/summary_generation.py:generate_summary()  # ✅ 不再处理thinking
+→ 内部调用 vn.generate_summary() 
+→ customllm/base_llm_chat.py:generate_summary()  # ✅ 在这里统一处理thinking
+→ 返回到 agent 层  # ✅ 使用已处理的结果
+→ 最终返回到 citu_app.py
+```
+
+### SQL生成解释性文本处理
+```
+agent/tools/sql_generation.py:generate_sql()
+→ vn.generate_sql()
+→ customllm/base_llm_chat.py:generate_sql()  # ✅ 在这里统一处理thinking
+→ 保存到 vn.last_llm_explanation  # ✅ 已处理的解释性文本
+→ 返回到 agent 层  # ✅ 使用已处理的结果
+```
+
+## 配置参数控制
+
+### DISPLAY_RESULT_THINKING 参数作用范围
+当 `DISPLAY_RESULT_THINKING = False` 时,以下所有内容的thinking标签都会被自动移除:
+
+1. **摘要生成**: `customllm/base_llm_chat.py:generate_summary()`
+2. **SQL生成解释性文本**: `customllm/base_llm_chat.py:generate_sql()` 
+3. **聊天对话**: `customllm/base_llm_chat.py:chat_with_llm()`
+4. **问题生成**: `customllm/base_llm_chat.py:generate_question()`
+5. **问题合并**: `customllm/base_llm_chat.py:generate_rewritten_question()`
+6. **图表代码生成**: `customllm/base_llm_chat.py:generate_plotly_code()`
+7. **API返回结果**: 所有通过Vanna实例返回的内容
+
+## 测试验证
+
+创建了测试脚本 `test_thinking_control.py` 来验证:
+
+1. **thinking内容移除功能**: 测试各种thinking标签格式的正确移除
+2. **配置集成**: 验证配置参数的正确导入和使用
+3. **Vanna实例**: 验证实际Vanna实例的thinking处理功能
+
+## 优势总结
+
+### 1. 避免重复处理
+- 消除了多重处理导致的潜在问题
+- 减少了性能开销
+
+### 2. 统一控制点
+- 只需要在一个地方维护thinking处理逻辑
+- 配置参数的影响范围清晰明确
+
+### 3. 架构简化
+- 移除了大量重复代码
+- 降低了维护复杂性
+
+### 4. 一致性保证
+- 所有thinking内容处理都遵循相同的逻辑
+- 避免了不同处理方式导致的不一致问题
+
+## 后续维护
+
+### 新增LLM响应处理时
+- 只需要在 `customllm/base_llm_chat.py` 中添加thinking处理
+- 不需要在其他地方重复实现
+
+### 修改thinking处理逻辑时
+- 只需要修改 `customllm/base_llm_chat.py` 中的 `_remove_thinking_content()` 方法
+- 修改会自动影响所有使用场景
+
+### 配置参数调整时
+- 只需要修改 `app_config.py` 中的 `DISPLAY_RESULT_THINKING` 值
+- 所有相关功能会自动响应配置变化