|
@@ -140,6 +140,13 @@ class CustomReactAgent:
|
|
|
# 5. 构建 StateGraph
|
|
|
self.agent_executor = self._create_graph()
|
|
|
logger.info(" StateGraph 已构建并编译。")
|
|
|
+
|
|
|
+ # 6. 显示消息裁剪配置状态
|
|
|
+ if config.MESSAGE_TRIM_ENABLED:
|
|
|
+ logger.info(f" 消息裁剪已启用: 保留消息数={config.MESSAGE_TRIM_COUNT}, 搜索限制={config.MESSAGE_TRIM_SEARCH_LIMIT}")
|
|
|
+ else:
|
|
|
+ logger.info(" 消息裁剪已禁用")
|
|
|
+
|
|
|
logger.info("✅ CustomReactAgent 初始化完成。")
|
|
|
|
|
|
async def _reinitialize_checkpointer(self):
|
|
@@ -179,10 +186,100 @@ class CustomReactAgent:
|
|
|
await self.redis_client.aclose()
|
|
|
logger.info("✅ Redis客户端已关闭。")
|
|
|
|
|
|
+ def _trim_messages_node(self, state: AgentState) -> AgentState:
|
|
|
+ """
|
|
|
+ 消息裁剪节点:确保从HumanMessage开始的完整对话轮次
|
|
|
+
|
|
|
+ 裁剪逻辑:
|
|
|
+ 1. 如果消息数 <= MESSAGE_TRIM_COUNT,不裁剪
|
|
|
+ 2. 取最近 MESSAGE_TRIM_COUNT 条消息
|
|
|
+ 3. 检查第一条是否为HumanMessage
|
|
|
+ 4. 如果不是,向前搜索最多 MESSAGE_TRIM_SEARCH_LIMIT 条消息找HumanMessage
|
|
|
+ 5. 如果找到,从HumanMessage开始保留;如果没找到,从原目标位置开始并记录WARNING
|
|
|
+ """
|
|
|
+ messages = state.get("messages", [])
|
|
|
+ thread_id = state.get("thread_id", "unknown")
|
|
|
+ original_count = len(messages)
|
|
|
+
|
|
|
+ # 1. 检查是否需要裁剪
|
|
|
+ if original_count <= config.MESSAGE_TRIM_COUNT:
|
|
|
+ logger.info(f"[{thread_id}] 消息数量 {original_count} <= {config.MESSAGE_TRIM_COUNT},无需裁剪")
|
|
|
+ return state
|
|
|
+
|
|
|
+ # 2. 开始裁剪逻辑
|
|
|
+ target_count = config.MESSAGE_TRIM_COUNT
|
|
|
+ search_limit = config.MESSAGE_TRIM_SEARCH_LIMIT
|
|
|
+
|
|
|
+ # 3. 取最近的target_count条消息
|
|
|
+ recent_start_index = original_count - target_count
|
|
|
+ recent_messages = messages[-target_count:]
|
|
|
+ first_msg = recent_messages[0]
|
|
|
+
|
|
|
+ if config.DEBUG_MODE:
|
|
|
+ logger.info(f"[{thread_id}] 开始消息裁剪分析:")
|
|
|
+ logger.info(f" 原始消息数: {original_count}")
|
|
|
+ logger.info(f" 目标保留数: {target_count}")
|
|
|
+ logger.info(f" 初始截取索引: {recent_start_index}")
|
|
|
+ logger.info(f" 第一条消息类型: {first_msg.type}")
|
|
|
+
|
|
|
+ final_start_index = recent_start_index
|
|
|
+
|
|
|
+ # 4. 检查第一条是否为HumanMessage
|
|
|
+ if first_msg.type != "human":
|
|
|
+ if config.DEBUG_MODE:
|
|
|
+ logger.info(f" 第一条不是HumanMessage,开始向前搜索...")
|
|
|
+
|
|
|
+ # 5. 向前搜索HumanMessage
|
|
|
+ found_human = False
|
|
|
+ search_start = recent_start_index - 1
|
|
|
+ search_end = max(0, recent_start_index - search_limit)
|
|
|
+
|
|
|
+ for i in range(search_start, search_end - 1, -1):
|
|
|
+ if i >= 0 and messages[i].type == "human":
|
|
|
+ final_start_index = i
|
|
|
+ found_human = True
|
|
|
+ if config.DEBUG_MODE:
|
|
|
+ logger.info(f" 在索引 {i} 找到HumanMessage,向前扩展 {recent_start_index - i} 条")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 6. 如果没找到HumanMessage,记录WARNING
|
|
|
+ if not found_human:
|
|
|
+ logger.warning(f"[{thread_id}] 在向前 {search_limit} 条消息中未找到HumanMessage,从原目标位置 {recent_start_index} 开始截断")
|
|
|
+ final_start_index = recent_start_index
|
|
|
+ else:
|
|
|
+ if config.DEBUG_MODE:
|
|
|
+ logger.info(f" 第一条就是HumanMessage,无需向前搜索")
|
|
|
+
|
|
|
+ # 7. 执行裁剪
|
|
|
+ final_messages = messages[final_start_index:]
|
|
|
+ final_count = len(final_messages)
|
|
|
+
|
|
|
+ # 8. 记录裁剪结果
|
|
|
+ logger.info(f"[{thread_id}] 消息裁剪完成: {original_count} → {final_count} 条 (从索引 {final_start_index} 开始)")
|
|
|
+
|
|
|
+ if config.DEBUG_MODE:
|
|
|
+ logger.info(f" 裁剪详情:")
|
|
|
+ logger.info(f" 删除消息数: {original_count - final_count}")
|
|
|
+ logger.info(f" 保留消息范围: [{final_start_index}:{original_count}]")
|
|
|
+
|
|
|
+ # 显示前几条和后几条消息类型
|
|
|
+ if final_count > 0:
|
|
|
+ first_few = min(3, final_count)
|
|
|
+ last_few = min(3, final_count)
|
|
|
+ logger.info(f" 前{first_few}条类型: {[msg.type for msg in final_messages[:first_few]]}")
|
|
|
+ if final_count > 3:
|
|
|
+ logger.info(f" 后{last_few}条类型: {[msg.type for msg in final_messages[-last_few:]]}")
|
|
|
+
|
|
|
+ return {**state, "messages": final_messages}
|
|
|
+
|
|
|
def _create_graph(self):
|
|
|
"""定义并编译最终的、正确的 StateGraph 结构。"""
|
|
|
builder = StateGraph(AgentState)
|
|
|
|
|
|
+ # 添加消息裁剪节点(如果启用)
|
|
|
+ if config.MESSAGE_TRIM_ENABLED:
|
|
|
+ builder.add_node("trim_messages", self._trim_messages_node)
|
|
|
+
|
|
|
# 定义所有需要的节点 - 全部改为异步
|
|
|
builder.add_node("agent", self._async_agent_node)
|
|
|
builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
|
|
@@ -191,7 +288,16 @@ class CustomReactAgent:
|
|
|
builder.add_node("format_final_response", self._async_format_final_response_node)
|
|
|
|
|
|
# 建立正确的边连接
|
|
|
- builder.set_entry_point("agent")
|
|
|
+ if config.MESSAGE_TRIM_ENABLED:
|
|
|
+ # 启用裁剪:START → trim_messages → agent
|
|
|
+ builder.set_entry_point("trim_messages")
|
|
|
+ builder.add_edge("trim_messages", "agent")
|
|
|
+ logger.info(" ✅ 消息裁剪节点已启用,工作流: START → trim_messages → agent")
|
|
|
+ else:
|
|
|
+ # 禁用裁剪:START → agent
|
|
|
+ builder.set_entry_point("agent")
|
|
|
+ logger.info(" ⚠️ 消息裁剪节点已禁用,工作流: START → agent")
|
|
|
+
|
|
|
builder.add_conditional_edges(
|
|
|
"agent",
|
|
|
self._async_should_continue,
|