فهرست منبع

解决了无法生成SQL时,仍然发送给valid_sql()的问题。

wangxq 1 ماه پیش
والد
کامیت
29dab53747
2فایلهای تغییر یافته به همراه149 افزوده شده و 76 حذف شده
  1. 124 67
      test/custom_react_agent/agent.py
  2. 25 9
      test/custom_react_agent/api.py

+ 124 - 67
test/custom_react_agent/agent.py

@@ -185,17 +185,40 @@ class CustomReactAgent:
         
         # 检查是否需要分析验证错误
         next_step = state.get("suggested_next_step")
-        if next_step == "analyze_validation_error":
-            # 查找最近的 valid_sql 错误信息
-            for msg in reversed(state["messages"]):
-                if isinstance(msg, ToolMessage) and msg.name == "valid_sql":
-                    error_guidance = self._generate_validation_error_guidance(msg.content)
-                    messages_for_llm.append(SystemMessage(content=error_guidance))
-                    logger.info("   ✅ 已添加SQL验证错误指导")
-                    break
-        elif next_step and next_step != "analyze_validation_error":
-            instruction = f"Suggestion: Consider using the '{next_step}' tool for the next step."
-            messages_for_llm.append(SystemMessage(content=instruction))
+        
+        # 行为指令与工具建议分离
+        real_tools = {'valid_sql', 'run_sql'}
+        
+        if next_step:
+            if next_step in real_tools:
+                # 场景1: 建议调用一个真实的工具
+                instruction = f"Suggestion: Based on the previous step, please use the '{next_step}' tool to continue."
+                messages_for_llm.append(SystemMessage(content=instruction))
+                logger.info(f"   ✅ 已添加工具建议: {next_step}")
+
+            elif next_step == "analyze_validation_error":
+                # 场景2: 分析SQL验证错误(特殊指令)
+                for msg in reversed(state["messages"]):
+                    if isinstance(msg, ToolMessage) and msg.name == "valid_sql":
+                        error_guidance = self._generate_validation_error_guidance(msg.content)
+                        messages_for_llm.append(SystemMessage(content=error_guidance))
+                        logger.info("   ✅ 已添加SQL验证错误指导")
+                        break
+            
+            elif next_step == 'summarize_final_answer':
+                # 场景3: 总结最终答案(行为指令)
+                instruction = "System Instruction: The SQL query was executed successfully. Please analyze the JSON data in the last message and summarize it in natural, user-friendly language as the final answer. Do not expose the raw JSON data or SQL statements in your response."
+                messages_for_llm.append(SystemMessage(content=instruction))
+                logger.info("   ✅ 已添加 '总结答案' 行为指令")
+
+            elif next_step == 'answer_with_common_sense':
+                # 场景4: 基于常识回答(特殊指令)
+                instruction = (
+                    "无法为当前问题生成有效的SQL查询。失败原因已在上下文中提供。"
+                    "请你直接利用自身的知识库来回答用户的问题,不要再重复解释失败的原因。"
+                )
+                messages_for_llm.append(SystemMessage(content=instruction))
+                logger.info("✅ 已添加 '常识回答' 行为指令")
 
         # 添加重试机制处理网络连接问题
         import asyncio
@@ -413,50 +436,47 @@ class CustomReactAgent:
 
     async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
         """
-        异步信息组装节点:为需要上下文的工具注入历史消息。
+        准备工具输入。
+        - 强制修正generate_sql的question参数,确保使用用户原始问题。
+        - 为generate_sql注入经过严格过滤的、干净的对话历史。
         """
-        logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")
-        
-        # 🎯 打印 state 全部信息
-        # self._print_state_info(state, "prepare_tool_input")
-        
-        last_message = state["messages"][-1]
-        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+        last_message = state['messages'][-1]
+        if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
             return {"messages": [last_message]}
 
