Selaa lähdekoodia

./custom_react_agent 又恢复到json有ascii码的状态,现在准备修改解决这个问题。

wangxq 1 kuukausi sitten
vanhempi
commit
106a60d0f7

+ 289 - 0
test/agent_old.py

@@ -0,0 +1,289 @@
+"""
+基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
+"""
+import logging
+import json
+import pandas as pd
+from typing import List, Optional, Dict, Any, Tuple
+from contextlib import AsyncExitStack
+
+from langchain_openai import ChatOpenAI
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
+from langgraph.graph import StateGraph, END
+from langgraph.prebuilt import ToolNode
+from redis.asyncio import Redis
+try:
+    from langgraph.checkpoint.redis import AsyncRedisSaver
+except ImportError:
+    AsyncRedisSaver = None
+
+# 从新模块导入配置、状态和工具
+from . import config
+from .state import AgentState
+from .sql_tools import sql_tools
+
+logger = logging.getLogger(__name__)
+
+class CustomReactAgent:
+    """
+    一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
+    """
+    def __init__(self):
+        """私有构造函数,请使用 create() 类方法来创建实例。"""
+        self.llm = None
+        self.tools = None
+        self.agent_executor = None
+        self.checkpointer = None
+        self._exit_stack = None
+
+    @classmethod
+    async def create(cls):
+        """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
+        instance = cls()
+        await instance._async_init()
+        return instance
+
+    async def _async_init(self):
+        """异步初始化所有组件。"""
+        logger.info("🚀 开始初始化 CustomReactAgent...")
+
+        # 1. 初始化 LLM
+        self.llm = ChatOpenAI(
+            api_key=config.QWEN_API_KEY,
+            base_url=config.QWEN_BASE_URL,
+            model=config.QWEN_MODEL,
+            temperature=0.1,
+            model_kwargs={
+                "extra_body": {
+                    "enable_thinking": False,
+                    "misc": {
+                        "ensure_ascii": False
+                    }
+                }
+            }
+        )
+        logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
+
+        # 2. 绑定工具
+        self.tools = sql_tools
+        self.llm_with_tools = self.llm.bind_tools(self.tools)
+        logger.info(f"   已绑定 {len(self.tools)} 个工具。")
+
+        # 3. 初始化 Redis Checkpointer
+        if config.REDIS_ENABLED and AsyncRedisSaver is not None:
+            try:
+                self._exit_stack = AsyncExitStack()
+                checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
+                self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
+                await self.checkpointer.asetup()
+                logger.info(f"   AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
+            except Exception as e:
+                logger.error(f"   ❌ RedisSaver 初始化失败: {e}", exc_info=True)
+                if self._exit_stack:
+                    await self._exit_stack.aclose()
+                self.checkpointer = None
+        else:
+            logger.warning("   Redis 持久化功能已禁用。")
+
+        # 4. 构建 StateGraph
+        self.agent_executor = self._create_graph()
+        logger.info("   StateGraph 已构建并编译。")
+        logger.info("✅ CustomReactAgent 初始化完成。")
+
+    async def close(self):
+        """清理资源,关闭 Redis 连接。"""
+        if self._exit_stack:
+            await self._exit_stack.aclose()
+            self._exit_stack = None
+            self.checkpointer = None
+            logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
+
+    def _create_graph(self):
+        """定义并编译 StateGraph。"""
+        builder = StateGraph(AgentState)
+
+        # 定义节点
+        builder.add_node("agent", self._agent_node)
+        builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
+        builder.add_node("tools", ToolNode(self.tools))
+        builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
+        builder.add_node("format_final_response", self._format_final_response_node)
+
+        # 定义边
+        builder.set_entry_point("agent")
+        builder.add_conditional_edges(
+            "agent",
+            self._should_continue,
+            {
+                "continue": "prepare_tool_input",
+                "end": "format_final_response"
+            }
+        )
+        builder.add_edge("prepare_tool_input", "tools")
+        builder.add_edge("tools", "update_state_after_tool")
+        builder.add_edge("update_state_after_tool", "agent")
+        builder.add_edge("format_final_response", END)
+
+        # 编译图,并传入 checkpointer
+        return builder.compile(checkpointer=self.checkpointer)
+
+    def _should_continue(self, state: AgentState) -> str:
+        """判断是继续调用工具还是结束。"""
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return "end"
+        return "continue"
+
+    def _agent_node(self, state: AgentState) -> Dict[str, Any]:
+        """Agent 节点:调用 LLM 进行思考和决策。"""
+        logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
+        
+        messages_for_llm = list(state["messages"])
+        if state.get("suggested_next_step"):
+            instruction = f"基于之前的步骤,强烈建议你下一步执行 '{state['suggested_next_step']}' 操作。"
+            # 为了避免污染历史,可以考虑不同的注入方式,但这里为了简单直接添加
+            messages_for_llm.append(HumanMessage(content=instruction, name="system_instruction"))
+
+        response = self.llm_with_tools.invoke(messages_for_llm)
+        logger.info(f"   LLM 返回: {response.pretty_print()}")
+        return {"messages": [response]}
+    
+    def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+        """信息组装节点:为需要上下文的工具注入历史消息。"""
+        logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
+        
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return {}
+
+        # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
+        new_tool_calls = []
+        for tool_call in last_message.tool_calls:
+            if tool_call["name"] == "generate_sql":
+                logger.info("   检测到 generate_sql 调用,注入可序列化的历史消息。")
+                # 复制一份以避免修改原始 tool_call
+                modified_args = tool_call["args"].copy()
+                
+                # 将消息对象列表转换为可序列化的字典列表
+                serializable_history = []
+                for msg in state["messages"]:
+                    serializable_history.append({
+                        "type": msg.type,
+                        "content": msg.content
+                    })
+                
+                modified_args["history_messages"] = serializable_history
+                new_tool_calls.append({
+                    "name": tool_call["name"],
+                    "args": modified_args,
+                    "id": tool_call["id"],
+                })
+            else:
+                new_tool_calls.append(tool_call)
+        
+        # 用包含修改后参数的新消息替换掉原来的
+        last_message.tool_calls = new_tool_calls
+        return {"messages": [last_message]}
+
+    def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
+        """流程建议与错误处理节点:在工具执行后更新状态。"""
+        logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
+        
+        last_tool_message = state['messages'][-1]
+        tool_name = last_tool_message.name
+        tool_output = last_tool_message.content
+        next_step = None
+
+        if tool_name == 'generate_sql':
+            if "失败" in tool_output or "无法生成" in tool_output:
+                next_step = 'answer_with_common_sense'
+                logger.warning(f"   generate_sql 失败,建议下一步: {next_step}")
+            else:
+                next_step = 'valid_sql'
+                logger.info(f"   generate_sql 成功,建议下一步: {next_step}")
+        
+        elif tool_name == 'valid_sql':
+            if "失败" in tool_output:
+                next_step = 'analyze_validation_error'
+                logger.warning(f"   valid_sql 失败,建议下一步: {next_step}")
+            else:
+                next_step = 'run_sql'
+                logger.info(f"   valid_sql 成功,建议下一步: {next_step}")
+
+        elif tool_name == 'run_sql':
+            next_step = 'summarize_final_answer'
+            logger.info(f"   run_sql 执行完毕,建议下一步: {next_step}")
+
+        return {"suggested_next_step": next_step}
+
+    def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
+        """最终输出格式化节点(当前为占位符)。"""
+        logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']} - 准备格式化最终输出...")
+        # 这里可以添加一个标记,表示这是格式化后的输出
+        last_message = state['messages'][-1]
+        formatted_content = f"[Formatted Output]\n{last_message.content}"
+        last_message.content = formatted_content
+        return {"messages": [last_message]}
+
+    async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
+        """
+        处理用户聊天请求。
+        """
+        if not thread_id:
+            thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
+            logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
+
+        config = {"configurable": {"thread_id": thread_id}}
+        
+        # 定义输入
+        inputs = {
+            "messages": [HumanMessage(content=message)],
+            "user_id": user_id,
+            "thread_id": thread_id,
+            "suggested_next_step": None, # 初始化建议
+        }
+
+        final_state = None
+        try:
+            logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
+            # 使用 ainvoke 来执行完整的图流程
+            final_state = await self.agent_executor.ainvoke(inputs, config)
+            
+            if final_state and final_state.get("messages"):
+                answer = final_state["messages"][-1].content
+                logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
+                return {"success": True, "answer": answer, "thread_id": thread_id}
+            else:
+                 logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
+                 return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
+
+        except Exception as e:
+            logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
+            return {"success": False, "error": str(e), "thread_id": thread_id}
+
+    async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
+        """从 checkpointer 获取指定线程的对话历史。"""
+        if not self.checkpointer:
+            return []
+        
+        config = {"configurable": {"thread_id": thread_id}}
+        conversation_state = await self.checkpointer.get(config)
+        
+        if not conversation_state:
+            return []
+            
+        history = []
+        for msg in conversation_state['values'].get('messages', []):
+            if isinstance(msg, HumanMessage):
+                role = "human"
+            elif isinstance(msg, ToolMessage):
+                role = "tool"
+            else: # AIMessage
+                role = "ai"
+            
+            history.append({
+                "type": role,
+                "content": msg.content,
+                "tool_calls": getattr(msg, 'tool_calls', None)
+            })
+        return history 

+ 124 - 32
test/custom_react_agent/agent.py

@@ -8,7 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
 from contextlib import AsyncExitStack
 
 from langchain_openai import ChatOpenAI
-from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage
 from langgraph.graph import StateGraph, END
 from langgraph.prebuilt import ToolNode
 from redis.asyncio import Redis
@@ -21,6 +21,7 @@ except ImportError:
 from . import config
 from .state import AgentState
 from .sql_tools import sql_tools
+from langchain_core.runnables import RunnablePassthrough
 
 logger = logging.getLogger(__name__)
 
@@ -53,7 +54,14 @@ class CustomReactAgent:
             base_url=config.QWEN_BASE_URL,
             model=config.QWEN_MODEL,
             temperature=0.1,
-            extra_body={"enable_thinking": False}
+            model_kwargs={
+                "extra_body": {
+                    "enable_thinking": False,
+                    "misc": {
+                        "ensure_ascii": False
+                    }
+                }
+            }
         )
         logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
 
@@ -92,40 +100,129 @@ class CustomReactAgent:
             logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
 
     def _create_graph(self):
-        """定义并编译 StateGraph。"""
+        """定义并编译最终的、正确的 StateGraph 结构。"""
         builder = StateGraph(AgentState)
 
-        # 定义节点
+        # 定义所有需要的节点
         builder.add_node("agent", self._agent_node)
+        builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
         builder.add_node("tools", ToolNode(self.tools))
+        builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
+        builder.add_node("format_final_response", self._format_final_response_node)
 
-        # 定义边
+        # 建立正确的边连接
         builder.set_entry_point("agent")
         builder.add_conditional_edges(
             "agent",
             self._should_continue,
-            {"continue": "tools", "end": END}
+            {
+                "continue": "prepare_tool_input",
+                "end": "format_final_response"
+            }
         )
-        builder.add_edge("tools", "agent")
+        builder.add_edge("prepare_tool_input", "tools")
+        builder.add_edge("tools", "update_state_after_tool")
+        builder.add_edge("update_state_after_tool", "agent")
+        builder.add_edge("format_final_response", END)
 
-        # 编译图,并传入 checkpointer
         return builder.compile(checkpointer=self.checkpointer)
 
     def _should_continue(self, state: AgentState) -> str:
-        """判断是否需要继续调用工具。"""
+        """判断是继续调用工具还是结束。"""
         last_message = state["messages"][-1]
-        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
-            return "end"
-        return "continue"
+        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
+            return "continue"
+        return "end"
 
     def _agent_node(self, state: AgentState) -> Dict[str, Any]:
-        """Agent 节点:调用 LLM 进行思考和决策。"""
+        """Agent 节点:只负责调用 LLM 并返回其输出。"""
         logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
         
-        # LangChain 会自动从 state 中提取 messages 传递给 LLM
-        response = self.llm_with_tools.invoke(state["messages"])
-        logger.info(f"   LLM 返回: {response.pretty_print()}")
+        messages_for_llm = list(state["messages"])
+        if state.get("suggested_next_step"):
+            instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
+            messages_for_llm.append(SystemMessage(content=instruction))
+
+        response = self.llm_with_tools.invoke(messages_for_llm)
+        logger.info(f"   LLM Response: {response.pretty_print()}")
+        
+        # 只返回消息,不承担其他职责
         return {"messages": [response]}
+    
+    def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+        """
+        信息组装节点:为需要上下文的工具注入历史消息。
+        """
+        logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
+        
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return {"messages": [last_message]}
+
+        # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
+        new_tool_calls = []
+        for tool_call in last_message.tool_calls:
+            if tool_call["name"] == "generate_sql":
+                logger.info("   检测到 generate_sql 调用,注入历史消息。")
+                # 复制一份以避免修改原始 tool_call
+                modified_args = tool_call["args"].copy()
+                
+                # 将消息对象列表转换为可序列化的字典列表
+                serializable_history = []
+                for msg in state["messages"]:
+                    serializable_history.append({
+                        "type": msg.type,
+                        "content": msg.content
+                    })
+                
+                modified_args["history_messages"] = serializable_history
+                logger.info(f"   注入了 {len(serializable_history)} 条历史消息")
+                
+                new_tool_calls.append({
+                    "name": tool_call["name"],
+                    "args": modified_args,
+                    "id": tool_call["id"],
+                })
+            else:
+                new_tool_calls.append(tool_call)
+        
+        # 用包含修改后参数的新消息替换掉原来的
+        last_message.tool_calls = new_tool_calls
+        return {"messages": [last_message]}
+
+    def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
+        """在工具执行后,更新 suggested_next_step。"""
+        logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
+        
+        last_tool_message = state['messages'][-1]
+        tool_name = last_tool_message.name
+        tool_output = last_tool_message.content
+        next_step = None
+
+        if tool_name == 'generate_sql':
+            if "失败" in tool_output or "无法生成" in tool_output:
+                next_step = 'answer_with_common_sense'
+            else:
+                next_step = 'valid_sql'
+        
+        elif tool_name == 'valid_sql':
+            if "失败" in tool_output:
+                next_step = 'analyze_validation_error'
+            else:
+                next_step = 'run_sql'
+
+        elif tool_name == 'run_sql':
+            next_step = 'summarize_final_answer'
+            
+        logger.info(f"   Tool '{tool_name}' executed. Suggested next step: {next_step}")
+        return {"suggested_next_step": next_step}
+
+    def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
+        """最终输出格式化节点。"""
+        logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
+        last_message = state['messages'][-1]
+        last_message.content = f"[Formatted Output]\n{last_message.content}"
+        return {"messages": [last_message]}
 
     async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
         """
