|
@@ -4,6 +4,7 @@
|
|
|
import logging
|
|
|
import json
|
|
|
import pandas as pd
|
|
|
+import httpx
|
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
|
from contextlib import AsyncExitStack
|
|
|
|
|
@@ -72,13 +73,27 @@ class CustomReactAgent:
|
|
|
model=config.QWEN_MODEL,
|
|
|
temperature=0.1,
|
|
|
timeout=config.NETWORK_TIMEOUT, # 添加超时配置
|
|
|
- max_retries=config.MAX_RETRIES, # 添加重试配置
|
|
|
+ max_retries=0, # 禁用OpenAI客户端重试,改用Agent层统一重试
|
|
|
extra_body={
|
|
|
"enable_thinking": False,
|
|
|
"misc": {
|
|
|
"ensure_ascii": False
|
|
|
}
|
|
|
- }
|
|
|
+ },
|
|
|
+ # 新增:优化HTTP连接配置
|
|
|
+ http_client=httpx.Client(
|
|
|
+ limits=httpx.Limits(
|
|
|
+ max_connections=config.HTTP_MAX_CONNECTIONS,
|
|
|
+ max_keepalive_connections=config.HTTP_MAX_KEEPALIVE_CONNECTIONS,
|
|
|
+ keepalive_expiry=config.HTTP_KEEPALIVE_EXPIRY, # 30秒keep-alive过期
|
|
|
+ ),
|
|
|
+ timeout=httpx.Timeout(
|
|
|
+ connect=config.HTTP_CONNECT_TIMEOUT, # 连接超时
|
|
|
+ read=config.NETWORK_TIMEOUT, # 读取超时
|
|
|
+ write=config.HTTP_CONNECT_TIMEOUT, # 写入超时
|
|
|
+ pool=config.HTTP_POOL_TIMEOUT # 连接池超时
|
|
|
+ )
|
|
|
+ )
|
|
|
)
|
|
|
logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
|
|
|
@@ -168,8 +183,18 @@ class CustomReactAgent:
|
|
|
messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
|
|
|
logger.info(" ✅ 已添加数据库范围判断提示词")
|
|
|
|
|
|
- if state.get("suggested_next_step"):
|
|
|
- instruction = f"Suggestion: Consider using the '{state['suggested_next_step']}' tool for the next step."
|
|
|
+ # 检查是否需要分析验证错误
|
|
|
+ next_step = state.get("suggested_next_step")
|
|
|
+ if next_step == "analyze_validation_error":
|
|
|
+ # 查找最近的 valid_sql 错误信息
|
|
|
+ for msg in reversed(state["messages"]):
|
|
|
+ if isinstance(msg, ToolMessage) and msg.name == "valid_sql":
|
|
|
+ error_guidance = self._generate_validation_error_guidance(msg.content)
|
|
|
+ messages_for_llm.append(SystemMessage(content=error_guidance))
|
|
|
+ logger.info(" ✅ 已添加SQL验证错误指导")
|
|
|
+ break
|
|
|
+ elif next_step and next_step != "analyze_validation_error":
|
|
|
+ instruction = f"Suggestion: Consider using the '{next_step}' tool for the next step."
|
|
|
messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
|
|
|
# 添加重试机制处理网络连接问题
|
|
@@ -179,39 +204,156 @@ class CustomReactAgent:
|
|
|
try:
|
|
|
# 使用异步调用
|
|
|
response = await self.llm_with_tools.ainvoke(messages_for_llm)
|
|
|
- logger.info(f" ✅ 异步LLM调用成功")
|
|
|
- return {"messages": [response]}
|
|
|
+
|
|
|
+ # 新增:详细的响应检查和日志
|
|
|
+ logger.info(f" LLM原始响应内容: '{response.content}'")
|
|
|
+ logger.info(f" 响应内容长度: {len(response.content) if response.content else 0}")
|
|
|
+ logger.info(f" 响应内容类型: {type(response.content)}")
|
|
|
+ logger.info(f" LLM是否有工具调用: {hasattr(response, 'tool_calls') and response.tool_calls}")
|
|
|
+
|
|
|
+ if hasattr(response, 'tool_calls') and response.tool_calls:
|
|
|
+ logger.info(f" 工具调用数量: {len(response.tool_calls)}")
|
|
|
+ for i, tool_call in enumerate(response.tool_calls):
|
|
|
+ logger.info(f" 工具调用[{i}]: {tool_call.get('name', 'Unknown')}")
|
|
|
+
|
|
|
+ # 🎯 改进的响应检查和重试逻辑
|
|
|
+ # 检查空响应情况 - 将空响应也视为需要重试的情况
|
|
|
+ if not response.content and not (hasattr(response, 'tool_calls') and response.tool_calls):
|
|
|
+ logger.warning(" ⚠️ LLM返回空响应且无工具调用")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ # 空响应也进行重试
|
|
|
+ wait_time = config.RETRY_BASE_DELAY * (2 ** attempt)
|
|
|
+ logger.info(f" 🔄 空响应重试,{wait_time}秒后重试...")
|
|
|
+ await asyncio.sleep(wait_time)
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # 所有重试都失败,返回降级回答
|
|
|
+ logger.error(f" ❌ 多次尝试仍返回空响应,返回降级回答")
|
|
|
+ fallback_content = "抱歉,我现在无法正确处理您的问题。请稍后重试或重新表述您的问题。"
|
|
|
+ fallback_response = AIMessage(content=fallback_content)
|
|
|
+ return {"messages": [fallback_response]}
|
|
|
+
|
|
|
+ elif response.content and response.content.strip() == "":
|
|
|
+ logger.warning(" ⚠️ LLM返回只包含空白字符的内容")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ # 空白字符也进行重试
|
|
|
+ wait_time = config.RETRY_BASE_DELAY * (2 ** attempt)
|
|
|
+ logger.info(f" 🔄 空白字符重试,{wait_time}秒后重试...")
|
|
|
+ await asyncio.sleep(wait_time)
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # 所有重试都失败,返回降级回答
|
|
|
+ logger.error(f" ❌ 多次尝试仍返回空白字符,返回降级回答")
|
|
|
+ fallback_content = "抱歉,我现在无法正确处理您的问题。请稍后重试或重新表述您的问题。"
|
|
|
+ fallback_response = AIMessage(content=fallback_content)
|
|
|
+ return {"messages": [fallback_response]}
|
|
|
+
|
|
|
+ elif not response.content and hasattr(response, 'tool_calls') and response.tool_calls:
|
|
|
+ logger.info(" ✅ LLM只返回工具调用,无文本内容(正常情况)")
|
|
|
+
|
|
|
+ # 🎯 最终检查:确保响应是有效的
|
|
|
+ if ((response.content and response.content.strip()) or
|
|
|
+ (hasattr(response, 'tool_calls') and response.tool_calls)):
|
|
|
+ logger.info(f" ✅ 异步LLM调用成功,返回有效响应")
|
|
|
+ return {"messages": [response]}
|
|
|
+ else:
|
|
|
+ # 这种情况理论上不应该发生,但作为最后的保障
|
|
|
+ logger.error(f" ❌ 意外的响应格式,进行重试")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ wait_time = config.RETRY_BASE_DELAY * (2 ** attempt)
|
|
|
+ logger.info(f" 🔄 意外响应格式重试,{wait_time}秒后重试...")
|
|
|
+ await asyncio.sleep(wait_time)
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ fallback_content = "抱歉,我现在无法正确处理您的问题。请稍后重试或重新表述您的问题。"
|
|
|
+ fallback_response = AIMessage(content=fallback_content)
|
|
|
+ return {"messages": [fallback_response]}
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = str(e)
|
|
|
- logger.warning(f" ⚠️ LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}")
|
|
|
+ error_type = type(e).__name__
|
|
|
+ logger.warning(f" ⚠️ LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_type}: {error_msg}")
|
|
|
+
|
|
|
+ # 🎯 改进的错误分类逻辑:检查异常类型和错误消息
|
|
|
+ is_network_error = False
|
|
|
+ is_parameter_error = False
|
|
|
+
|
|
|
+ # 1. 检查异常类型
|
|
|
+ network_exception_types = [
|
|
|
+ 'APIConnectionError', 'ConnectTimeout', 'ReadTimeout',
|
|
|
+ 'TimeoutError', 'APITimeoutError', 'ConnectError',
|
|
|
+ 'HTTPError', 'RequestException', 'ConnectionError'
|
|
|
+ ]
|
|
|
+ if error_type in network_exception_types:
|
|
|
+ is_network_error = True
|
|
|
+ logger.info(f" 📊 根据异常类型判断为网络错误: {error_type}")
|
|
|
|
|
|
- # 检查是否是网络连接错误
|
|
|
- if any(keyword in error_msg for keyword in [
|
|
|
- "Connection error", "APIConnectionError", "ConnectError",
|
|
|
- "timeout", "远程主机强迫关闭", "网络连接"
|
|
|
- ]):
|
|
|
+ # 2. 检查BadRequestError中的参数错误
|
|
|
+ if error_type == 'BadRequestError':
|
|
|
+ # 检查是否是消息格式错误
|
|
|
+ if any(keyword in error_msg.lower() for keyword in [
|
|
|
+ 'must be followed by tool messages',
|
|
|
+ 'invalid_parameter_error',
|
|
|
+ 'assistant message with "tool_calls"',
|
|
|
+ 'tool_call_id',
|
|
|
+ 'message format'
|
|
|
+ ]):
|
|
|
+ is_parameter_error = True
|
|
|
+ logger.info(f" 📊 根据错误消息判断为参数格式错误: {error_msg[:100]}...")
|
|
|
+
|
|
|
+ # 3. 检查错误消息内容(不区分大小写)
|
|
|
+ error_msg_lower = error_msg.lower()
|
|
|
+ network_keywords = [
|
|
|
+ 'connection error', 'connect error', 'timeout', 'timed out',
|
|
|
+ 'network', 'connection refused', 'connection reset',
|
|
|
+ 'remote host', '远程主机', '网络连接', '连接超时',
|
|
|
+ 'request timed out', 'read timeout', 'connect timeout'
|
|
|
+ ]
|
|
|
+
|
|
|
+ for keyword in network_keywords:
|
|
|
+ if keyword in error_msg_lower:
|
|
|
+ is_network_error = True
|
|
|
+ logger.info(f" 📊 根据错误消息判断为网络错误: '{keyword}' in '{error_msg}'")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 处理可重试的错误
|
|
|
+ if is_network_error or is_parameter_error:
|
|
|
if attempt < max_retries - 1:
|
|
|
- wait_time = config.RETRY_BASE_DELAY ** attempt # 指数退避:2, 4, 8秒
|
|
|
- logger.info(f" 🔄 网络错误,{wait_time}秒后重试...")
|
|
|
+ # 渐进式重试间隔:3, 6, 12秒
|
|
|
+ wait_time = config.RETRY_BASE_DELAY * (2 ** attempt)
|
|
|
+ error_type_desc = "网络错误" if is_network_error else "参数格式错误"
|
|
|
+ logger.info(f" 🔄 {error_type_desc},{wait_time}秒后重试...")
|
|
|
+
|
|
|
+ # 🎯 对于参数错误,修复消息历史后重试
|
|
|
+ if is_parameter_error:
|
|
|
+ try:
|
|
|
+ messages_for_llm = await self._handle_parameter_error_with_retry(
|
|
|
+ messages_for_llm, error_msg, attempt
|
|
|
+ )
|
|
|
+ logger.info(f" 🔧 消息历史修复完成,继续重试...")
|
|
|
+ except Exception as fix_error:
|
|
|
+ logger.error(f" ❌ 消息历史修复失败: {fix_error}")
|
|
|
+ # 修复失败,使用原始消息继续重试
|
|
|
+
|
|
|
await asyncio.sleep(wait_time)
|
|
|
continue
|
|
|
else:
|
|
|
# 所有重试都失败了,返回一个降级的回答
|
|
|
- logger.error(f" ❌ 网络连接持续失败,返回降级回答")
|
|
|
+ error_type_desc = "网络连接" if is_network_error else "请求格式"
|
|
|
+ logger.error(f" ❌ {error_type_desc}持续失败,返回降级回答")
|
|
|
|
|
|
# 检查是否有SQL执行结果可以利用
|
|
|
sql_data = await self._async_extract_latest_sql_data(state["messages"])
|
|
|
if sql_data:
|
|
|
- fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
|
|
|
+ fallback_content = f"抱歉,由于{error_type_desc}问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
|
|
|
else:
|
|
|
- fallback_content = "抱歉,由于网络连接问题,无法完成此次请求。请稍后重试或检查网络连接。"
|
|
|
+ fallback_content = f"抱歉,由于{error_type_desc}问题,无法完成此次请求。请稍后重试或检查网络连接。"
|
|
|
|
|
|
fallback_response = AIMessage(content=fallback_content)
|
|
|
return {"messages": [fallback_response]}
|
|
|
else:
|
|
|
# 非网络错误,直接抛出
|
|
|
- logger.error(f" ❌ LLM调用出现非网络错误: {error_msg}")
|
|
|
+ logger.error(f" ❌ LLM调用出现非可重试错误: {error_type}: {error_msg}")
|
|
|
raise e
|
|
|
|
|
|
def _print_state_info(self, state: AgentState, node_name: str) -> None:
|
|
@@ -239,10 +381,16 @@ class CustomReactAgent:
|
|
|
|
|
|
if messages:
|
|
|
logger.info(" 最近的消息:")
|
|
|
- for i, msg in enumerate(messages[-10:], start=max(0, len(messages)-10)): # 显示最后3条消息
|
|
|
+ for i, msg in enumerate(messages[-10:], start=max(0, len(messages)-10)): # 显示最后10条消息
|
|
|
msg_type = type(msg).__name__
|
|
|
- content_preview = str(msg.content)[:100] + "..." if len(str(msg.content)) > 100 else str(msg.content)
|
|
|
- logger.info(f" [{i}] {msg_type}: {content_preview}")
|
|
|
+ content = str(msg.content)
|
|
|
+
|
|
|
+ # 对于长内容,使用多行显示
|
|
|
+ if len(content) > 200:
|
|
|
+ logger.info(f" [{i}] {msg_type}:")
|
|
|
+ logger.info(f" {content}")
|
|
|
+ else:
|
|
|
+ logger.info(f" [{i}] {msg_type}: {content}")
|
|
|
|
|
|
# 如果是 AIMessage 且有工具调用,显示工具调用信息
|
|
|
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
@@ -250,7 +398,16 @@ class CustomReactAgent:
|
|
|
tool_name = tool_call.get('name', 'Unknown')
|
|
|
tool_args = tool_call.get('args', {})
|
|
|
logger.info(f" 工具调用: {tool_name}")
|
|
|
- logger.info(f" 参数: {str(tool_args)[:200]}...")
|
|
|
+
|
|
|
+ # 对于复杂参数,使用JSON格式化
|
|
|
+ import json
|
|
|
+ try:
|
|
|
+ formatted_args = json.dumps(tool_args, ensure_ascii=False, indent=2)
|
|
|
+ logger.info(f" 参数:")
|
|
|
+ for line in formatted_args.split('\n'):
|
|
|
+ logger.info(f" {line}")
|
|
|
+ except Exception:
|
|
|
+ logger.info(f" 参数: {str(tool_args)}")
|
|
|
|
|
|
logger.info(" ~" * 10 + " State Print End" + " ~" * 10)
|
|
|
|
|
@@ -821,4 +978,286 @@ Please intelligently choose whether to query the database based on the nature of
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠️ Unable to read database scope description file: {e}")
|
|
|
- return ""
|
|
|
+ return ""
|
|
|
+
|
|
|
+ def _generate_validation_error_guidance(self, validation_error: str) -> str:
|
|
|
+ """根据验证错误类型生成具体的修复指导"""
|
|
|
+
|
|
|
+ if "字段不存在" in validation_error or "column" in validation_error.lower():
|
|
|
+ return """SQL验证失败:字段不存在错误。
|
|
|
+处理建议:
|
|
|
+1. 检查字段名是否拼写正确
|
|
|
+2. 如果字段确实不存在,请告知用户缺少该字段,并基于常识提供答案
|
|
|
+3. 不要尝试修复不存在的字段,直接给出基于常识的解释"""
|
|
|
+
|
|
|
+ elif "表不存在" in validation_error or "table" in validation_error.lower():
|
|
|
+ return """SQL验证失败:表不存在错误。
|
|
|
+处理建议:
|
|
|
+1. 检查表名是否拼写正确
|
|
|
+2. 如果表确实不存在,请告知用户该数据不在数据库中
|
|
|
+3. 基于问题性质,提供常识性的答案或建议用户确认数据源"""
|
|
|
+
|
|
|
+ elif "语法错误" in validation_error or "syntax error" in validation_error.lower():
|
|
|
+ return """SQL验证失败:语法错误。
|
|
|
+处理建议:
|
|
|
+1. 仔细检查SQL语法(括号、引号、关键词等)
|
|
|
+2. 修复语法错误后,调用 valid_sql 工具重新验证
|
|
|
+3. 常见问题:缺少逗号、括号不匹配、关键词拼写错误"""
|
|
|
+
|
|
|
+ else:
|
|
|
+ return f"""SQL验证失败:{validation_error}
|
|
|
+处理建议:
|
|
|
+1. 如果是语法问题,请修复后重新验证
|
|
|
+2. 如果是字段/表不存在,请向用户说明并提供基于常识的答案
|
|
|
+3. 避免猜测或编造数据库中不存在的信息"""
|
|
|
+
|
|
|
+ # === 参数错误诊断和修复函数 ===
|
|
|
+
|
|
|
+ def _diagnose_parameter_error(self, messages: List[BaseMessage], error_msg: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 诊断参数错误的详细原因
|
|
|
+ """
|
|
|
+ logger.error("🔍 开始诊断参数错误...")
|
|
|
+ logger.error(f" 错误消息: {error_msg}")
|
|
|
+
|
|
|
+ diagnosis = {
|
|
|
+ "error_type": "parameter_error",
|
|
|
+ "incomplete_tool_calls": [],
|
|
|
+ "orphaned_tool_messages": [],
|
|
|
+ "total_messages": len(messages),
|
|
|
+ "recommended_action": None
|
|
|
+ }
|
|
|
+
|
|
|
+ # 分析消息历史
|
|
|
+ logger.error("📋 消息历史分析:")
|
|
|
+ for i, msg in enumerate(messages):
|
|
|
+ msg_type = type(msg).__name__
|
|
|
+
|
|
|
+ if isinstance(msg, AIMessage):
|
|
|
+ has_tool_calls = hasattr(msg, 'tool_calls') and msg.tool_calls
|
|
|
+ content_summary = f"'{msg.content[:50]}...'" if msg.content else "空内容"
|
|
|
+
|
|
|
+ logger.error(f" [{i}] {msg_type}: {content_summary}")
|
|
|
+
|
|
|
+ if has_tool_calls:
|
|
|
+ logger.error(f" 工具调用: {len(msg.tool_calls)} 个")
|
|
|
+ for j, tc in enumerate(msg.tool_calls):
|
|
|
+ tool_name = tc.get('name', 'Unknown')
|
|
|
+ tool_id = tc.get('id', 'Unknown')
|
|
|
+ logger.error(f" [{j}] {tool_name} (ID: {tool_id})")
|
|
|
+
|
|
|
+ # 查找对应的ToolMessage
|
|
|
+ found_response = False
|
|
|
+ for k in range(i + 1, len(messages)):
|
|
|
+ if (isinstance(messages[k], ToolMessage) and
|
|
|
+ messages[k].tool_call_id == tool_id):
|
|
|
+ found_response = True
|
|
|
+ break
|
|
|
+ elif isinstance(messages[k], (HumanMessage, AIMessage)):
|
|
|
+ # 遇到新的对话轮次,停止查找
|
|
|
+ break
|
|
|
+
|
|
|
+ if not found_response:
|
|
|
+ diagnosis["incomplete_tool_calls"].append({
|
|
|
+ "message_index": i,
|
|
|
+ "tool_name": tool_name,
|
|
|
+ "tool_id": tool_id,
|
|
|
+ "ai_message_content": msg.content
|
|
|
+ })
|
|
|
+ logger.error(f" ❌ 未找到对应的ToolMessage!")
|
|
|
+ else:
|
|
|
+ logger.error(f" ✅ 找到对应的ToolMessage")
|
|
|
+
|
|
|
+ elif isinstance(msg, ToolMessage):
|
|
|
+ logger.error(f" [{i}] {msg_type}: {msg.name} (ID: {msg.tool_call_id})")
|
|
|
+
|
|
|
+ # 检查是否有对应的AIMessage
|
|
|
+ found_ai_message = False
|
|
|
+ for k in range(i - 1, -1, -1):
|
|
|
+ if (isinstance(messages[k], AIMessage) and
|
|
|
+ hasattr(messages[k], 'tool_calls') and
|
|
|
+ messages[k].tool_calls):
|
|
|
+ if any(tc.get('id') == msg.tool_call_id for tc in messages[k].tool_calls):
|
|
|
+ found_ai_message = True
|
|
|
+ break
|
|
|
+ elif isinstance(messages[k], HumanMessage):
|
|
|
+ break
|
|
|
+
|
|
|
+ if not found_ai_message:
|
|
|
+ diagnosis["orphaned_tool_messages"].append({
|
|
|
+ "message_index": i,
|
|
|
+ "tool_name": msg.name,
|
|
|
+ "tool_call_id": msg.tool_call_id
|
|
|
+ })
|
|
|
+ logger.error(f" ❌ 未找到对应的AIMessage!")
|
|
|
+
|
|
|
+ elif isinstance(msg, HumanMessage):
|
|
|
+ logger.error(f" [{i}] {msg_type}: '{msg.content[:50]}...'")
|
|
|
+
|
|
|
+ # 生成修复建议
|
|
|
+ if diagnosis["incomplete_tool_calls"]:
|
|
|
+ logger.error(f"🔧 发现 {len(diagnosis['incomplete_tool_calls'])} 个不完整的工具调用")
|
|
|
+ diagnosis["recommended_action"] = "fix_incomplete_tool_calls"
|
|
|
+ elif diagnosis["orphaned_tool_messages"]:
|
|
|
+ logger.error(f"🔧 发现 {len(diagnosis['orphaned_tool_messages'])} 个孤立的工具消息")
|
|
|
+ diagnosis["recommended_action"] = "remove_orphaned_tool_messages"
|
|
|
+ else:
|
|
|
+ logger.error("🔧 未发现明显的消息格式问题")
|
|
|
+ diagnosis["recommended_action"] = "unknown"
|
|
|
+
|
|
|
+ return diagnosis
|
|
|
+
|
|
|
+ def _fix_by_adding_missing_tool_messages(self, messages: List[BaseMessage], diagnosis: Dict) -> List[BaseMessage]:
|
|
|
+ """
|
|
|
+ 通过添加缺失的ToolMessage来修复消息历史
|
|
|
+ """
|
|
|
+ logger.info("🔧 策略1: 补充缺失的ToolMessage")
|
|
|
+
|
|
|
+ fixed_messages = list(messages)
|
|
|
+
|
|
|
+ for incomplete in diagnosis["incomplete_tool_calls"]:
|
|
|
+ # 为缺失的工具调用添加错误响应
|
|
|
+ error_tool_message = ToolMessage(
|
|
|
+ content="工具调用已超时或失败,请重新尝试。",
|
|
|
+ tool_call_id=incomplete["tool_id"],
|
|
|
+ name=incomplete["tool_name"]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 插入到合适的位置
|
|
|
+ insert_index = incomplete["message_index"] + 1
|
|
|
+ fixed_messages.insert(insert_index, error_tool_message)
|
|
|
+
|
|
|
+ logger.info(f" ✅ 为工具调用 {incomplete['tool_name']}({incomplete['tool_id']}) 添加错误响应")
|
|
|
+
|
|
|
+ return fixed_messages
|
|
|
+
|
|
|
+ def _fix_by_removing_incomplete_tool_calls(self, messages: List[BaseMessage], diagnosis: Dict) -> List[BaseMessage]:
|
|
|
+ """
|
|
|
+ 通过删除不完整的工具调用来修复消息历史
|
|
|
+ """
|
|
|
+ logger.info("🔧 策略2: 删除不完整的工具调用")
|
|
|
+
|
|
|
+ fixed_messages = []
|
|
|
+
|
|
|
+ for i, msg in enumerate(messages):
|
|
|
+ if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
|
+ # 检查这个消息是否有不完整的工具调用
|
|
|
+ has_incomplete = any(
|
|
|
+ inc["message_index"] == i
|
|
|
+ for inc in diagnosis["incomplete_tool_calls"]
|
|
|
+ )
|
|
|
+
|
|
|
+ if has_incomplete:
|
|
|
+ # 如果有文本内容,保留文本内容但删除工具调用
|
|
|
+ if msg.content and msg.content.strip():
|
|
|
+ logger.info(f" 🔧 保留文本内容,删除工具调用: '{msg.content[:50]}...'")
|
|
|
+ fixed_msg = AIMessage(content=msg.content)
|
|
|
+ fixed_messages.append(fixed_msg)
|
|
|
+ else:
|
|
|
+ # 如果没有文本内容,创建一个说明性的消息
|
|
|
+ logger.info(f" 🔧 创建说明性消息替换空的工具调用")
|
|
|
+ fixed_msg = AIMessage(content="我需要重新分析这个问题。")
|
|
|
+ fixed_messages.append(fixed_msg)
|
|
|
+ else:
|
|
|
+ fixed_messages.append(msg)
|
|
|
+ else:
|
|
|
+ fixed_messages.append(msg)
|
|
|
+
|
|
|
+ return fixed_messages
|
|
|
+
|
|
|
+ def _fix_by_rebuilding_history(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
|
|
+ """
|
|
|
+ 重建消息历史,只保留完整的对话轮次
|
|
|
+ """
|
|
|
+ logger.info("🔧 策略3: 重建消息历史")
|
|
|
+
|
|
|
+ clean_messages = []
|
|
|
+ current_conversation = []
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ # 新的对话轮次开始
|
|
|
+ if current_conversation:
|
|
|
+ # 检查上一轮对话是否完整
|
|
|
+ if self._is_conversation_complete(current_conversation):
|
|
|
+ clean_messages.extend(current_conversation)
|
|
|
+ logger.info(f" ✅ 保留完整的对话轮次 ({len(current_conversation)} 条消息)")
|
|
|
+ else:
|
|
|
+ logger.info(f" ❌ 跳过不完整的对话轮次 ({len(current_conversation)} 条消息)")
|
|
|
+
|
|
|
+ current_conversation = [msg]
|
|
|
+ else:
|
|
|
+ current_conversation.append(msg)
|
|
|
+
|
|
|
+ # 处理最后一轮对话
|
|
|
+ if current_conversation:
|
|
|
+ if self._is_conversation_complete(current_conversation):
|
|
|
+ clean_messages.extend(current_conversation)
|
|
|
+ else:
|
|
|
+ # 最后一轮对话不完整,只保留用户消息
|
|
|
+ clean_messages.extend([msg for msg in current_conversation if isinstance(msg, HumanMessage)])
|
|
|
+
|
|
|
+ logger.info(f" 📊 重建完成: {len(messages)} -> {len(clean_messages)} 条消息")
|
|
|
+ return clean_messages
|
|
|
+
|
|
|
+ def _is_conversation_complete(self, conversation: List[BaseMessage]) -> bool:
|
|
|
+ """
|
|
|
+ 检查对话轮次是否完整
|
|
|
+ """
|
|
|
+ for msg in conversation:
|
|
|
+ if (isinstance(msg, AIMessage) and
|
|
|
+ hasattr(msg, 'tool_calls') and
|
|
|
+ msg.tool_calls):
|
|
|
+ # 检查是否有对应的ToolMessage
|
|
|
+ tool_call_ids = [tc.get('id') for tc in msg.tool_calls]
|
|
|
+ found_responses = sum(
|
|
|
+ 1 for m in conversation
|
|
|
+ if isinstance(m, ToolMessage) and m.tool_call_id in tool_call_ids
|
|
|
+ )
|
|
|
+ if found_responses < len(tool_call_ids):
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
+ async def _handle_parameter_error_with_retry(self, messages: List[BaseMessage], error_msg: str, attempt: int) -> List[BaseMessage]:
|
|
|
+ """
|
|
|
+ 处理参数错误的完整流程
|
|
|
+ """
|
|
|
+ logger.error(f"🔧 处理参数错误 (重试 {attempt + 1}/3)")
|
|
|
+
|
|
|
+ # 1. 诊断问题
|
|
|
+ diagnosis = self._diagnose_parameter_error(messages, error_msg)
|
|
|
+
|
|
|
+ # 2. 根据重试次数选择修复策略
|
|
|
+ if attempt == 0:
|
|
|
+ # 第一次重试:补充缺失的ToolMessage
|
|
|
+ fixed_messages = self._fix_by_adding_missing_tool_messages(messages, diagnosis)
|
|
|
+ elif attempt == 1:
|
|
|
+ # 第二次重试:删除不完整的工具调用
|
|
|
+ fixed_messages = self._fix_by_removing_incomplete_tool_calls(messages, diagnosis)
|
|
|
+ else:
|
|
|
+ # 第三次重试:重建消息历史
|
|
|
+ fixed_messages = self._fix_by_rebuilding_history(messages)
|
|
|
+
|
|
|
+ logger.info(f"🔧 修复完成: {len(messages)} -> {len(fixed_messages)} 条消息")
|
|
|
+ return fixed_messages
|
|
|
+
|
|
|
+ def _generate_contextual_fallback(self, messages: List[BaseMessage], diagnosis: Dict) -> str:
|
|
|
+ """
|
|
|
+ 基于上下文生成合理的回答
|
|
|
+ """
|
|
|
+ # 分析用户的最新问题
|
|
|
+ last_human_message = None
|
|
|
+ for msg in reversed(messages):
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ last_human_message = msg
|
|
|
+ break
|
|
|
+
|
|
|
+ if not last_human_message:
|
|
|
+ return "抱歉,我无法理解您的问题。"
|
|
|
+
|
|
|
+ # 分析是否是数据库相关问题
|
|
|
+ question = last_human_message.content.lower()
|
|
|
+ if any(keyword in question for keyword in ['查询', '数据', '服务区', '收入', '车流量']):
|
|
|
+ return f"抱歉,在处理您关于「{last_human_message.content}」的查询时遇到了技术问题。请稍后重试,或者重新描述您的问题。"
|
|
|
+ else:
|
|
|
+ return "抱歉,我现在无法正确处理您的问题。请稍后重试或重新表述您的问题。"
|