-        # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
+        # 强制修正LLM幻觉出的问题
+        for tool_call in last_message.tool_calls:
+            if tool_call['name'] == 'generate_sql':
+                original_user_question = next((msg.content for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)), None)
+                if original_user_question and tool_call['args'].get('question') != original_user_question:
+                    logger.warning(
+                        f"修正 'generate_sql' 的问题参数。\n"
+                        f"  - LLM提供: '{tool_call['args'].get('question')}'\n"
+                        f"  + 修正为: '{original_user_question}'"
+                    )
+                    tool_call['args']['question'] = original_user_question
+
+        # 恢复原始的、更健壮的历史消息过滤和注入逻辑
         new_tool_calls = []
         for tool_call in last_message.tool_calls:
             if tool_call["name"] == "generate_sql":
-                logger.info("   检测到 generate_sql 调用,注入历史消息。")
-                # 复制一份以避免修改原始 tool_call
+                logger.info("检测到 generate_sql 调用,开始注入历史消息。")
                 modified_args = tool_call["args"].copy()
                 
-                # 🎯 改进的消息过滤逻辑:只保留有用的对话上下文,排除当前问题
                 clean_history = []
-                messages_except_current = state["messages"][:-1]  # 排除最后一个消息(当前问题)
+                messages_except_current = state["messages"][:-1]
                 
                 for msg in messages_except_current:
                     if isinstance(msg, HumanMessage):
-                        # 保留历史用户消息(但不包括当前问题)
-                        clean_history.append({
-                            "type": "human",
-                            "content": msg.content
-                        })
+                        clean_history.append({"type": "human", "content": msg.content})
                     elif isinstance(msg, AIMessage):
-                        # 只保留最终的、面向用户的回答(包含"[Formatted Output]"的消息)
-                        if msg.content and "[Formatted Output]" in msg.content:
-                            # 去掉 "[Formatted Output]" 标记,只保留真正的回答
-                            clean_content = msg.content.replace("[Formatted Output]\n", "")
-                            clean_history.append({
-                                "type": "ai",
-                                "content": clean_content
-                            })
-                        # 跳过包含工具调用的 AIMessage(中间步骤)
-                    # 跳过所有 ToolMessage(工具执行结果)
+                        if not msg.tool_calls and msg.content:
+                            clean_content = msg.content.replace("[Formatted Output]\n", "").strip()
+                            if clean_content:
+                                clean_history.append({"type": "ai", "content": clean_content})
                 
                 modified_args["history_messages"] = clean_history
