让我基于你的工作流图来分析state裁剪的最佳实践:基于你的工作流图,我来详细回答这两个关键问题:
推荐:在Agent节点之前添加专门的裁剪节点
def trim_messages_node(state):
"""在Agent处理前裁剪消息"""
messages = state.get("messages", [])
if len(messages) <= 100:
return state
# 保留最后100条消息
trimmed_messages = messages[-100:]
return {**state, "messages": trimmed_messages}
# 修改工作流:START → trim_messages → Agent
builder.add_edge(START, "trim_messages")
builder.add_edge("trim_messages", "agent")
为什么选择这个位置?
基于你的checkpoint数据,你的工作流是:
用户问题 → 生成SQL → 验证SQL → 执行SQL
问题场景:
# 如果用户说:"修改一下这个SQL的WHERE条件"
# 但只保留最后10条消息,可能会丢失:
# 1. 原始用户问题:"请问系统中哪个服务区档口最多?"
# 2. 生成的SQL语句
# 3. 验证结果
# 结果:Agent不知道要修改哪个SQL!
如果只保留最后1条消息,你的SQL工具链会完全失效:
def smart_trim_for_sql_workflow(state):
"""专门为SQL工作流设计的智能裁剪"""
messages = state.get("messages", [])
if len(messages) <= 100:
return state
# 1. 总是保留最后20条消息(当前工作上下文)
recent_messages = messages[-20:]
# 2. 从历史中保留关键消息
important_messages = []
for msg in messages[:-20]:
# 保留所有用户问题(重要上下文)
if msg.type == "human":
important_messages.append(msg)
# 保留SQL相关的工具结果
elif (msg.type == "tool" and
msg.name in ["generate_sql", "valid_sql", "run_sql"]):
important_messages.append(msg)
# 3. 限制历史消息数量,保持在80条以内
if len(important_messages) > 80:
important_messages = important_messages[-80:]
# 4. 组合:重要历史 + 最近消息
final_messages = important_messages + recent_messages
return {**state, "messages": final_messages}
# 推荐的裁剪节点
def trim_messages_node(state):
messages = state.get("messages", [])
if len(messages) <= 100:
return state
# 保留最后50条 + 关键历史消息
recent = messages[-50:]
# 从更早的历史中提取关键消息
historical = messages[:-50]
key_messages = [
msg for msg in historical
if msg.type == "human" or # 用户问题
(msg.type == "tool" and msg.name in ["generate_sql", "run_sql"]) # 关键工具结果
][-50:] # 最多50条关键历史
return {**state, "messages": key_messages + recent}
总结:对于你的SQL工作流,建议保留足够的上下文(50-100条),并使用智能裁剪确保保留SQL相关的关键信息。
你的理解非常精准!这是一个很重要的观察。你的理解完全正确!这是一个非常重要的架构设计洞察。
# START → trim_messages → Agent → Tools → ...
def trim_at_start(state):
messages = state.get("messages", [])
# 此时的messages包含:
# 1. 从checkpoint恢复的完整历史
# 2. 用户刚输入的新消息
# 但是!当前对话轮次还没开始处理
return {"messages": messages[-100:]}
关键优势:当前轮次还没开始,所以裁剪是"安全"的
想象这个危险场景:
# 错误的流程:Agent → Tools → trim_messages → update_state
# 当前轮次进行中:
messages = [
HumanMessage("请修改刚才的SQL查询"), # 1. 用户请求
AIMessage("调用generate_sql工具", tool_calls=[...]), # 2. Agent响应
ToolMessage("SELECT * FROM ...", name="generate_sql"), # 3. 生成新SQL
# 如果在这里裁剪!可能会裁掉步骤1或2
]
# 问题:破坏了当前轮次的完整性
# Agent后续可能不知道用户想要什么
# START时的state:
{
"messages": [
# ... 历史消息(可能很多)
HumanMessage("请问哪个服务区档口最多?"), # 历史上下文
AIMessage("调用generate_sql..."), # 历史上下文
ToolMessage("SELECT sa.service_area..."), # 重要!之前的SQL
HumanMessage("修改刚才的SQL查询") # 新输入
]
}
# 裁剪后保留了关键上下文,当前轮次可以正常开始
# 在Agent→Tools之间裁剪:
{
"messages": [
AIMessage("调用generate_sql..."), # Agent刚决定调用工具
ToolMessage("SELECT sa.service_area..."), # 工具刚返回结果
# 如果这时裁剪,可能丢失用户的原始请求!
]
}
# Agent后续不知道为什么要处理这个SQL
🎯 保护当前对话轮次的完整性
# 你的工作流应该是:
builder.add_edge(START, "trim_messages") # 在最开始裁剪
builder.add_edge("trim_messages", "agent") # 然后开始正常流程
def safe_trim_at_start(state):
"""在开始时安全裁剪 - 不影响当前轮次"""
messages = state.get("messages", [])
if len(messages) <= 100:
return state
# 安全裁剪:当前轮次还没开始
return {
**state,
"messages": messages[-100:]
}
你的这个洞察非常有价值!在START之后裁剪确实是最安全的策略,因为它保证了当前对话轮次的完整性不会被破坏。
这是一个很好的架构设计原则:状态管理操作(如裁剪)应该在业务逻辑开始之前完成,而不是在业务逻辑进行中进行。
如果messages >= 100条:
def trim_messages_simple(state, target_count=100):
"""简单的消息裁剪:保证从HumanMessage开始"""
messages = state.get("messages", [])
if len(messages) <= target_count:
return state # 不需要裁剪
# 1. 先取最近的target_count条消息
recent_messages = messages[-target_count:]
# 2. 检查第一条是否为HumanMessage
start_index = 0
first_msg = recent_messages[0]
if first_msg.type != "human":
# 3. 往前找HumanMessage
for i in range(len(messages) - target_count - 1, -1, -1):
if messages[i].type == "human":
# 找到了,从这里开始保留
start_index = i
break
else:
# 没找到HumanMessage,就用原来的
start_index = len(messages) - target_count
else:
# 第一条就是HumanMessage,直接用
start_index = len(messages) - target_count
# 4. 保留从start_index开始的所有消息
final_messages = messages[start_index:]
print(f"消息裁剪: {len(messages)} → {len(final_messages)}条")
return {**state, "messages": final_messages}
# 在你的graph中使用
def trim_messages_node(state):
"""消息裁剪节点"""
return trim_messages_simple(state, target_count=100)
# 添加到工作流
builder.add_node("trim_messages", trim_messages_node)
builder.add_edge(START, "trim_messages")
builder.add_edge("trim_messages", "agent")
# 假设有120条消息(索引0-119)
messages = [
# ... 前20条消息 (索引0-19)
HumanMessage("问题A"), # 索引20
AIMessage("回答A"), # 索引21
# ... 中间消息 (索引22-98)
AIMessage("工具调用..."), # 索引99
ToolMessage("结果..."), # 索引100
HumanMessage("问题B"), # 索引101
AIMessage("回答B"), # 索引102
# ... 最近的消息 (索引103-119)
]
# 执行裁剪:
# 1. 取最近100条:messages[-100:] → 从索引20开始到索引119
# 2. 第一条是messages[20] = HumanMessage("问题A") → 是Human,符合要求
# 3. 保留从索引20开始的100条消息
# 另一个场景:如果第一条不是Human
messages2 = [
# ... 前18条消息 (索引0-17)
HumanMessage("问题A"), # 索引18
AIMessage("回答A"), # 索引19
# ... 中间消息 (索引20-98)
AIMessage("工具调用..."), # 索引99
ToolMessage("结果..."), # 索引100 ← 这是最近100条中的第一条,不是Human
HumanMessage("问题B"), # 索引101
# ... 最近的消息 (索引102-119)
]
# 执行裁剪:
# 1. 取最近100条:messages2[-100:] → 从索引20开始
# 2. 第一条是messages2[20] = AIMessage("工具调用...") → 不是Human
# 3. 往前找Human:找到索引18的HumanMessage("问题A")
# 4. 保留从索引18开始的102条消息