瀏覽代碼

./custom_react_agent 又恢复到json有ascii码的状态,现在准备修改解决这个问题。

wangxq 1 月之前
父節點
當前提交
106a60d0f7

+ 289 - 0
test/agent_old.py

@@ -0,0 +1,289 @@
+"""
+基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
+"""
+import logging
+import json
+import pandas as pd
+from typing import List, Optional, Dict, Any, Tuple
+from contextlib import AsyncExitStack
+
+from langchain_openai import ChatOpenAI
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
+from langgraph.graph import StateGraph, END
+from langgraph.prebuilt import ToolNode
+from redis.asyncio import Redis
+try:
+    from langgraph.checkpoint.redis import AsyncRedisSaver
+except ImportError:
+    AsyncRedisSaver = None
+
+# 从新模块导入配置、状态和工具
+from . import config
+from .state import AgentState
+from .sql_tools import sql_tools
+
+logger = logging.getLogger(__name__)
+
+class CustomReactAgent:
+    """
+    一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
+    """
+    def __init__(self):
+        """私有构造函数,请使用 create() 类方法来创建实例。"""
+        self.llm = None
+        self.tools = None
+        self.agent_executor = None
+        self.checkpointer = None
+        self._exit_stack = None
+
+    @classmethod
+    async def create(cls):
+        """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
+        instance = cls()
+        await instance._async_init()
+        return instance
+
+    async def _async_init(self):
+        """异步初始化所有组件。"""
+        logger.info("🚀 开始初始化 CustomReactAgent...")
+
+        # 1. 初始化 LLM
+        self.llm = ChatOpenAI(
+            api_key=config.QWEN_API_KEY,
+            base_url=config.QWEN_BASE_URL,
+            model=config.QWEN_MODEL,
+            temperature=0.1,
+            model_kwargs={
+                "extra_body": {
+                    "enable_thinking": False,
+                    "misc": {
+                        "ensure_ascii": False
+                    }
+                }
+            }
+        )
+        logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
+
+        # 2. 绑定工具
+        self.tools = sql_tools
+        self.llm_with_tools = self.llm.bind_tools(self.tools)
+        logger.info(f"   已绑定 {len(self.tools)} 个工具。")
+
+        # 3. 初始化 Redis Checkpointer
+        if config.REDIS_ENABLED and AsyncRedisSaver is not None:
+            try:
+                self._exit_stack = AsyncExitStack()
+                checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
+                self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
+                await self.checkpointer.asetup()
+                logger.info(f"   AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
+            except Exception as e:
+                logger.error(f"   ❌ RedisSaver 初始化失败: {e}", exc_info=True)
+                if self._exit_stack:
+                    await self._exit_stack.aclose()
+                self.checkpointer = None
+        else:
+            logger.warning("   Redis 持久化功能已禁用。")
+
+        # 4. 构建 StateGraph
+        self.agent_executor = self._create_graph()
+        logger.info("   StateGraph 已构建并编译。")
+        logger.info("✅ CustomReactAgent 初始化完成。")
+
+    async def close(self):
+        """清理资源,关闭 Redis 连接。"""
+        if self._exit_stack:
+            await self._exit_stack.aclose()
+            self._exit_stack = None
+            self.checkpointer = None
+            logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
+
+    def _create_graph(self):
+        """定义并编译 StateGraph。"""
+        builder = StateGraph(AgentState)
+
+        # 定义节点
+        builder.add_node("agent", self._agent_node)
+        builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
+        builder.add_node("tools", ToolNode(self.tools))
+        builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
+        builder.add_node("format_final_response", self._format_final_response_node)
+
+        # 定义边
+        builder.set_entry_point("agent")
+        builder.add_conditional_edges(
+            "agent",
+            self._should_continue,
+            {
+                "continue": "prepare_tool_input",
+                "end": "format_final_response"
+            }
+        )
+        builder.add_edge("prepare_tool_input", "tools")
+        builder.add_edge("tools", "update_state_after_tool")
+        builder.add_edge("update_state_after_tool", "agent")
+        builder.add_edge("format_final_response", END)
+
+        # 编译图,并传入 checkpointer
+        return builder.compile(checkpointer=self.checkpointer)
+
+    def _should_continue(self, state: AgentState) -> str:
+        """判断是继续调用工具还是结束。"""
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return "end"
+        return "continue"
+
+    def _agent_node(self, state: AgentState) -> Dict[str, Any]:
+        """Agent 节点:调用 LLM 进行思考和决策。"""
+        logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
+        
+        messages_for_llm = list(state["messages"])
+        if state.get("suggested_next_step"):
+            instruction = f"基于之前的步骤,强烈建议你下一步执行 '{state['suggested_next_step']}' 操作。"
+            # 为了避免污染历史,可以考虑不同的注入方式,但这里为了简单直接添加
+            messages_for_llm.append(HumanMessage(content=instruction, name="system_instruction"))
+
+        response = self.llm_with_tools.invoke(messages_for_llm)
+        logger.info(f"   LLM 返回: {response.pretty_print()}")
+        return {"messages": [response]}
+    
+    def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+        """信息组装节点:为需要上下文的工具注入历史消息。"""
+        logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
+        
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return {}
+
+        # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
+        new_tool_calls = []
+        for tool_call in last_message.tool_calls:
+            if tool_call["name"] == "generate_sql":
+                logger.info("   检测到 generate_sql 调用,注入可序列化的历史消息。")
+                # 复制一份以避免修改原始 tool_call
+                modified_args = tool_call["args"].copy()
+                
+                # 将消息对象列表转换为可序列化的字典列表
+                serializable_history = []
+                for msg in state["messages"]:
+                    serializable_history.append({
+                        "type": msg.type,
+                        "content": msg.content
+                    })
+                
+                modified_args["history_messages"] = serializable_history
+                new_tool_calls.append({
+                    "name": tool_call["name"],
+                    "args": modified_args,
+                    "id": tool_call["id"],
+                })
+            else:
+                new_tool_calls.append(tool_call)
+        
+        # 用包含修改后参数的新消息替换掉原来的
+        last_message.tool_calls = new_tool_calls
+        return {"messages": [last_message]}
+
+    def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
+        """流程建议与错误处理节点:在工具执行后更新状态。"""
+        logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
+        
+        last_tool_message = state['messages'][-1]
+        tool_name = last_tool_message.name
+        tool_output = last_tool_message.content
+        next_step = None
+
+        if tool_name == 'generate_sql':
+            if "失败" in tool_output or "无法生成" in tool_output:
+                next_step = 'answer_with_common_sense'
+                logger.warning(f"   generate_sql 失败,建议下一步: {next_step}")
+            else:
+                next_step = 'valid_sql'
+                logger.info(f"   generate_sql 成功,建议下一步: {next_step}")
+        
+        elif tool_name == 'valid_sql':
+            if "失败" in tool_output:
+                next_step = 'analyze_validation_error'
+                logger.warning(f"   valid_sql 失败,建议下一步: {next_step}")
+            else:
+                next_step = 'run_sql'
+                logger.info(f"   valid_sql 成功,建议下一步: {next_step}")
+
+        elif tool_name == 'run_sql':
+            next_step = 'summarize_final_answer'
+            logger.info(f"   run_sql 执行完毕,建议下一步: {next_step}")
+
+        return {"suggested_next_step": next_step}
+
+    def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
+        """最终输出格式化节点(当前为占位符)。"""
+        logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']} - 准备格式化最终输出...")
+        # 这里可以添加一个标记,表示这是格式化后的输出
+        last_message = state['messages'][-1]
+        formatted_content = f"[Formatted Output]\n{last_message.content}"
+        last_message.content = formatted_content
+        return {"messages": [last_message]}
+
+    async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
+        """
+        处理用户聊天请求。
+        """
+        if not thread_id:
+            thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
+            logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
+
+        config = {"configurable": {"thread_id": thread_id}}
+        
+        # 定义输入
+        inputs = {
+            "messages": [HumanMessage(content=message)],
+            "user_id": user_id,
+            "thread_id": thread_id,
+            "suggested_next_step": None, # 初始化建议
+        }
+
+        final_state = None
+        try:
+            logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
+            # 使用 ainvoke 来执行完整的图流程
+            final_state = await self.agent_executor.ainvoke(inputs, config)
+            
+            if final_state and final_state.get("messages"):
+                answer = final_state["messages"][-1].content
+                logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
+                return {"success": True, "answer": answer, "thread_id": thread_id}
+            else:
+                 logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
+                 return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
+
+        except Exception as e:
+            logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
+            return {"success": False, "error": str(e), "thread_id": thread_id}
+
+    async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
+        """从 checkpointer 获取指定线程的对话历史。"""
+        if not self.checkpointer:
+            return []
+        
+        config = {"configurable": {"thread_id": thread_id}}
+        conversation_state = await self.checkpointer.get(config)
+        
+        if not conversation_state:
+            return []
+            
+        history = []
+        for msg in conversation_state['values'].get('messages', []):
+            if isinstance(msg, HumanMessage):
+                role = "human"
+            elif isinstance(msg, ToolMessage):
+                role = "tool"
+            else: # AIMessage
+                role = "ai"
+            
+            history.append({
+                "type": role,
+                "content": msg.content,
+                "tool_calls": getattr(msg, 'tool_calls', None)
+            })
+        return history 

+ 124 - 32
test/custom_react_agent/agent.py

@@ -8,7 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
 from contextlib import AsyncExitStack
 from contextlib import AsyncExitStack
 
 
 from langchain_openai import ChatOpenAI
 from langchain_openai import ChatOpenAI
-from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage
 from langgraph.graph import StateGraph, END
 from langgraph.graph import StateGraph, END
 from langgraph.prebuilt import ToolNode
 from langgraph.prebuilt import ToolNode
 from redis.asyncio import Redis
 from redis.asyncio import Redis
@@ -21,6 +21,7 @@ except ImportError:
 from . import config
 from . import config
 from .state import AgentState
 from .state import AgentState
 from .sql_tools import sql_tools
 from .sql_tools import sql_tools
+from langchain_core.runnables import RunnablePassthrough
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -53,7 +54,14 @@ class CustomReactAgent:
             base_url=config.QWEN_BASE_URL,
             base_url=config.QWEN_BASE_URL,
             model=config.QWEN_MODEL,
             model=config.QWEN_MODEL,
             temperature=0.1,
             temperature=0.1,
-            extra_body={"enable_thinking": False}
+            model_kwargs={
+                "extra_body": {
+                    "enable_thinking": False,
+                    "misc": {
+                        "ensure_ascii": False
+                    }
+                }
+            }
         )
         )
         logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
         logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
 
 
@@ -92,40 +100,129 @@ class CustomReactAgent:
             logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
             logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
 
 
     def _create_graph(self):
     def _create_graph(self):
-        """定义并编译 StateGraph。"""
+        """定义并编译最终的、正确的 StateGraph 结构。"""
         builder = StateGraph(AgentState)
         builder = StateGraph(AgentState)
 
 
-        # 定义节点
+        # 定义所有需要的节点
         builder.add_node("agent", self._agent_node)
         builder.add_node("agent", self._agent_node)
+        builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
         builder.add_node("tools", ToolNode(self.tools))
         builder.add_node("tools", ToolNode(self.tools))
+        builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
+        builder.add_node("format_final_response", self._format_final_response_node)
 
 
-        # 定义边
+        # 建立正确的边连接
         builder.set_entry_point("agent")
         builder.set_entry_point("agent")
         builder.add_conditional_edges(
         builder.add_conditional_edges(
             "agent",
             "agent",
             self._should_continue,
             self._should_continue,
-            {"continue": "tools", "end": END}
+            {
+                "continue": "prepare_tool_input",
+                "end": "format_final_response"
+            }
         )
         )
-        builder.add_edge("tools", "agent")
+        builder.add_edge("prepare_tool_input", "tools")
+        builder.add_edge("tools", "update_state_after_tool")
+        builder.add_edge("update_state_after_tool", "agent")
+        builder.add_edge("format_final_response", END)
 
 
-        # 编译图,并传入 checkpointer
         return builder.compile(checkpointer=self.checkpointer)
         return builder.compile(checkpointer=self.checkpointer)
 
 
     def _should_continue(self, state: AgentState) -> str:
     def _should_continue(self, state: AgentState) -> str:
-        """判断是否需要继续调用工具。"""
+        """判断是继续调用工具还是结束。"""
         last_message = state["messages"][-1]
         last_message = state["messages"][-1]
-        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
-            return "end"
-        return "continue"
+        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
+            return "continue"
+        return "end"
 
 
     def _agent_node(self, state: AgentState) -> Dict[str, Any]:
     def _agent_node(self, state: AgentState) -> Dict[str, Any]:
-        """Agent 节点:调用 LLM 进行思考和决策。"""
+        """Agent 节点:只负责调用 LLM 并返回其输出。"""
         logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
         logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
         
         
-        # LangChain 会自动从 state 中提取 messages 传递给 LLM
-        response = self.llm_with_tools.invoke(state["messages"])
-        logger.info(f"   LLM 返回: {response.pretty_print()}")
+        messages_for_llm = list(state["messages"])
+        if state.get("suggested_next_step"):
+            instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
+            messages_for_llm.append(SystemMessage(content=instruction))
+
+        response = self.llm_with_tools.invoke(messages_for_llm)
+        logger.info(f"   LLM Response: {response.pretty_print()}")
+        
+        # 只返回消息,不承担其他职责
         return {"messages": [response]}
         return {"messages": [response]}
+    
+    def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+        """
+        信息组装节点:为需要上下文的工具注入历史消息。
+        """
+        logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
+        
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return {"messages": [last_message]}
+
+        # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
+        new_tool_calls = []
+        for tool_call in last_message.tool_calls:
+            if tool_call["name"] == "generate_sql":
+                logger.info("   检测到 generate_sql 调用,注入历史消息。")
+                # 复制一份以避免修改原始 tool_call
+                modified_args = tool_call["args"].copy()
+                
+                # 将消息对象列表转换为可序列化的字典列表
+                serializable_history = []
+                for msg in state["messages"]:
+                    serializable_history.append({
+                        "type": msg.type,
+                        "content": msg.content
+                    })
+                
+                modified_args["history_messages"] = serializable_history
+                logger.info(f"   注入了 {len(serializable_history)} 条历史消息")
+                
+                new_tool_calls.append({
+                    "name": tool_call["name"],
+                    "args": modified_args,
+                    "id": tool_call["id"],
+                })
+            else:
+                new_tool_calls.append(tool_call)
+        
+        # 用包含修改后参数的新消息替换掉原来的
+        last_message.tool_calls = new_tool_calls
+        return {"messages": [last_message]}
+
+    def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
+        """在工具执行后,更新 suggested_next_step。"""
+        logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
+        
+        last_tool_message = state['messages'][-1]
+        tool_name = last_tool_message.name
+        tool_output = last_tool_message.content
+        next_step = None
+
+        if tool_name == 'generate_sql':
+            if "失败" in tool_output or "无法生成" in tool_output:
+                next_step = 'answer_with_common_sense'
+            else:
+                next_step = 'valid_sql'
+        
+        elif tool_name == 'valid_sql':
+            if "失败" in tool_output:
+                next_step = 'analyze_validation_error'
+            else:
+                next_step = 'run_sql'
+
+        elif tool_name == 'run_sql':
+            next_step = 'summarize_final_answer'
+            
+        logger.info(f"   Tool '{tool_name}' executed. Suggested next step: {next_step}")
+        return {"suggested_next_step": next_step}
+
+    def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
+        """最终输出格式化节点。"""
+        logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
+        last_message = state['messages'][-1]
+        last_message.content = f"[Formatted Output]\n{last_message.content}"
+        return {"messages": [last_message]}
 
 
     async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
     async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
         """
         """
@@ -134,34 +231,29 @@ class CustomReactAgent:
         if not thread_id:
         if not thread_id:
             thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
             thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
             logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
             logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
-
-        config = {"configurable": {"thread_id": thread_id}}
         
         
-        # 定义输入
+        config = {
+            "configurable": {
+                "thread_id": thread_id,
+            }
+        }
+        
         inputs = {
         inputs = {
             "messages": [HumanMessage(content=message)],
             "messages": [HumanMessage(content=message)],
             "user_id": user_id,
             "user_id": user_id,
             "thread_id": thread_id,
             "thread_id": thread_id,
+            "suggested_next_step": None,
         }
         }
 
 
-        final_state = None
         try:
         try:
-            logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
-            # 使用 ainvoke 来执行完整的图流程
             final_state = await self.agent_executor.ainvoke(inputs, config)
             final_state = await self.agent_executor.ainvoke(inputs, config)
-            
-            if final_state and final_state.get("messages"):
-                answer = final_state["messages"][-1].content
-                logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
-                return {"success": True, "answer": answer, "thread_id": thread_id}
-            else:
-                 logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
-                 return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
-
+            answer = final_state["messages"][-1].content
+            logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
+            return {"success": True, "answer": answer, "thread_id": thread_id}
         except Exception as e:
         except Exception as e:
             logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
             logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
             return {"success": False, "error": str(e), "thread_id": thread_id}
             return {"success": False, "error": str(e), "thread_id": thread_id}
-
+    
     async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
     async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
         """从 checkpointer 获取指定线程的对话历史。"""
         """从 checkpointer 获取指定线程的对话历史。"""
         if not self.checkpointer:
         if not self.checkpointer:

+ 109 - 0
test/custom_react_agent/community_help_request.md

@@ -0,0 +1,109 @@
+# 紧急求助:LangGraph 中向工具传递包含中文的复杂参数时,遭遇双重JSON编码和序列化错误
+
+大家好,
+
+我们正在使用 LangGraph 构建一个基于 `StateGraph` 的 ReAct Agent,其核心需求是在多轮对话中能够理解并利用上下文生成 SQL 查询。为了实现这一点,我们设计了一个流程,需要将完整的对话历史(`history_messages`)作为参数,传递给我们自定义的 `generate_sql` 工具。
+
+然而,当对话历史中包含中文字符时,我们在 LangChain 的工具调用序列化阶段遇到了一个棘手的 `JSONDecodeError`,似乎是由于双重 JSON 编码或不当的 `\uXXXX` 转义引起的。我们已经尝试了多种方案,但都未能解决。恳请社区的专家们不吝赐教!
+
+---
+
+## 1. 核心目标与架构
+
+我们的目标是让 `generate_sql` 工具能够访问完整的对话历史,以便在处理“这个服务区怎么样?”这类指代性问题时,能够知道“这个服务区”具体指代的是哪一个。
+
+我们的 `StateGraph` 设计如下:
+
+```mermaid
+graph TD
+    A[START] --> B(agent_node);
+    B --> C{有工具调用?};
+    C -- 是 --> D(prepare_tool_input_node);
+    C -- 否 --> G[END];
+    D --> E(tool_node);
+    E --> F(update_state_after_tool_node);
+    F --> B;
+```
+
+- **`prepare_tool_input_node`**: 这个节点的关键职责是,当检测到 `agent_node` 决定调用 `generate_sql` 时,从 `state` 中提取完整的 `messages` 列表,并将其作为 `history_messages` 参数注入到该工具调用的 `args` 中。
+
+## 2. 遇到的核心问题
+
+当 `prepare_tool_input_node` 成功将 `history_messages`(一个包含 `{'type': 'human', 'content': '你好'}` 这样字典的列表)注入后,图在继续执行时崩溃。
+
+**错误日志如下:**
+
+```
+Invalid Tool Calls:
+  generate_sql (call_e58f408879664da99cd18d)
+ Call ID: call_e58f408879664da99cd18d
+  Error: Function generate_sql arguments:
+
+{"question": "\u8bf7\u95ee\u8fd9\u4e2a\u9ad8\u901f...ff1f", "history_messages": [{"type": "human", "content": "\u8bf7\u95ee\u7cfb\u7edf..."}, ...]}
+
+are not valid JSON. Received JSONDecodeError Invalid \escape: line 1 column 1539 (char 1538)
+For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE
+```
+
+## 3. 我们的问题分析
+
+从日志可以看出,传递给 `generate_sql` 的 `args` 字典,在被序列化时,所有的中文字符都被转换成了 `\uXXXX` 的 ASCII 编码格式。
+
+我们推断,问题根源在于 LangChain 在准备将 `tool_calls` 发送给 LLM API 或进行内部处理时,**强制使用 `json.dumps(..., ensure_ascii=True)` 对 `args` 字典进行了序列化**。当这个已经被编码的、包含大量 `\` 转义符的字符串在后续流程中被再次当作 JSON 解析时,便会因为非法的 `\u` 转义序列而导致 `JSONDecodeError`。这似乎是一种我们无法轻易配置的“双重编码”问题。
+
+## 4. 已尝试的失败方案
+
+1.  **在 `ChatOpenAI` 初始化时设置 `ensure_ascii=False`**:
+    - 我们尝试通过 `model_kwargs={"extra_body": {"misc": {"ensure_ascii": False}}}` 来影响序列化行为。
+    - **结果**: 这只影响了从 LLM 返回的最终答案的渲染,但**未能改变 LangChain 对工具调用参数的序列化行为**,错误依旧。
+
+2.  **在工具端解码**:
+    - 我们尝试在 `generate_sql` 函数内部,对接收到的 `history_messages` 字符串进行 `json.loads()` 或其他形式的解码。
+    - **结果**: 失败。因为错误发生在 LangChain 调用我们工具**之前**的序列化阶段,程序流程根本没有机会进入到我们的工具函数内部。
+
+## 5. 寻求帮助的核心问题
+
+我们感觉陷入了一个两难的境地。为了实现上下文感知,我们必须向工具传递复杂的数据;但 LangChain 的序列化机制似乎不允许这样做,尤其是当数据包含非 ASCII 字符时。
+
+我们想请教社区:
+
+1.  **在 LangGraph 中,向工具传递包含非 ASCII 字符的复杂数据结构(如对象或字典列表)的最佳实践到底是什么?**
+2.  **是否有任何方法可以覆盖或配置 `ToolNode` 或其底层对 `tool_calls` `args` 的序列化行为,强制其使用 `ensure_ascii=False`?**
+3.  如果这条路走不通,是否有其他更优雅、更推荐的设计模式,来解决“需要感知完整对话历史的工具”这一常见的场景?(例如,除了我们正在尝试的“参数注入”模式外,还有没有其他的可能性?)
+
+---
+
+### 附:关键代码片段
+
+**`_prepare_tool_input_node`**:
+```python
+def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+    last_message = state["messages"][-1]
+    if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+        return {}
+    
+    new_tool_calls = []
+    for tool_call in last_message.tool_calls:
+        if tool_call["name"] == "generate_sql":
+            # 将消息对象列表转换为可序列化的字典列表
+            serializable_history = [
+                {"type": msg.type, "content": msg.content} 
+                for msg in state["messages"]
+            ]
+            
+            modified_args = tool_call["args"].copy()
+            modified_args["history_messages"] = serializable_history
+            new_tool_calls.append({ # ... 重新构建 tool_call ... })
+    
+    last_message.tool_calls = new_tool_calls
+    return {"messages": [last_message]}
+```
+
+**`generate_sql` 工具签名**:
+```python
+@tool
+def generate_sql(question: str, history_messages: List[Dict[str, Any]]) -> str:
+    # ...
+```
+
+任何建议或思路都将对我们产生巨大的帮助。提前感谢大家! 

+ 40 - 20
test/custom_react_agent/sql_tools.py

@@ -5,50 +5,70 @@ import re
 import json
 import json
 import logging
 import logging
 from langchain_core.tools import tool
 from langchain_core.tools import tool
+from pydantic.v1 import BaseModel, Field
+from typing import List, Dict, Any
 import pandas as pd
 import pandas as pd
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
-# --- 工具函数 ---
+# --- Pydantic Schema for Tool Arguments ---
 
 
-@tool
-def generate_sql(question: str) -> str:
-    """
-    根据用户问题生成SQL查询语句。
+class GenerateSqlArgs(BaseModel):
+    """Input schema for the generate_sql tool."""
+    question: str = Field(description="The user's question to be converted to SQL.")
+    history_messages: List[Dict[str, Any]] = Field(
+        default=[],
+        description="The conversation history messages for context."
+    )
 
 
-    Args:
-        question: 用户的原始问题。
+# --- Tool Functions ---
 
 
-    Returns:
-        生成的SQL语句或错误信息。
+@tool(args_schema=GenerateSqlArgs)
+def generate_sql(question: str, history_messages: List[Dict[str, Any]] = None) -> str:
+    """
+    Generates an SQL query based on the user's question and the conversation history.
     """
     """
-    logger.info(f"🔧 [Tool] generate_sql - 问题: '{question}'")
+    logger.info(f"🔧 [Tool] generate_sql - Question: '{question}'")
+    
+    if history_messages is None:
+        history_messages = []
+    
+    logger.info(f"   History contains {len(history_messages)} messages.")
+
+    # Combine history and the current question to form a rich prompt
+    history_str = "\n".join([f"{msg['type']}: {msg.get('content', '') or ''}" for msg in history_messages])
+    enriched_question = f"""Based on the following conversation history:
+---
+{history_str}
+---
+
+Please provide an SQL query that answers this specific question: {question}"""
 
 
     try:
     try:
         from common.vanna_instance import get_vanna_instance
         from common.vanna_instance import get_vanna_instance
         vn = get_vanna_instance()
         vn = get_vanna_instance()
-        sql = vn.generate_sql(question)
+        sql = vn.generate_sql(enriched_question)
 
 
         if not sql or sql.strip() == "":
         if not sql or sql.strip() == "":
             if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
             if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
                 error_info = vn.last_llm_explanation
                 error_info = vn.last_llm_explanation
-                logger.warning(f"   Vanna返回了错误解释: {error_info}")
-                return f"数据库查询失败,具体原因:{error_info}"
+                logger.warning(f"   Vanna returned an explanation instead of SQL: {error_info}")
+                return f"Database query failed. Reason: {error_info}"
             else:
             else:
-                logger.warning("   Vanna未能生成SQL且无解释。")
-                return "无法生成SQL:问题可能不适合数据库查询"
+                logger.warning("   Vanna failed to generate SQL and provided no explanation.")
+                return "Could not generate SQL: The question may not be suitable for a database query."
 
 
         sql_upper = sql.upper().strip()
         sql_upper = sql.upper().strip()
         if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
         if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
-            logger.warning(f"   Vanna返回了疑似错误信息而非SQL: {sql}")
-            return f"数据库查询失败,具体原因:{sql}"
+            logger.warning(f"   Vanna returned a message that does not appear to be a valid SQL query: {sql}")
+            return f"Database query failed. Reason: {sql}"
 
 
-        logger.info(f"   ✅ 成功生成SQL: {sql}")
+        logger.info(f"   ✅ SQL Generated Successfully: {sql}")
         return sql
         return sql
 
 
     except Exception as e:
     except Exception as e:
-        logger.error(f"   SQL生成过程中发生异常: {e}", exc_info=True)
-        return f"SQL生成失败: {str(e)}"
+        logger.error(f"   An exception occurred during SQL generation: {e}", exc_info=True)
+        return f"SQL generation failed: {str(e)}"
 
 
 @tool
 @tool
 def valid_sql(sql: str) -> str:
 def valid_sql(sql: str) -> str:

+ 3 - 1
test/custom_react_agent/state.py

@@ -13,7 +13,9 @@ class AgentState(TypedDict):
         messages: 对话消息列表,使用 add_messages 聚合。
         messages: 对话消息列表,使用 add_messages 聚合。
         user_id: 当前用户ID。
         user_id: 当前用户ID。
         thread_id: 当前会话的线程ID。
         thread_id: 当前会话的线程ID。
+        suggested_next_step: 用于引导LLM下一步行动的建议指令。
     """
     """
     messages: Annotated[List[BaseMessage], add_messages]
     messages: Annotated[List[BaseMessage], add_messages]
     user_id: str
     user_id: str
-    thread_id: str 
+    thread_id: str
+    suggested_next_step: Optional[str]