|
@@ -185,17 +185,40 @@ class CustomReactAgent:
|
|
|
|
|
|
# 检查是否需要分析验证错误
|
|
|
next_step = state.get("suggested_next_step")
|
|
|
- if next_step == "analyze_validation_error":
|
|
|
- # 查找最近的 valid_sql 错误信息
|
|
|
- for msg in reversed(state["messages"]):
|
|
|
- if isinstance(msg, ToolMessage) and msg.name == "valid_sql":
|
|
|
- error_guidance = self._generate_validation_error_guidance(msg.content)
|
|
|
- messages_for_llm.append(SystemMessage(content=error_guidance))
|
|
|
- logger.info(" ✅ 已添加SQL验证错误指导")
|
|
|
- break
|
|
|
- elif next_step and next_step != "analyze_validation_error":
|
|
|
- instruction = f"Suggestion: Consider using the '{next_step}' tool for the next step."
|
|
|
- messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
+
|
|
|
+ # 行为指令与工具建议分离
|
|
|
+ real_tools = {'valid_sql', 'run_sql'}
|
|
|
+
|
|
|
+ if next_step:
|
|
|
+ if next_step in real_tools:
|
|
|
+ # 场景1: 建议调用一个真实的工具
|
|
|
+ instruction = f"Suggestion: Based on the previous step, please use the '{next_step}' tool to continue."
|
|
|
+ messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
+ logger.info(f" ✅ 已添加工具建议: {next_step}")
|
|
|
+
|
|
|
+ elif next_step == "analyze_validation_error":
|
|
|
+ # 场景2: 分析SQL验证错误(特殊指令)
|
|
|
+ for msg in reversed(state["messages"]):
|
|
|
+ if isinstance(msg, ToolMessage) and msg.name == "valid_sql":
|
|
|
+ error_guidance = self._generate_validation_error_guidance(msg.content)
|
|
|
+ messages_for_llm.append(SystemMessage(content=error_guidance))
|
|
|
+ logger.info(" ✅ 已添加SQL验证错误指导")
|
|
|
+ break
|
|
|
+
|
|
|
+ elif next_step == 'summarize_final_answer':
|
|
|
+ # 场景3: 总结最终答案(行为指令)
|
|
|
+ instruction = "System Instruction: The SQL query was executed successfully. Please analyze the JSON data in the last message and summarize it in natural, user-friendly language as the final answer. Do not expose the raw JSON data or SQL statements in your response."
|
|
|
+ messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
+ logger.info(" ✅ 已添加 '总结答案' 行为指令")
|
|
|
+
|
|
|
+ elif next_step == 'answer_with_common_sense':
|
|
|
+ # 场景4: 基于常识回答(特殊指令)
|
|
|
+ instruction = (
|
|
|
+ "无法为当前问题生成有效的SQL查询。失败原因已在上下文中提供。"
|
|
|
+ "请你直接利用自身的知识库来回答用户的问题,不要再重复解释失败的原因。"
|
|
|
+ )
|
|
|
+ messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
+ logger.info("✅ 已添加 '常识回答' 行为指令")
|
|
|
|
|
|
# 添加重试机制处理网络连接问题
|
|
|
import asyncio
|
|
@@ -413,50 +436,47 @@ class CustomReactAgent:
|
|
|
|
|
|
async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""
|
|
|
- 异步信息组装节点:为需要上下文的工具注入历史消息。
|
|
|
+ 准备工具输入。
|
|
|
+ - 强制修正generate_sql的question参数,确保使用用户原始问题。
|
|
|
+ - 为generate_sql注入经过严格过滤的、干净的对话历史。
|
|
|
"""
|
|
|
- logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")
|
|
|
-
|
|
|
- # 🎯 打印 state 全部信息
|
|
|
- # self._print_state_info(state, "prepare_tool_input")
|
|
|
-
|
|
|
- last_message = state["messages"][-1]
|
|
|
- if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
|
+ last_message = state['messages'][-1]
|
|
|
+ if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
|
|
return {"messages": [last_message]}
|
|
|
|
|
|
- # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
|
|
|
+ # 强制修正LLM幻觉出的问题
|
|
|
+ for tool_call in last_message.tool_calls:
|
|
|
+ if tool_call['name'] == 'generate_sql':
|
|
|
+ original_user_question = next((msg.content for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)), None)
|
|
|
+ if original_user_question and tool_call['args'].get('question') != original_user_question:
|
|
|
+ logger.warning(
|
|
|
+ f"修正 'generate_sql' 的问题参数。\n"
|
|
|
+ f" - LLM提供: '{tool_call['args'].get('question')}'\n"
|
|
|
+ f" + 修正为: '{original_user_question}'"
|
|
|
+ )
|
|
|
+ tool_call['args']['question'] = original_user_question
|
|
|
+
|
|
|
+ # 恢复原始的、更健壮的历史消息过滤和注入逻辑
|
|
|
new_tool_calls = []
|
|
|
for tool_call in last_message.tool_calls:
|
|
|
if tool_call["name"] == "generate_sql":
|
|
|
- logger.info(" 检测到 generate_sql 调用,注入历史消息。")
|
|
|
- # 复制一份以避免修改原始 tool_call
|
|
|
+ logger.info("检测到 generate_sql 调用,开始注入历史消息。")
|
|
|
modified_args = tool_call["args"].copy()
|
|
|
|
|
|
- # 🎯 改进的消息过滤逻辑:只保留有用的对话上下文,排除当前问题
|
|
|
clean_history = []
|
|
|
- messages_except_current = state["messages"][:-1] # 排除最后一个消息(当前问题)
|
|
|
+ messages_except_current = state["messages"][:-1]
|
|
|
|
|
|
for msg in messages_except_current:
|
|
|
if isinstance(msg, HumanMessage):
|
|
|
- # 保留历史用户消息(但不包括当前问题)
|
|
|
- clean_history.append({
|
|
|
- "type": "human",
|
|
|
- "content": msg.content
|
|
|
- })
|
|
|
+ clean_history.append({"type": "human", "content": msg.content})
|
|
|
elif isinstance(msg, AIMessage):
|
|
|
- # 只保留最终的、面向用户的回答(包含"[Formatted Output]"的消息)
|
|
|
- if msg.content and "[Formatted Output]" in msg.content:
|
|
|
- # 去掉 "[Formatted Output]" 标记,只保留真正的回答
|
|
|
- clean_content = msg.content.replace("[Formatted Output]\n", "")
|
|
|
- clean_history.append({
|
|
|
- "type": "ai",
|
|
|
- "content": clean_content
|
|
|
- })
|
|
|
- # 跳过包含工具调用的 AIMessage(中间步骤)
|
|
|
- # 跳过所有 ToolMessage(工具执行结果)
|
|
|
+ if not msg.tool_calls and msg.content:
|
|
|
+ clean_content = msg.content.replace("[Formatted Output]\n", "").strip()
|
|
|
+ if clean_content:
|
|
|
+ clean_history.append({"type": "ai", "content": clean_content})
|
|
|
|
|
|
modified_args["history_messages"] = clean_history
|
|
|
- logger.info(f" 注入了 {len(clean_history)} 条过滤后的历史消息")
|
|
|
+ logger.info(f"注入了 {len(clean_history)} 条过滤后的历史消息")
|
|
|
|
|
|
new_tool_calls.append({
|
|
|
"name": tool_call["name"],
|
|
@@ -466,10 +486,31 @@ class CustomReactAgent:
|
|
|
else:
|
|
|
new_tool_calls.append(tool_call)
|
|
|
|
|
|
- # 用包含修改后参数的新消息替换掉原来的
|
|
|
last_message.tool_calls = new_tool_calls
|
|
|
return {"messages": [last_message]}
|
|
|
|
|
|
+ def _filter_and_format_history(self, messages: list) -> list:
|
|
|
+ """
|
|
|
+ 过滤和格式化历史消息,为generate_sql工具提供干净的上下文。
|
|
|
+ 只保留历史中的用户提问和AI的最终回答。
|
|
|
+ """
|
|
|
+ clean_history = []
|
|
|
+ # 处理除最后一个(即当前的工具调用)之外的所有消息
|
|
|
+ messages_to_process = messages[:-1]
|
|
|
+
|
|
|
+ for msg in messages_to_process:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ clean_history.append({"type": "human", "content": msg.content})
|
|
|
+ elif isinstance(msg, AIMessage):
|
|
|
+ # 只保留最终的、面向用户的回答(不包含工具调用的纯文本回答)
|
|
|
+ if not msg.tool_calls and msg.content:
|
|
|
+ # 移除可能存在的格式化标记
|
|
|
+ clean_content = msg.content.replace("[Formatted Output]\n", "").strip()
|
|
|
+ if clean_content:
|
|
|
+ clean_history.append({"type": "ai", "content": clean_content})
|
|
|
+
|
|
|
+ return clean_history
|
|
|
+
|
|
|
async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""在工具执行后,更新 suggested_next_step 并清理参数。"""
|
|
|
logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
|
|
@@ -483,13 +524,12 @@ class CustomReactAgent:
|
|
|
next_step = None
|
|
|
|
|
|
if tool_name == 'generate_sql':
|
|
|
- if "失败" in tool_output or "无法生成" in tool_output:
|
|
|
+ # 使用 .lower() 将输出转为小写,可以同时捕获 "failed" 和 "Failed" 等情况
|
|
|
+ tool_output_lower = tool_output.lower()
|
|
|
+ if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
|
|
|
next_step = 'answer_with_common_sense'
|
|
|
else:
|
|
|
next_step = 'valid_sql'
|
|
|
-
|
|
|
- # 🎯 清理 generate_sql 的 history_messages 参数,设置为空字符串
|
|
|
- # self._clear_history_messages_parameter(state['messages'])
|
|
|
|
|
|
elif tool_name == 'valid_sql':
|
|
|
if "失败" in tool_output:
|
|
@@ -539,20 +579,26 @@ class CustomReactAgent:
|
|
|
"""异步生成API格式的数据结构"""
|
|
|
logger.info("📊 异步生成API格式数据...")
|
|
|
|
|
|
- # 提取基础响应内容
|
|
|
last_message = state['messages'][-1]
|
|
|
response_content = last_message.content
|
|
|
|
|
|
- # 去掉格式化标记,获取纯净的回答
|
|
|
if response_content.startswith("[Formatted Output]\n"):
|
|
|
response_content = response_content.replace("[Formatted Output]\n", "")
|
|
|
|
|
|
- # 初始化API数据结构
|
|
|
api_data = {
|
|
|
"response": response_content
|
|
|
}
|
|
|
+
|
|
|
+ # --- 新增逻辑:为 answer_with_common_sense 场景拼接响应 ---
|
|
|
+ if state.get("suggested_next_step") == 'answer_with_common_sense':
|
|
|
+ failure_reason = self._find_generate_sql_failure_reason(state['messages'])
|
|
|
+ if failure_reason:
|
|
|
+ # 将 "Database query failed. Reason: " 前缀移除,使其更自然
|
|
|
+ cleaned_reason = failure_reason.replace("Database query failed. Reason:", "").strip()
|
|
|
+ # 拼接失败原因和LLM的常识回答
|
|
|
+ api_data["response"] = f"{cleaned_reason}\n\n{response_content}"
|
|
|
+ logger.info(" ✅ 已成功拼接 '失败原因' 和 '常识回答'")
|
|
|
|
|
|
- # 提取SQL和数据记录
|
|
|
sql_info = await self._async_extract_sql_and_data(state['messages'])
|
|
|
if sql_info['sql']:
|
|
|
api_data["sql"] = sql_info['sql']
|
|
@@ -565,6 +611,18 @@ class CustomReactAgent:
|
|
|
logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}")
|
|
|
return api_data
|
|
|
|
|
|
+ def _find_generate_sql_failure_reason(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
+ """从后向前查找最近一次generate_sql失败的原因"""
|
|
|
+ for msg in reversed(messages):
|
|
|
+ if isinstance(msg, ToolMessage) and msg.name == 'generate_sql':
|
|
|
+ # 找到最近的generate_sql工具消息
|
|
|
+ if "failed" in msg.content.lower() or "失败" in msg.content.lower():
|
|
|
+ return msg.content
|
|
|
+ else:
|
|
|
+ # 如果是成功的消息,说明当前轮次没有失败,停止查找
|
|
|
+ return None
|
|
|
+ return None
|
|
|
+
|
|
|
async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
|
"""异步从消息历史中提取SQL和数据记录"""
|
|
|
result = {"sql": None, "records": None}
|
|
@@ -995,33 +1053,32 @@ Please intelligently choose whether to query the database based on the nature of
|
|
|
def _generate_validation_error_guidance(self, validation_error: str) -> str:
|
|
|
"""根据验证错误类型生成具体的修复指导"""
|
|
|
|
|
|
- if "字段不存在" in validation_error or "column" in validation_error.lower():
|
|
|
- return """SQL验证失败:字段不存在错误。
|
|
|
-处理建议:
|
|
|
-1. 检查字段名是否拼写正确
|
|
|
-2. 如果字段确实不存在,请告知用户缺少该字段,并基于常识提供答案
|
|
|
-3. 不要尝试修复不存在的字段,直接给出基于常识的解释"""
|
|
|
-
|
|
|
- elif "表不存在" in validation_error or "table" in validation_error.lower():
|
|
|
- return """SQL验证失败:表不存在错误。
|
|
|
-处理建议:
|
|
|
-1. 检查表名是否拼写正确
|
|
|
-2. 如果表确实不存在,请告知用户该数据不在数据库中
|
|
|
-3. 基于问题性质,提供常识性的答案或建议用户确认数据源"""
|
|
|
-
|
|
|
- elif "语法错误" in validation_error or "syntax error" in validation_error.lower():
|
|
|
+ # 优先处理最常见的语法错误
|
|
|
+ if "语法错误" in validation_error or "syntax error" in validation_error.lower():
|
|
|
return """SQL验证失败:语法错误。
|
|
|
处理建议:
|
|
|
1. 仔细检查SQL语法(括号、引号、关键词等)
|
|
|
2. 修复语法错误后,调用 valid_sql 工具重新验证
|
|
|
3. 常见问题:缺少逗号、括号不匹配、关键词拼写错误"""
|
|
|
|
|
|
+ # 新增的合并条件,处理所有“不存在”类型的错误
|
|
|
+ elif ("不存在" in validation_error or
|
|
|
+ "no such table" in validation_error.lower() or
|
|
|
+ "does not exist" in validation_error.lower()):
|
|
|
+ return """SQL验证失败:表或字段不存在。
|
|
|
+处理建议:
|
|
|
+1. 请明确告知用户,因数据库缺少相应的表或字段,无法通过SQL查询获取准确答案。
|
|
|
+2. 请基于你的通用知识和常识,直接回答用户的问题或提供相关解释。
|
|
|
+3. 请不要再尝试生成或修复SQL。"""
|
|
|
+
|
|
|
+ # 其他原有分支可以被新逻辑覆盖,故移除
|
|
|
+ # Fallback 到通用的错误处理
|
|
|
else:
|
|
|
return f"""SQL验证失败:{validation_error}
|
|
|
处理建议:
|
|
|
-1. 如果是语法问题,请修复后重新验证
|
|
|
-2. 如果是字段/表不存在,请向用户说明并提供基于常识的答案
|
|
|
-3. 避免猜测或编造数据库中不存在的信息"""
|
|
|
+1. 如果这是一个可以修复的错误,请尝试修正并再次验证。
|
|
|
+2. 如果错误表明数据缺失,请直接向用户说明情况。
|
|
|
+3. 避免猜测或编造数据库中不存在的信息。"""
|
|
|
|
|
|
# === 参数错误诊断和修复函数 ===
|
|
|
|