-                logger.info(f"   注入了 {len(clean_history)} 条过滤后的历史消息")
+                logger.info(f"注入了 {len(clean_history)} 条过滤后的历史消息")
                 
                 new_tool_calls.append({
                     "name": tool_call["name"],
@@ -466,10 +486,31 @@ class CustomReactAgent:
             else:
                 new_tool_calls.append(tool_call)
         
-        # 用包含修改后参数的新消息替换掉原来的
         last_message.tool_calls = new_tool_calls
         return {"messages": [last_message]}
 
+    def _filter_and_format_history(self, messages: list) -> list:
+        """
+        过滤和格式化历史消息,为generate_sql工具提供干净的上下文。
+        只保留历史中的用户提问和AI的最终回答。
+        """
+        clean_history = []
+        # 处理除最后一个(即当前的工具调用)之外的所有消息
+        messages_to_process = messages[:-1]
+
+        for msg in messages_to_process:
+            if isinstance(msg, HumanMessage):
+                clean_history.append({"type": "human", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                # 只保留最终的、面向用户的回答(不包含工具调用的纯文本回答)
+                if not msg.tool_calls and msg.content:
+                    # 移除可能存在的格式化标记
+                    clean_content = msg.content.replace("[Formatted Output]\n", "").strip()
+                    if clean_content:
+                        clean_history.append({"type": "ai", "content": clean_content})
+        
+        return clean_history
+
     async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
         """在工具执行后,更新 suggested_next_step 并清理参数。"""
         logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
@@ -483,13 +524,12 @@ class CustomReactAgent:
         next_step = None
 
         if tool_name == 'generate_sql':
-            if "失败" in tool_output or "无法生成" in tool_output:
+            # 使用 .lower() 将输出转为小写,可以同时捕获 "failed" 和 "Failed" 等情况
+            tool_output_lower = tool_output.lower()
+            if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
                 next_step = 'answer_with_common_sense'
             else:
                 next_step = 'valid_sql'
-            
-            # 🎯 清理 generate_sql 的 history_messages 参数,设置为空字符串
-            # self._clear_history_messages_parameter(state['messages'])
         
         elif tool_name == 'valid_sql':
             if "失败" in tool_output:
@@ -539,20 +579,26 @@ class CustomReactAgent:
         """异步生成API格式的数据结构"""
         logger.info("📊 异步生成API格式数据...")
         
-        # 提取基础响应内容
         last_message = state['messages'][-1]
         response_content = last_message.content
         
-        # 去掉格式化标记,获取纯净的回答
         if response_content.startswith("[Formatted Output]\n"):
             response_content = response_content.replace("[Formatted Output]\n", "")
         
-        # 初始化API数据结构
         api_data = {
             "response": response_content
         }
+
+        # --- 新增逻辑:为 answer_with_common_sense 场景拼接响应 ---
+        if state.get("suggested_next_step") == 'answer_with_common_sense':
+            failure_reason = self._find_generate_sql_failure_reason(state['messages'])
+            if failure_reason:
+                # 将 "Database query failed. Reason: " 前缀移除,使其更自然
+                cleaned_reason = failure_reason.replace("Database query failed. Reason:", "").strip()
+                # 拼接失败原因和LLM的常识回答
+                api_data["response"] = f"{cleaned_reason}\n\n{response_content}"
+                logger.info("   ✅ 已成功拼接 '失败原因' 和 '常识回答'")
         
-        # 提取SQL和数据记录
         sql_info = await self._async_extract_sql_and_data(state['messages'])
         if sql_info['sql']:
             api_data["sql"] = sql_info['sql']
@@ -565,6 +611,18 @@ class CustomReactAgent:
         logger.info(f"   API数据生成完成,包含字段: {list(api_data.keys())}")
         return api_data
 
+    def _find_generate_sql_failure_reason(self, messages: List[BaseMessage]) -> Optional[str]:
+        """从后向前查找最近一次generate_sql失败的原因"""
+        for msg in reversed(messages):
+            if isinstance(msg, ToolMessage) and msg.name == 'generate_sql':
+                # 找到最近的generate_sql工具消息
+                if "failed" in msg.content.lower() or "失败" in msg.content.lower():
+                    return msg.content
+                else:
+                    # 如果是成功的消息,说明当前轮次没有失败,停止查找
+                    return None
+        return None
+
     async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
         """异步从消息历史中提取SQL和数据记录"""
         result = {"sql": None, "records": None}
@@ -995,33 +1053,32 @@ Please intelligently choose whether to query the database based on the nature of
     def _generate_validation_error_guidance(self, validation_error: str) -> str:
         """根据验证错误类型生成具体的修复指导"""
         
-        if "字段不存在" in validation_error or "column" in validation_error.lower():
-            return """SQL验证失败:字段不存在错误。
-处理建议:
-1. 检查字段名是否拼写正确
-2. 如果字段确实不存在,请告知用户缺少该字段,并基于常识提供答案
-3. 不要尝试修复不存在的字段,直接给出基于常识的解释"""
-
-        elif "表不存在" in validation_error or "table" in validation_error.lower():
-            return """SQL验证失败:表不存在错误。
-处理建议:
-1. 检查表名是否拼写正确
-2. 如果表确实不存在,请告知用户该数据不在数据库中
-3. 基于问题性质,提供常识性的答案或建议用户确认数据源"""
-
-        elif "语法错误" in validation_error or "syntax error" in validation_error.lower():
+        # 优先处理最常见的语法错误
+        if "语法错误" in validation_error or "syntax error" in validation_error.lower():
             return """SQL验证失败:语法错误。
 处理建议:
 1. 仔细检查SQL语法(括号、引号、关键词等)
 2. 修复语法错误后,调用 valid_sql 工具重新验证
 3. 常见问题:缺少逗号、括号不匹配、关键词拼写错误"""
 
+        # 新增的合并条件,处理所有“不存在”类型的错误
+        elif ("不存在" in validation_error or 
+              "no such table" in validation_error.lower() or
+              "does not exist" in validation_error.lower()):
+            return """SQL验证失败:表或字段不存在。
+处理建议:
+1. 请明确告知用户,因数据库缺少相应的表或字段,无法通过SQL查询获取准确答案。
+2. 请基于你的通用知识和常识,直接回答用户的问题或提供相关解释。
+3. 请不要再尝试生成或修复SQL。"""
+
+        # 其他原有分支可以被新逻辑覆盖,故移除
+        # Fallback 到通用的错误处理
         else:
             return f"""SQL验证失败:{validation_error}
 处理建议:
-1. 如果是语法问题,请修复后重新验证
-2. 如果是字段/表不存在,请向用户说明并提供基于常识的答案
-3. 避免猜测或编造数据库中不存在的信息"""
+1. 如果这是一个可以修复的错误,请尝试修正并再次验证。
+2. 如果错误表明数据缺失,请直接向用户说明情况。
+3. 避免猜测或编造数据库中不存在的信息"""
 
     # === 参数错误诊断和修复函数 ===
     

+ 25 - 9
test/custom_react_agent/api.py

@@ -35,7 +35,7 @@ _agent_instance: Optional[CustomReactAgent] = None
 _redis_client: Optional[redis.Redis] = None
 
 def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
-    """验证请求数据"""
+    """验证请求数据,并支持从thread_id中推断user_id"""
     errors = []
     
     # 验证 question(必填)
@@ -45,20 +45,36 @@ def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
     elif len(question) > 2000:
         errors.append('问题长度不能超过2000字符')
     
-    # 验证 user_id(可选,默认为"guest")
-    user_id = data.get('user_id', 'guest')
+    # 优先获取 thread_id
+    thread_id = data.get('thread_id') or data.get('conversation_id')
+    
+    # 获取 user_id,但暂不设置默认值
+    user_id = data.get('user_id')
+
+    # 如果没有传递 user_id,则尝试从 thread_id 中推断
+    if not user_id:
+        if thread_id and ':' in thread_id:
+            inferred_user_id = thread_id.split(':', 1)[0]
+            if inferred_user_id:
+                user_id = inferred_user_id
+                logger.info(f"👤 未提供user_id,从 thread_id '{thread_id}' 中推断出: '{user_id}'")
+            else:
+                # 如果拆分结果为空,则使用默认值
+                user_id = 'guest'
+        else:
+            # 如果 thread_id 不存在或格式不正确,则使用默认值
+            user_id = 'guest'
+
+    # 验证 user_id 长度
     if user_id and len(user_id) > 50:
         errors.append('用户ID长度不能超过50字符')
     
-    # thread_id 和 conversation_id 处理:如果thread_id没有值,就使用conversation_id的值
-    thread_id = data.get('thread_id') or data.get('conversation_id')
-    
     # 用户ID与会话ID一致性校验
-    if thread_id and user_id != 'guest':
+    if thread_id:
         if ':' not in thread_id:
             errors.append('会话ID格式无效,期望格式为 user_id:timestamp')
         else:
-            thread_user_id = thread_id.split(':', 1)[0]  # 取冒号前的部分作为用户ID
+            thread_user_id = thread_id.split(':', 1)[0]
             if thread_user_id != user_id:
                 errors.append(f'会话归属验证失败:会话ID [{thread_id}] 不属于当前用户 [{user_id}]')
     
@@ -67,7 +83,7 @@ def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
     
     return {
         'question': question.strip(),
-        'user_id': user_id or 'guest',
+        'user_id': user_id,
         'thread_id': thread_id  # 可选,不传则自动生成新会话
     }