瀏覽代碼

react_agent 添加state 裁剪功能.

wangxq 3 周之前
父節點
當前提交
e5ef3966cf
共有 2 個文件被更改,包括 113 次插入2 次删除
  1. 107 1
      react_agent/agent.py
  2. 6 1
      react_agent/config.py

+ 107 - 1
react_agent/agent.py

@@ -140,6 +140,13 @@ class CustomReactAgent:
         # 5. 构建 StateGraph
         self.agent_executor = self._create_graph()
         logger.info("   StateGraph 已构建并编译。")
+        
+        # 6. 显示消息裁剪配置状态
+        if config.MESSAGE_TRIM_ENABLED:
+            logger.info(f"   消息裁剪已启用: 保留消息数={config.MESSAGE_TRIM_COUNT}, 搜索限制={config.MESSAGE_TRIM_SEARCH_LIMIT}")
+        else:
+            logger.info("   消息裁剪已禁用")
+        
         logger.info("✅ CustomReactAgent 初始化完成。")
 
     async def _reinitialize_checkpointer(self):
@@ -179,10 +186,100 @@ class CustomReactAgent:
             await self.redis_client.aclose()
             logger.info("✅ Redis客户端已关闭。")
 
+    def _trim_messages_node(self, state: AgentState) -> AgentState:
+        """
+        消息裁剪节点:确保从HumanMessage开始的完整对话轮次
+        
+        裁剪逻辑:
+        1. 如果消息数 <= MESSAGE_TRIM_COUNT,不裁剪
+        2. 取最近 MESSAGE_TRIM_COUNT 条消息
+        3. 检查第一条是否为HumanMessage
+        4. 如果不是,向前搜索最多 MESSAGE_TRIM_SEARCH_LIMIT 条消息找HumanMessage
+        5. 如果找到,从HumanMessage开始保留;如果没找到,从原目标位置开始并记录WARNING
+        """
+        messages = state.get("messages", [])
+        thread_id = state.get("thread_id", "unknown")
+        original_count = len(messages)
+        
+        # 1. 检查是否需要裁剪
+        if original_count <= config.MESSAGE_TRIM_COUNT:
+            logger.info(f"[{thread_id}] 消息数量 {original_count} <= {config.MESSAGE_TRIM_COUNT},无需裁剪")
+            return state
+        
+        # 2. 开始裁剪逻辑
+        target_count = config.MESSAGE_TRIM_COUNT
+        search_limit = config.MESSAGE_TRIM_SEARCH_LIMIT
+        
+        # 3. 取最近的target_count条消息
+        recent_start_index = original_count - target_count
+        recent_messages = messages[-target_count:]
+        first_msg = recent_messages[0]
+        
+        if config.DEBUG_MODE:
+            logger.info(f"[{thread_id}] 开始消息裁剪分析:")
+            logger.info(f"   原始消息数: {original_count}")
+            logger.info(f"   目标保留数: {target_count}")
+            logger.info(f"   初始截取索引: {recent_start_index}")
+            logger.info(f"   第一条消息类型: {first_msg.type}")
+        
+        final_start_index = recent_start_index
+        
+        # 4. 检查第一条是否为HumanMessage
+        if first_msg.type != "human":
+            if config.DEBUG_MODE:
+                logger.info(f"   第一条不是HumanMessage,开始向前搜索...")
+            
+            # 5. 向前搜索HumanMessage
+            found_human = False
+            search_start = recent_start_index - 1
+            search_end = max(0, recent_start_index - search_limit)
+            
+            for i in range(search_start, search_end - 1, -1):
+                if i >= 0 and messages[i].type == "human":
+                    final_start_index = i
+                    found_human = True
+                    if config.DEBUG_MODE:
+                        logger.info(f"   在索引 {i} 找到HumanMessage,向前扩展 {recent_start_index - i} 条")
+                    break
+            
+            # 6. 如果没找到HumanMessage,记录WARNING
+            if not found_human:
+                logger.warning(f"[{thread_id}] 在向前 {search_limit} 条消息中未找到HumanMessage,从原目标位置 {recent_start_index} 开始截断")
+                final_start_index = recent_start_index
+        else:
+            if config.DEBUG_MODE:
+                logger.info(f"   第一条就是HumanMessage,无需向前搜索")
+        
+        # 7. 执行裁剪
+        final_messages = messages[final_start_index:]
+        final_count = len(final_messages)
+        
+        # 8. 记录裁剪结果
+        logger.info(f"[{thread_id}] 消息裁剪完成: {original_count} → {final_count} 条 (从索引 {final_start_index} 开始)")
+        
+        if config.DEBUG_MODE:
+            logger.info(f"   裁剪详情:")
+            logger.info(f"     删除消息数: {original_count - final_count}")
+            logger.info(f"     保留消息范围: [{final_start_index}:{original_count}]")
+            
+            # 显示前几条和后几条消息类型
+            if final_count > 0:
+                first_few = min(3, final_count)
+                last_few = min(3, final_count)
+                logger.info(f"     前{first_few}条类型: {[msg.type for msg in final_messages[:first_few]]}")
+                if final_count > 3:
+                    logger.info(f"     后{last_few}条类型: {[msg.type for msg in final_messages[-last_few:]]}")
+        
+        return {**state, "messages": final_messages}
+
     def _create_graph(self):
         """定义并编译最终的、正确的 StateGraph 结构。"""
         builder = StateGraph(AgentState)
 
+        # 添加消息裁剪节点(如果启用)
+        if config.MESSAGE_TRIM_ENABLED:
+            builder.add_node("trim_messages", self._trim_messages_node)
+        
         # 定义所有需要的节点 - 全部改为异步
         builder.add_node("agent", self._async_agent_node)
         builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
@@ -191,7 +288,16 @@ class CustomReactAgent:
         builder.add_node("format_final_response", self._async_format_final_response_node)
 
         # 建立正确的边连接
-        builder.set_entry_point("agent")
+        if config.MESSAGE_TRIM_ENABLED:
+            # 启用裁剪:START → trim_messages → agent
+            builder.set_entry_point("trim_messages")
+            builder.add_edge("trim_messages", "agent")
+            logger.info("   ✅ 消息裁剪节点已启用,工作流: START → trim_messages → agent")
+        else:
+            # 禁用裁剪:START → agent
+            builder.set_entry_point("agent")
+            logger.info("   ⚠️ 消息裁剪节点已禁用,工作流: START → agent")
+        
         builder.add_conditional_edges(
             "agent",
             self._async_should_continue,

+ 6 - 1
react_agent/config.py

@@ -60,4 +60,9 @@ HTTP_POOL_TIMEOUT = 5.0            # 连接池超时(秒)
 
 # --- 调试配置 ---
 DEBUG_MODE = True                  # 调试模式:True=完整日志,False=简化日志
-MAX_LOG_LENGTH = 1000              # 非调试模式下的最大日志长度 
+MAX_LOG_LENGTH = 1000              # 非调试模式下的最大日志长度
+
+# --- State管理配置 ---
+MESSAGE_TRIM_ENABLED = True        # 是否启用消息裁剪
+MESSAGE_TRIM_COUNT = 100          # 消息数量超过此值时触发裁剪,裁剪后保留此数量的消息
+MESSAGE_TRIM_SEARCH_LIMIT = 20    # 向前搜索HumanMessage的最大条数