@@ -134,34 +231,29 @@ class CustomReactAgent:
         if not thread_id:
             thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
             logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
-
-        config = {"configurable": {"thread_id": thread_id}}
         
-        # 定义输入
+        config = {
+            "configurable": {
+                "thread_id": thread_id,
+            }
+        }
+        
         inputs = {
             "messages": [HumanMessage(content=message)],
             "user_id": user_id,
             "thread_id": thread_id,
+            "suggested_next_step": None,
         }
 
-        final_state = None
         try:
-            logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
-            # 使用 ainvoke 来执行完整的图流程
             final_state = await self.agent_executor.ainvoke(inputs, config)
-            
-            if final_state and final_state.get("messages"):
-                answer = final_state["messages"][-1].content
-                logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
-                return {"success": True, "answer": answer, "thread_id": thread_id}
-            else:
-                 logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
-                 return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
-
+            answer = final_state["messages"][-1].content
+            logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
+            return {"success": True, "answer": answer, "thread_id": thread_id}
         except Exception as e:
             logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
             return {"success": False, "error": str(e), "thread_id": thread_id}
-
+    
     async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
         """从 checkpointer 获取指定线程的对话历史。"""
         if not self.checkpointer:

+ 109 - 0
test/custom_react_agent/community_help_request.md

@@ -0,0 +1,109 @@
+# 紧急求助:LangGraph 中向工具传递包含中文的复杂参数时,遭遇双重JSON编码和序列化错误
+
+大家好,
+
+我们正在使用 LangGraph 构建一个基于 `StateGraph` 的 ReAct Agent,其核心需求是在多轮对话中能够理解并利用上下文生成 SQL 查询。为了实现这一点,我们设计了一个流程,需要将完整的对话历史(`history_messages`)作为参数,传递给我们自定义的 `generate_sql` 工具。
+
+然而,当对话历史中包含中文字符时,我们在 LangChain 的工具调用序列化阶段遇到了一个棘手的 `JSONDecodeError`,似乎是由于双重 JSON 编码或不当的 `\uXXXX` 转义引起的。我们已经尝试了多种方案,但都未能解决。恳请社区的专家们不吝赐教!
+
+---
+
+## 1. 核心目标与架构
+
+我们的目标是让 `generate_sql` 工具能够访问完整的对话历史,以便在处理“这个服务区怎么样?”这类指代性问题时,能够知道“这个服务区”具体指代的是哪一个。
+
+我们的 `StateGraph` 设计如下:
+
+```mermaid
+graph TD
+    A[START] --> B(agent_node);
+    B --> C{有工具调用?};
+    C -- 是 --> D(prepare_tool_input_node);
+    C -- 否 --> G[END];
+    D --> E(tool_node);
+    E --> F(update_state_after_tool_node);
+    F --> B;
+```
+
+- **`prepare_tool_input_node`**: 这个节点的关键职责是,当检测到 `agent_node` 决定调用 `generate_sql` 时,从 `state` 中提取完整的 `messages` 列表,并将其作为 `history_messages` 参数注入到该工具调用的 `args` 中。
+
+## 2. 遇到的核心问题
+
+当 `prepare_tool_input_node` 成功将 `history_messages`(一个包含 `{'type': 'human', 'content': '你好'}` 这样字典的列表)注入后,图在继续执行时崩溃。
+
+**错误日志如下:**
+
+```
+Invalid Tool Calls:
+  generate_sql (call_e58f408879664da99cd18d)
+ Call ID: call_e58f408879664da99cd18d
+  Error: Function generate_sql arguments:
+
+{"question": "\u8bf7\u95ee\u8fd9\u4e2a\u9ad8\u901f...ff1f", "history_messages": [{"type": "human", "content": "\u8bf7\u95ee\u7cfb\u7edf..."}, ...]}
+
+are not valid JSON. Received JSONDecodeError Invalid \escape: line 1 column 1539 (char 1538)
+For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE
+```
+
+## 3. 我们的问题分析
+
+从日志可以看出,传递给 `generate_sql` 的 `args` 字典,在被序列化时,所有的中文字符都被转换成了 `\uXXXX` 的 ASCII 编码格式。
+
+我们推断,问题根源在于 LangChain 在准备将 `tool_calls` 发送给 LLM API 或进行内部处理时,**强制使用 `json.dumps(..., ensure_ascii=True)` 对 `args` 字典进行了序列化**。当这个已经被编码的、包含大量 `\` 转义符的字符串在后续流程中被再次当作 JSON 解析时,便会因为非法的 `\u` 转义序列而导致 `JSONDecodeError`。这似乎是一种我们无法轻易配置的“双重编码”问题。
+
+## 4. 已尝试的失败方案
+
+1.  **在 `ChatOpenAI` 初始化时设置 `ensure_ascii=False`**:
+    - 我们尝试通过 `model_kwargs={"extra_body": {"misc": {"ensure_ascii": False}}}` 来影响序列化行为。
+    - **结果**: 这只影响了从 LLM 返回的最终答案的渲染,但**未能改变 LangChain 对工具调用参数的序列化行为**,错误依旧。
+
+2.  **在工具端解码**:
+    - 我们尝试在 `generate_sql` 函数内部,对接收到的 `history_messages` 字符串进行 `json.loads()` 或其他形式的解码。
+    - **结果**: 失败。因为错误发生在 LangChain 调用我们工具**之前**的序列化阶段,程序流程根本没有机会进入到我们的工具函数内部。
+
+## 5. 寻求帮助的核心问题
+
+我们感觉陷入了一个两难的境地。为了实现上下文感知,我们必须向工具传递复杂的数据;但 LangChain 的序列化机制似乎不允许这样做,尤其是当数据包含非 ASCII 字符时。
+
+我们想请教社区:
+
+1.  **在 LangGraph 中,向工具传递包含非 ASCII 字符的复杂数据结构(如对象或字典列表)的最佳实践到底是什么?**
+2.  **是否有任何方法可以覆盖或配置 `ToolNode` 或其底层对 `tool_calls` `args` 的序列化行为,强制其使用 `ensure_ascii=False`?**
+3.  如果这条路走不通,是否有其他更优雅、更推荐的设计模式,来解决“需要感知完整对话历史的工具”这一常见的场景?(例如,除了我们正在尝试的“参数注入”模式外,还有没有其他的可能性?)
+
+---
+
+### 附:关键代码片段
+
+**`_prepare_tool_input_node`**:
+```python
+def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
+    last_message = state["messages"][-1]
+    if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+        return {}
+    
+    new_tool_calls = []
+    for tool_call in last_message.tool_calls:
+        if tool_call["name"] == "generate_sql":
+            # 将消息对象列表转换为可序列化的字典列表
+            serializable_history = [
+                {"type": msg.type, "content": msg.content} 
+                for msg in state["messages"]
+            ]
+            
+            modified_args = tool_call["args"].copy()
+            modified_args["history_messages"] = serializable_history
+            new_tool_calls.append({ # ... 重新构建 tool_call ... })
+    
+    last_message.tool_calls = new_tool_calls
+    return {"messages": [last_message]}
+```
+
+**`generate_sql` 工具签名**:
+```python
+@tool
+def generate_sql(question: str, history_messages: List[Dict[str, Any]]) -> str:
+    # ...
+```
+
+任何建议或思路都将对我们产生巨大的帮助。提前感谢大家! 

+ 40 - 20
test/custom_react_agent/sql_tools.py

@@ -5,50 +5,70 @@ import re
 import json
 import logging
 from langchain_core.tools import tool
+from pydantic.v1 import BaseModel, Field
+from typing import List, Dict, Any
 import pandas as pd
 
 logger = logging.getLogger(__name__)
 
-# --- 工具函数 ---
+# --- Pydantic Schema for Tool Arguments ---
 
-@tool
-def generate_sql(question: str) -> str:
-    """
-    根据用户问题生成SQL查询语句。
+class GenerateSqlArgs(BaseModel):
+    """Input schema for the generate_sql tool."""
+    question: str = Field(description="The user's question to be converted to SQL.")
+    history_messages: List[Dict[str, Any]] = Field(
+        default=[],
+        description="The conversation history messages for context."
+    )
 
-    Args:
-        question: 用户的原始问题。
+# --- Tool Functions ---
 
-    Returns:
-        生成的SQL语句或错误信息。
+@tool(args_schema=GenerateSqlArgs)
+def generate_sql(question: str, history_messages: List[Dict[str, Any]] = None) -> str:
+    """
+    Generates an SQL query based on the user's question and the conversation history.
     """
-    logger.info(f"🔧 [Tool] generate_sql - 问题: '{question}'")
+    logger.info(f"🔧 [Tool] generate_sql - Question: '{question}'")
+    
+    if history_messages is None:
+        history_messages = []
+    
+    logger.info(f"   History contains {len(history_messages)} messages.")
+
+    # Combine history and the current question to form a rich prompt
+    history_str = "\n".join([f"{msg['type']}: {msg.get('content', '') or ''}" for msg in history_messages])
+    enriched_question = f"""Based on the following conversation history:
+---
+{history_str}
+---
+
+Please provide an SQL query that answers this specific question: {question}"""
 
     try:
         from common.vanna_instance import get_vanna_instance
         vn = get_vanna_instance()
-        sql = vn.generate_sql(question)
+        sql = vn.generate_sql(enriched_question)
 
         if not sql or sql.strip() == "":
             if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
                 error_info = vn.last_llm_explanation
-                logger.warning(f"   Vanna返回了错误解释: {error_info}")
-                return f"数据库查询失败,具体原因:{error_info}"
+                logger.warning(f"   Vanna returned an explanation instead of SQL: {error_info}")
+                return f"Database query failed. Reason: {error_info}"
             else:
-                logger.warning("   Vanna未能生成SQL且无解释。")
-                return "无法生成SQL:问题可能不适合数据库查询"
+                logger.warning("   Vanna failed to generate SQL and provided no explanation.")
+                return "Could not generate SQL: The question may not be suitable for a database query."
 
         sql_upper = sql.upper().strip()
         if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
-            logger.warning(f"   Vanna返回了疑似错误信息而非SQL: {sql}")
-            return f"数据库查询失败,具体原因:{sql}"
+            logger.warning(f"   Vanna returned a message that does not appear to be a valid SQL query: {sql}")
+            return f"Database query failed. Reason: {sql}"
 
-        logger.info(f"   ✅ 成功生成SQL: {sql}")
+        logger.info(f"   ✅ SQL Generated Successfully: {sql}")
         return sql
 
     except Exception as e:
-        logger.error(f"   SQL生成过程中发生异常: {e}", exc_info=True)
-        return f"SQL生成失败: {str(e)}"
+        logger.error(f"   An exception occurred during SQL generation: {e}", exc_info=True)
+        return f"SQL generation failed: {str(e)}"
 
 @tool
 def valid_sql(sql: str) -> str:

+ 3 - 1
test/custom_react_agent/state.py

@@ -13,7 +13,9 @@ class AgentState(TypedDict):
         messages: 对话消息列表,使用 add_messages 聚合。
         user_id: 当前用户ID。
         thread_id: 当前会话的线程ID。
+        suggested_next_step: 用于引导LLM下一步行动的建议指令。
     """
     messages: Annotated[List[BaseMessage], add_messages]
     user_id: str
-    thread_id: str 
+    thread_id: str
+    suggested_next_step: Optional[str]