|
@@ -153,7 +153,7 @@ class CustomReactAgent:
|
|
|
"""
|
|
|
打印 state 的全部信息,用于调试
|
|
|
"""
|
|
|
- logger.info(" =" * 20)
|
|
|
+ logger.info(" ~" * 10 + " State Print Start" + "~" * 10)
|
|
|
logger.info(f"📋 [State Debug] {node_name} - 当前状态信息:")
|
|
|
|
|
|
# 🎯 打印 state 中的所有字段
|
|
@@ -187,7 +187,7 @@ class CustomReactAgent:
|
|
|
logger.info(f" 工具调用: {tool_name}")
|
|
|
logger.info(f" 参数: {str(tool_args)[:200]}...")
|
|
|
|
|
|
- logger.info(" =" * 20)
|
|
|
+ logger.info(" ~" * 10 + " State Print End" + "~" * 10)
|
|
|
|
|
|
def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""
|
|
@@ -251,7 +251,7 @@ class CustomReactAgent:
|
|
|
logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
|
|
|
|
|
|
# 🎯 打印 state 全部信息
|
|
|
- # self._print_state_info(state, "update_state_after_tool")
|
|
|
+ self._print_state_info(state, "update_state_after_tool")
|
|
|
|
|
|
last_tool_message = state['messages'][-1]
|
|
|
tool_name = last_tool_message.name
|
|
@@ -297,12 +297,53 @@ class CustomReactAgent:
|
|
|
last_message.content = f"[Formatted Output]\n{last_message.content}"
|
|
|
return {"messages": [last_message]}
|
|
|
|
|
|
+ def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
+ """从消息历史中提取最近的run_sql执行结果,但仅限于当前对话轮次。"""
|
|
|
+ logger.info("🔍 提取最新的SQL执行结果...")
|
|
|
+
|
|
|
+ # 🎯 只查找最后一个HumanMessage之后的SQL执行结果
|
|
|
+ last_human_index = -1
|
|
|
+ for i in range(len(messages) - 1, -1, -1):
|
|
|
+ if isinstance(messages[i], HumanMessage):
|
|
|
+ last_human_index = i
|
|
|
+ break
|
|
|
+
|
|
|
+ if last_human_index == -1:
|
|
|
+ logger.info(" 未找到用户消息,跳过SQL数据提取")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 只在当前对话轮次中查找SQL结果
|
|
|
+ current_conversation = messages[last_human_index:]
|
|
|
+ logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
|
|
|
+
|
|
|
+ for msg in reversed(current_conversation):
|
|
|
+ if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
|
|
|
+ logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
|
|
|
+
|
|
|
+ # 🎯 处理Unicode转义序列,将其转换为正常的中文字符
|
|
|
+ try:
|
|
|
+ # 先尝试解析JSON以验证格式
|
|
|
+ parsed_data = json.loads(msg.content)
|
|
|
+ # 重新序列化,确保中文字符正常显示
|
|
|
+ formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
|
|
|
+ logger.info(f" 已转换Unicode转义序列为中文字符")
|
|
|
+ return formatted_content
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ # 如果不是有效JSON,直接返回原内容
|
|
|
+ logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
|
|
|
+ return msg.content
|
|
|
+
|
|
|
+ logger.info(" 当前对话轮次中未找到run_sql执行结果")
|
|
|
+ return None
|
|
|
+
|
|
|
async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
|
"""
|
|
|
处理用户聊天请求。
|
|
|
"""
|
|
|
if not thread_id:
|
|
|
- thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
|
+ now = pd.Timestamp.now()
|
|
|
+ milliseconds = int(now.microsecond / 1000)
|
|
|
+ thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
|
|
|
logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
|
|
|
|
|
|
config = {
|
|
@@ -321,8 +362,26 @@ class CustomReactAgent:
|
|
|
try:
|
|
|
final_state = await self.agent_executor.ainvoke(inputs, config)
|
|
|
answer = final_state["messages"][-1].content
|
|
|
+
|
|
|
+ # 🎯 提取最近的 run_sql 执行结果(不修改messages)
|
|
|
+ sql_data = self._extract_latest_sql_data(final_state["messages"])
|
|
|
+
|
|
|
logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
|
|
|
- return {"success": True, "answer": answer, "thread_id": thread_id}
|
|
|
+
|
|
|
+ # 构建返回结果
|
|
|
+ result = {
|
|
|
+ "success": True,
|
|
|
+ "answer": answer,
|
|
|
+ "thread_id": thread_id
|
|
|
+ }
|
|
|
+
|
|
|
+ # 只有当存在SQL数据时才添加到返回结果中
|
|
|
+ if sql_data:
|
|
|
+ result["sql_data"] = sql_data
|
|
|
+ logger.info(" 📊 已包含SQL原始数据")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
|
return {"success": False, "error": str(e), "thread_id": thread_id}
|