Forráskód Böngészése

准备重构/test/custom_react_agent下的代码,已经完成了使用StateGraph替代create_react_agent()的方法,增加输出和中间处理节点。

wangxq 1 hónapja
szülő
commit
a415b7fe82

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 864 - 0
test/4. LangGraph 实现自治循环代理(ReAct)及事件流的应用.ipynb


+ 1 - 0
test/__init__.py

@@ -0,0 +1 @@
+# This file makes the 'test' directory a Python package. 

+ 1 - 0
test/custom_react_agent/__init__.py

@@ -0,0 +1 @@
+# This file makes the 'custom_react_agent' directory a Python package. 

+ 190 - 0
test/custom_react_agent/agent.py

@@ -0,0 +1,190 @@
+"""
+基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
+"""
+import logging
+import json
+import pandas as pd
+from typing import List, Optional, Dict, Any, Tuple
+from contextlib import AsyncExitStack
+
+from langchain_openai import ChatOpenAI
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
+from langgraph.graph import StateGraph, END
+from langgraph.prebuilt import ToolNode
+from redis.asyncio import Redis
+try:
+    from langgraph.checkpoint.redis import AsyncRedisSaver
+except ImportError:
+    AsyncRedisSaver = None
+
+# 从新模块导入配置、状态和工具
+from . import config
+from .state import AgentState
+from .sql_tools import sql_tools
+
+logger = logging.getLogger(__name__)
+
+class CustomReactAgent:
+    """
+    一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
+    """
+    def __init__(self):
+        """私有构造函数,请使用 create() 类方法来创建实例。"""
+        self.llm = None
+        self.tools = None
+        self.agent_executor = None
+        self.checkpointer = None
+        self._exit_stack = None
+
+    @classmethod
+    async def create(cls):
+        """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
+        instance = cls()
+        await instance._async_init()
+        return instance
+
+    async def _async_init(self):
+        """异步初始化所有组件。"""
+        logger.info("🚀 开始初始化 CustomReactAgent...")
+
+        # 1. 初始化 LLM
+        self.llm = ChatOpenAI(
+            api_key=config.QWEN_API_KEY,
+            base_url=config.QWEN_BASE_URL,
+            model=config.QWEN_MODEL,
+            temperature=0.1,
+            extra_body={"enable_thinking": False}
+        )
+        logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")
+
+        # 2. 绑定工具
+        self.tools = sql_tools
+        self.llm_with_tools = self.llm.bind_tools(self.tools)
+        logger.info(f"   已绑定 {len(self.tools)} 个工具。")
+
+        # 3. 初始化 Redis Checkpointer
+        if config.REDIS_ENABLED and AsyncRedisSaver is not None:
+            try:
+                self._exit_stack = AsyncExitStack()
+                checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
+                self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
+                await self.checkpointer.asetup()
+                logger.info(f"   AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
+            except Exception as e:
+                logger.error(f"   ❌ RedisSaver 初始化失败: {e}", exc_info=True)
+                if self._exit_stack:
+                    await self._exit_stack.aclose()
+                self.checkpointer = None
+        else:
+            logger.warning("   Redis 持久化功能已禁用。")
+
+        # 4. 构建 StateGraph
+        self.agent_executor = self._create_graph()
+        logger.info("   StateGraph 已构建并编译。")
+        logger.info("✅ CustomReactAgent 初始化完成。")
+
+    async def close(self):
+        """清理资源,关闭 Redis 连接。"""
+        if self._exit_stack:
+            await self._exit_stack.aclose()
+            self._exit_stack = None
+            self.checkpointer = None
+            logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
+
+    def _create_graph(self):
+        """定义并编译 StateGraph。"""
+        builder = StateGraph(AgentState)
+
+        # 定义节点
+        builder.add_node("agent", self._agent_node)
+        builder.add_node("tools", ToolNode(self.tools))
+
+        # 定义边
+        builder.set_entry_point("agent")
+        builder.add_conditional_edges(
+            "agent",
+            self._should_continue,
+            {"continue": "tools", "end": END}
+        )
+        builder.add_edge("tools", "agent")
+
+        # 编译图,并传入 checkpointer
+        return builder.compile(checkpointer=self.checkpointer)
+
+    def _should_continue(self, state: AgentState) -> str:
+        """判断是否需要继续调用工具。"""
+        last_message = state["messages"][-1]
+        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
+            return "end"
+        return "continue"
+
+    def _agent_node(self, state: AgentState) -> Dict[str, Any]:
+        """Agent 节点:调用 LLM 进行思考和决策。"""
+        logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
+        
+        # LangChain 会自动从 state 中提取 messages 传递给 LLM
+        response = self.llm_with_tools.invoke(state["messages"])
+        logger.info(f"   LLM 返回: {response.pretty_print()}")
+        return {"messages": [response]}
+
+    async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
+        """
+        处理用户聊天请求。
+        """
+        if not thread_id:
+            thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
+            logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
+
+        config = {"configurable": {"thread_id": thread_id}}
+        
+        # 定义输入
+        inputs = {
+            "messages": [HumanMessage(content=message)],
+            "user_id": user_id,
+            "thread_id": thread_id,
+        }
+
+        final_state = None
+        try:
+            logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
+            # 使用 ainvoke 来执行完整的图流程
+            final_state = await self.agent_executor.ainvoke(inputs, config)
+            
+            if final_state and final_state.get("messages"):
+                answer = final_state["messages"][-1].content
+                logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
+                return {"success": True, "answer": answer, "thread_id": thread_id}
+            else:
+                 logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
+                 return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
+
+        except Exception as e:
+            logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
+            return {"success": False, "error": str(e), "thread_id": thread_id}
+
+    async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
+        """从 checkpointer 获取指定线程的对话历史。"""
+        if not self.checkpointer:
+            return []
+        
+        config = {"configurable": {"thread_id": thread_id}}
+        conversation_state = await self.checkpointer.get(config)
+        
+        if not conversation_state:
+            return []
+            
+        history = []
+        for msg in conversation_state['values'].get('messages', []):
+            if isinstance(msg, HumanMessage):
+                role = "human"
+            elif isinstance(msg, ToolMessage):
+                role = "tool"
+            else: # AIMessage
+                role = "ai"
+            
+            history.append({
+                "type": role,
+                "content": msg.content,
+                "tool_calls": getattr(msg, 'tool_calls', None)
+            })
+        return history 

+ 26 - 0
test/custom_react_agent/config.py

@@ -0,0 +1,26 @@
+"""
+全局配置文件
+"""
+import os
+import logging
+
+# --- 项目根目录 ---
+# /test/custom_react_agent/config.py -> /
+PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+# --- LLM 配置 ---
+# 在这里写死你的千问API Key
+QWEN_API_KEY = "sk-db68e37f00974031935395315bfe07f0"
+QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
+QWEN_MODEL = "qwen3-235b-a22b"
+
+# --- Redis 配置 ---
+REDIS_URL = "redis://localhost:6379"
+REDIS_ENABLED = True
+
+# --- 日志配置 ---
+LOG_LEVEL = logging.INFO
+LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
+
+# --- Agent 配置 ---
+DEFAULT_USER_ID = "default-user" 

+ 112 - 0
test/custom_react_agent/redesign_summary.md

@@ -0,0 +1,112 @@
+# Custom React Agent 重构概要设计
+
+本文档总结了将原有基于 `create_react_agent` 的 Demo 重构为使用 `StateGraph` 的、具备强大上下文处理能力和流程控制能力的新版 Agent 的概要设计。
+
+## 1. 重构核心目标
+
+- **解决上下文遗忘问题**:确保 Agent 在多轮对话中,尤其是在连续调用 SQL 相关工具时,能够理解并利用之前的对话历史(如上文提到的实体“南城服务区”)。
+- **增强流程控制能力**:对 `generate_sql -> valid_sql -> run_sql` 这一固定流程进行强力引导,防止 LLM“忘记”执行下一步或执行错误,提高 Agent 的可靠性和可预测性。
+- **提升代码健壮性与可维护性**:通过模块化和清晰的职责划分,使代码更易于理解、调试和扩展。
+
+## 2. 最终 `StateGraph` 架构
+
+新架构的核心是一个包含 5 个节点的 `StateGraph`,它取代了原有的 `create_react_agent` 黑盒。
+
+![StateGraph Flow](https://mermaid.ink/img/pako:eNqNVc1qwzAQ_Zde5ZCHLHiVIqUEKjRAG3pQDxuCjM1OsbSwlGTqQIL__SrZdpqkh9zuzs5-d3eDCbygCmoy7dFm9JOv_Qz9Icvob7fso-996es_yz_1b-H4fTTX9e5rWz4etrutRgzHlDxJg2b3OYjSLnoa8HgrdfuRjQZH9r6g_FXd75LpwzT_vNX_8Cufhsnvbnl4-8Xu_sKnmHa2sCX_o7Ud9PsKOluN9J_a_ZOhc179yuVmyvqd7dv_Lltf9tFbifp4eH_XeIkcftEElv_V9N7webczuFf9jOqkehZQPgeZEtbyvMBtyZNO0M2PXqP6y_NfP9V5iZF6SOpPFOlZ0FuKUOtnOcv2jiGeth-PqCKvaNHdzQGfEPSeNRe3Nu4iqtuSPYthf0vnddOZhYzunvZ0uI9k-drffxtMfuLPTNd6u3eCXdOE409UttVPisR5WOY9ZgvtmvFvUYfzdetaulpePqPTPvO9nxNI9yd7VdqNxrzfN8OU-8O4Dqn-kOi9rv9C4EfHi_VfKOG-Y9-6tN7TvFZTy3-q3-m979c_fX_XO995f5EuPnfgt8l4U3I?type=png)
+
+```mermaid
+graph TD
+    A[START] --> B(agent_node);
+    B --> C{有工具调用?};
+    C -- 是 --> D(prepare_tool_input_node);
+    C -- 否 --> G(format_final_response_node);
+    D --> E(tool_node);
+    E --> F(update_state_after_tool_node);
+    F --> B;
+    G --> H[END];
+```
+
+### 2.1. 节点职责
+
+- **`agent_node` (决策者)**
+  - **输入**: 完整的 `state`,包含 `messages` 历史和 `suggested_next_step`。
+  - **职责**:
+    1.  读取完整的对话历史。
+    2.  读取 `state.suggested_next_step` 作为强烈的行动建议 (例如: `valid_sql`, `run_sql`, `analyze_error`)。
+    3.  通过提示工程,将建议和历史结合,让 LLM 做出决策。
+    4.  **输出**: 一个“草稿版”的 `tool_calls`,或决定直接回答的 `AIMessage`。
+
+- **`prepare_tool_input_node` (信息组装者)** - **(新增节点)**
+  - **位置**: `agent_node` 之后, `tool_node` 之前。
+  - **职责**:
+    1.  检查 `agent_node` 输出的 `tool_calls`。
+    2.  如果发现需要上下文的工具(如 `generate_sql`),则从 `state.messages` 中提取完整的对话历史。
+    3.  将提取的历史作为 `history_messages` 参数,**注入**到 `tool_calls` 的 `args` 中。
+  - **输出**: 一个“精装版”的、包含了完整上下文信息的 `tool_calls`。
+
+- **`tool_node` (执行者)**
+  - **职责**: 接收“精装版”的 `tool_calls`,并忠实地调用 `sql_tools.py` 中的工具函数。
+
+- **`update_state_after_tool_node` (流程建议与错误处理器)** - **(新增节点)**
+  - **位置**: `tool_node` 之后。
+  - **职责**:
+    1.  检查刚刚执行的工具名称及其返回结果(成功/失败)。
+    2.  根据预设的逻辑,智能地更新 `state.suggested_next_step` 字段,以精确引导下一步:
+        - **`generate_sql` 成功**: `suggested_next_step` -> `"valid_sql"`
+        - **`generate_sql` 失败**: `suggested_next_step` -> `"answer_with_common_sense"` (引导LLM基于常识回答或向用户解释)
+        - **`valid_sql` 成功**: `suggested_next_step` -> `"run_sql"`
+        - **`valid_sql` 失败**: `suggested_next_step` -> `"analyze_validation_error"` (引导LLM分析错误原因)
+        - **`run_sql` 执行后**: `suggested_next_step` -> `"summarize_final_answer"` (引导LLM基于数据总结)
+  - **输出**: 更新后的 `state`。
+
+- **`format_final_response_node` (最终输出格式化器)** - **(新增节点)**
+  - **位置**: 在 `agent_node` 决定直接回答后,图结束前。
+  - **职责 (v1 - 占位)**:
+    1.  **当前阶段**: 仅作为流程占位符,证明流程已正确进入此节点。
+    2.  在日志中打印一条明确的信息,如 `"[Node] format_final_response - 准备格式化最终输出..."`。
+  - **职责 (未来)**:
+    1.  从 `state` 中提取 LLM 的最终文字总结和最近一次 `run_sql` 的数据(如果存在)。
+    2.  将数据格式化为 Markdown 表格。
+    3.  将文字总结和数据表格合并成一个对用户友好的、结构化的最终答案。
+  - **输出**: 更新 `state` 中最后一条 `AIMessage` 的内容。
+
+## 3. `AgentState` 状态设计
+
+`state.py` 文件将定义 `StateGraph` 中流转的数据结构。
+
+```python
+from typing import TypedDict, Annotated, Optional, List
+from langchain_core.messages import BaseMessage
+
+class AgentState(TypedDict):
+    messages: Annotated[List[BaseMessage], add_messages]
+    user_id: str
+    thread_id: str
+    # 新增字段,用于引导 LLM 的下一步行动
+    suggested_next_step: Optional[str]
+```
+
+- **`messages`**: 核心字段,存储完整的、包含 `HumanMessage`, `AIMessage`, `ToolMessage` 的对话历史。
+- **`suggested_next_step`**: 流程控制的关键。它由 `update_state_after_tool_node` 写入,由 `agent_node` 读取,为 LLM 提供强力的流程引导。
+
+## 4. 工具签名与实现
+
+- **`sql_tools.py`**:
+  - `generate_sql(question: str, history_messages: List[BaseMessage]) -> str`:
+    - 修改其函数签名,明确要求传入 `history_messages`。
+    - 在其内部,将 `question` 和 `history_messages` 组合成更丰富的提示,再交给 Vanna 的 LLM 进行处理,从而解决上下文理解问题。
+  - `valid_sql` 和 `run_sql` 保持简单的输入输出。
+
+## 5. 日志与持久化
+
+- **日志**: 使用 Python 内置的 `logging` 模块,由 `config.py` 控制级别。在每个节点的入口和出口、关键的逻辑判断处打印详细日志,以便清晰地追踪 Agent 的思考和执行链路。
+- **持久化**: 完全复用并保留原有的 `AsyncRedisSaver` 机制。`CustomReactAgent` 在初始化时创建 `checkpointer`,并在编译 `StateGraph` 时传入,以实现自动的状态持久化。
+
+## 6. 优势总结
+
+1.  **双重上下文保障**:
+    - **数据上下文**: 通过 `prepare_tool_input_node` 确保 `generate_sql` 能获取完整的对话历史。
+    - **流程上下文**: 通过 `update_state_after_tool_node` 和 `suggested_next_step` 确保 Agent 遵循预设的执行流程。
+2.  **职责分离**: 每个节点职责单一(决策、准备数据、执行、更新状态),代码清晰,易于维护。
+3.  **高度可控与可预测**: 在给予 LLM 思考空间的同时,通过代码逻辑保证了核心流程的稳定性和可靠性。
+4.  **易于调试**: 详细的日志输出将使追踪和定位问题变得非常简单。 

+ 128 - 0
test/custom_react_agent/shell.py

@@ -0,0 +1,128 @@
+"""
+重构后的 CustomReactAgent 的交互式命令行客户端
+"""
+import asyncio
+import logging
+import sys
+import os
+
+# 动态地将项目根目录添加到 sys.path,以支持跨模块导入
+# 这使得脚本更加健壮,无论从哪里执行
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
+sys.path.insert(0, PROJECT_ROOT)
+
+# 从新模块导入 Agent 和配置 (使用相对导入)
+from .agent import CustomReactAgent
+from . import config
+
+# 配置日志
+logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT)
+logger = logging.getLogger(__name__)
+
+class CustomAgentShell:
+    """新 Agent 的交互式 Shell 客户端"""
+
+    def __init__(self, agent: CustomReactAgent):
+        """私有构造函数,请使用 create() 类方法。"""
+        self.agent = agent
+        self.user_id: str = config.DEFAULT_USER_ID
+        self.thread_id: str | None = None
+
+    @classmethod
+    async def create(cls):
+        """异步工厂方法,创建 Shell 实例。"""
+        agent = await CustomReactAgent.create()
+        return cls(agent)
+
+    async def close(self):
+        """关闭 Agent 资源。"""
+        if self.agent:
+            await self.agent.close()
+
+    async def start(self):
+        """启动 Shell 界面。"""
+        print("\n🚀 Custom React Agent Shell (StateGraph Version)")
+        print("=" * 50)
+        
+        # 获取用户ID
+        user_input = input(f"请输入您的用户ID (默认: {self.user_id}): ").strip()
+        if user_input:
+            self.user_id = user_input
+        
+        print(f"👤 当前用户: {self.user_id}")
+        # 这里可以加入显示历史会话的逻辑
+        
+        print("\n💬 开始对话 (输入 'exit' 或 'quit' 退出)")
+        print("-" * 50)
+        
+        await self._chat_loop()
+
+    async def _chat_loop(self):
+        """主要的聊天循环。"""
+        while True:
+            user_input = input(f"👤 [{self.user_id[:8]}]> ").strip()
+            
+            if not user_input:
+                continue
+            
+            if user_input.lower() in ['quit', 'exit']:
+                raise KeyboardInterrupt  # 优雅退出
+            
+            if user_input.lower() == 'new':
+                self.thread_id = None
+                print("🆕 已开始新会话。")
+                continue
+
+            if user_input.lower() == 'history':
+                await self._show_current_history()
+                continue
+            
+            # 正常对话
+            print("🤖 Agent 正在思考...")
+            result = await self.agent.chat(user_input, self.user_id, self.thread_id)
+            
+            if result.get("success"):
+                print(f"🤖 Agent: {result.get('answer')}")
+                # 更新 thread_id 以便在同一会话中继续
+                self.thread_id = result.get("thread_id")
+            else:
+                print(f"❌ 发生错误: {result.get('error')}")
+
+    async def _show_current_history(self):
+        """显示当前会话的历史记录。"""
+        if not self.thread_id:
+            print("当前没有活跃的会话。请先开始对话。")
+            return
+        
+        print(f"\n--- 对话历史: {self.thread_id} ---")
+        history = await self.agent.get_conversation_history(self.thread_id)
+        if not history:
+            print("无法获取历史或历史为空。")
+            return
+            
+        for msg in history:
+            print(f"[{msg['type']}] {msg['content']}")
+        print("--- 历史结束 ---")
+
+
+async def main():
+    """主函数入口"""
+    shell = None
+    try:
+        shell = await CustomAgentShell.create()
+        await shell.start()
+    except KeyboardInterrupt:
+        logger.info("\n👋 检测到退出指令,正在清理资源...")
+    except Exception as e:
+        logger.error(f"❌ 程序发生严重错误: {e}", exc_info=True)
+    finally:
+        if shell:
+            await shell.close()
+        print("✅ 程序已成功关闭。")
+
+if __name__ == "__main__":
+    try:
+        asyncio.run(main())
+    except KeyboardInterrupt:
+        # 这个捕获是为了处理在 main 之外的 Ctrl+C
+        print("\n👋 程序被强制退出。") 

+ 119 - 0
test/custom_react_agent/sql_tools.py

@@ -0,0 +1,119 @@
+"""
+数据库查询相关的工具集
+"""
+import re
+import json
+import logging
+from langchain_core.tools import tool
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+# --- 工具函数 ---
+
+@tool
+def generate_sql(question: str) -> str:
+    """
+    根据用户问题生成SQL查询语句。
+
+    Args:
+        question: 用户的原始问题。
+
+    Returns:
+        生成的SQL语句或错误信息。
+    """
+    logger.info(f"🔧 [Tool] generate_sql - 问题: '{question}'")
+
+    try:
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        sql = vn.generate_sql(question)
+
+        if not sql or sql.strip() == "":
+            if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
+                error_info = vn.last_llm_explanation
+                logger.warning(f"   Vanna返回了错误解释: {error_info}")
+                return f"数据库查询失败,具体原因:{error_info}"
+            else:
+                logger.warning("   Vanna未能生成SQL且无解释。")
+                return "无法生成SQL:问题可能不适合数据库查询"
+
+        sql_upper = sql.upper().strip()
+        if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
+            logger.warning(f"   Vanna返回了疑似错误信息而非SQL: {sql}")
+            return f"数据库查询失败,具体原因:{sql}"
+
+        logger.info(f"   ✅ 成功生成SQL: {sql}")
+        return sql
+
+    except Exception as e:
+        logger.error(f"   SQL生成过程中发生异常: {e}", exc_info=True)
+        return f"SQL生成失败: {str(e)}"
+
+@tool
+def valid_sql(sql: str) -> str:
+    """
+    验证SQL语句的正确性和安全性。
+
+    Args:
+        sql: 待验证的SQL语句。
+
+    Returns:
+        验证结果。
+    """
+    logger.info(f"🔧 [Tool] valid_sql - 待验证SQL (前100字符): {sql[:100]}...")
+
+    if not sql or sql.strip() == "":
+        logger.warning("   SQL验证失败:SQL语句为空。")
+        return "SQL验证失败:SQL语句为空"
+
+    sql_upper = sql.upper().strip()
+    if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
+         logger.warning(f"   SQL验证失败:不是有效的查询语句。SQL: {sql}")
+         return "SQL验证失败:不是有效的查询语句"
+    
+    # 简单的安全检查
+    dangerous_patterns = [r'\bDROP\b', r'\bDELETE\b', r'\bTRUNCATE\b', r'\bALTER\b', r'\bCREATE\b', r'\bUPDATE\b']
+    for pattern in dangerous_patterns:
+        if re.search(pattern, sql_upper):
+            keyword = pattern.replace(r'\b', '').replace('\\', '')
+            logger.error(f"   SQL验证失败:包含危险操作 {keyword}。SQL: {sql}")
+            return f"SQL验证失败:包含危险操作 {keyword}"
+
+    logger.info(f"   ✅ SQL验证通过。")
+    return "SQL验证通过:语法正确"
+
+@tool
+def run_sql(sql: str) -> str:
+    """
+    执行SQL查询并以JSON字符串格式返回结果。
+
+    Args:
+        sql: 待执行的SQL语句。
+
+    Returns:
+        JSON字符串格式的查询结果,或包含错误的JSON字符串。
+    """
+    logger.info(f"🔧 [Tool] run_sql - 待执行SQL (前100字符): {sql[:100]}...")
+
+    try:
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        df = vn.run_sql(sql)
+
+        if df is None:
+            logger.warning("   SQL执行成功,但查询结果为空。")
+            result = {"status": "success", "data": [], "message": "查询无结果"}
+            return json.dumps(result, ensure_ascii=False)
+
+        logger.info(f"   ✅ SQL执行成功,返回 {len(df)} 条记录。")
+        # 将DataFrame转换为JSON,并妥善处理datetime等特殊类型
+        return df.to_json(orient='records', date_format='iso')
+
+    except Exception as e:
+        logger.error(f"   SQL执行过程中发生异常: {e}", exc_info=True)
+        error_result = {"status": "error", "error_message": str(e)}
+        return json.dumps(error_result, ensure_ascii=False)
+
+# 将所有工具函数收集到一个列表中,方便Agent导入和使用
+sql_tools = [generate_sql, valid_sql, run_sql] 

+ 19 - 0
test/custom_react_agent/state.py

@@ -0,0 +1,19 @@
+"""
+定义 StateGraph 的状态
+"""
+from typing import TypedDict, Annotated, Optional, List, Dict, Any
+from langchain_core.messages import BaseMessage
+from langgraph.graph.message import add_messages
+
+class AgentState(TypedDict):
+    """
+    StateGraph 中流转的状态对象定义。
+
+    Attributes:
+        messages: 对话消息列表,使用 add_messages 聚合。
+        user_id: 当前用户ID。
+        thread_id: 当前会话的线程ID。
+    """
+    messages: Annotated[List[BaseMessage], add_messages]
+    user_id: str
+    thread_id: str 

+ 97 - 0
test/custom_react_agent/参考方案.md

@@ -0,0 +1,97 @@
+明白了,我将为你准备一个基于 StateGraph 重构的方案,保留 ReAct 模型结构(即 LangChain Agent 使用 tools 推理的能力),同时确保:
+
+* 使用 RedisSaver 进行异步 checkpoint 和历史追踪
+* 每个节点都能访问上下文(如 user\_id、thread\_id)
+* 明确 thread\_id 作为会话隔离标识
+* 保留工具调用逻辑和错误处理能力
+* 结构上为向生产过渡做准备,如拆出 config.py 和增强日志输出
+
+我将整理一份详细的改造建议和新的代码框架提案,请稍等,我整理好后马上给你。
+
+
+# 重构整体架构
+
+* **模块划分**:引入 `config.py` 管理全局配置(模型名称、`redis_url`、日志等级等),主逻辑放在如 `qwen3_agent.py` 的模块中,将 Graph 定义、Agent 类等拆分。`client_api.py` 与 `shell.py` 保持调用接口不变(`get_conversation_history`、`get_user_conversations` 等),只是内部调用新 Agent 类即可。
+* **状态定义**:为 StateGraph 定义自定义状态(如 `ChatState`),包括至少 `messages`(保存对话列表)、`thread_id`、`user_id` 等字段;可使用 `typing.Annotated` 和 `add_messages` 注解以维护消息列表。例如:`class ChatState(TypedDict): messages: Annotated[List[BaseMessage], add_messages]; thread_id: str; user_id: str`。。这样所有节点(包括工具函数)都可访问上下文信息;必要时可在工具函数参数中使用 `Annotated[..., InjectedState]` 注解注入状态字段。
+* **配置文件**:`config.py` 示例:
+
+  ```python
+  MODEL_NAME = "qwen3-235b-a22b"
+  REDIS_URL = "redis://localhost:6379"
+  REDIS_ENABLED = True
+  LOG_LEVEL = logging.INFO
+  ```
+
+  主文件中读取这些常量来初始化模型、Redis、日志等配置。
+
+## Graph 定义流程示意
+
+重构后使用 LangGraph 的 `StateGraph` 明确描述 ReAct 流程。基本流程为:**用户输入→LLM(助手)思考→根据需要调用工具→工具返回结果→LLM 继续思考→…→最终输出答案**。可参考如下伪代码流程:
+
+```python
+from langgraph.graph import StateGraph, START
+from langgraph.prebuilt import ToolNode, tools_condition
+
+# 创建 StateGraph,指定状态类型 ChatState
+builder = StateGraph(ChatState)
+
+# 节点:assistant 调用 LLM(绑定工具)
+def assistant_node(state: ChatState) -> dict:
+    # 调用绑定工具的模型,输入当前消息列表
+    response = llm.bind_tools(tools).invoke(state["messages"])
+    return {"messages": response}
+
+builder.add_node("assistant", assistant_node)
+builder.add_node("tools", ToolNode(tools))  # 工具节点
+
+# 边:开始进入 assistant 节点
+builder.add_edge(START, "assistant")
+# 如果 assistant 输出包含工具调用,则流转到 tools 节点,否则结束
+builder.add_conditional_edges(
+    "assistant",
+    tools_condition  # 有工具调用则进 tools,否则结束
+)
+# tools 处理后回到 assistant 节点(形成循环)
+builder.add_edge("tools", "assistant")
+
+# 编译 StateGraph(稍后传入 checkpointer)
+graph = builder.compile()
+```
+
+如示例所示,**ReAct 图** 有两个核心节点:“assistant” 节点用于调用模型并产生 `ToolCall`;“tools” 节点用于并行执行这些工具调用。通过 `add_conditional_edges` 将 **assistant→tools** 或 **assistant→END** 的流转条件化(`tools_condition` 判断最新 AIMessage 中是否有工具调用)。整体流程为:用户消息进 `assistant`,若有工具调用则进入 `tools` 执行后再回 `assistant` 继续,直至无工具调用后结束并返回最终答案。
+
+## RedisSaver 持久化生命周期
+
+采用 `langgraph-checkpoint-redis` 提供的 **AsyncRedisSaver** 进行短期(线程级)持久化,以便跨会话保持对话历史。**初始化**时,用 Redis URL 创建 AsyncRedisSaver 实例并 `await saver.asetup()` 建立所需索引,如:
+
+```python
+self._exit_stack = AsyncExitStack()
+saver_mgr = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
+self.checkpointer = await self._exit_stack.enter_async_context(saver_mgr)
+await self.checkpointer.asetup()
+```
+
+参照官方示例,可在编译图时将 `checkpointer` 传入 `StateGraph.compile(checkpointer=...)`。这样,图的每次执行都会自动保存状态到 Redis。**关闭**时,通过 `await self._exit_stack.aclose()` 释放 Redis 连接(或使用 `async with AsyncRedisSaver.from_conn_string(...)` 上下文管理器)。
+
+`thread_id` 用作对话流水号:首次对话时自动生成(如 `userID:timestamp`),并在后续调用时传入图的 `config` 部分(`{"configurable":{"thread_id": thread_id}}`),以检索或续接该会话的历史。通过 `checkpointer.get(config)` 可异步取回当前线程的全部消息列表,以实现 `get_conversation_history` 等功能(可参考原代码的取值逻辑)。
+
+## 日志输出与追踪
+
+* **日志框架**:使用 Python 内置的 `logging` 模块,设置基本配置输出到控制台。例如:
+
+  ```python
+  import logging
+  logging.basicConfig(
+      level=config.LOG_LEVEL,
+      format="%(asctime)s %(levelname)s: %(message)s"
+  )
+  logger = logging.getLogger(__name__)
+  ```
+
+  在关键步骤(如初始化模型/Redis、节点执行前后、工具调用等)使用 `logger.info()/debug()/warning()` 记录状态和统计信息,以便实时追踪流程。
+* **控制台追踪**:避免过于复杂的日志管理,简单的 `print` 或 `logger` 输出即可。建议在 `assistant` 节点前后输出提示(如“调用模型,Thread=xxx”),在工具函数开始时输出工具名和参数,在异常时使用 `logger.error()` 打印堆栈信息。这样可在终端实时观察 Agent 的运行轨迹,而无需额外工具监控。
+* **日志等级**:通过 `config.py` 中的 `LOG_LEVEL` 配置调试信息输出级别(如 DEBUG, INFO)。开发时可设为 DEBUG 以观察细节,生产时切换为 INFO 以减少冗余输出。
+
+以上方案在保留原有 LangChain Agent/工具调用风格的同时,采用 StateGraph 明确化流程,各节点可访问共享的上下文状态。使用 AsyncRedisSaver 实现对话历史的持久化,利用 `thread_id` 管理不同会话;日志输出则通过标准 `logging` 模块实现可控的实时跟踪输出。
+
+**参考资料:** LangGraph ReAct 架构示例;RedisSaver 用法指南;状态注入示例。

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott