|
@@ -8,7 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
|
|
from contextlib import AsyncExitStack
|
|
from contextlib import AsyncExitStack
|
|
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_openai import ChatOpenAI
|
|
-from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
|
|
|
|
|
|
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage
|
|
from langgraph.graph import StateGraph, END
|
|
from langgraph.graph import StateGraph, END
|
|
from langgraph.prebuilt import ToolNode
|
|
from langgraph.prebuilt import ToolNode
|
|
from redis.asyncio import Redis
|
|
from redis.asyncio import Redis
|
|
@@ -21,6 +21,7 @@ except ImportError:
|
|
from . import config
|
|
from . import config
|
|
from .state import AgentState
|
|
from .state import AgentState
|
|
from .sql_tools import sql_tools
|
|
from .sql_tools import sql_tools
|
|
|
|
+from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -53,7 +54,14 @@ class CustomReactAgent:
|
|
base_url=config.QWEN_BASE_URL,
|
|
base_url=config.QWEN_BASE_URL,
|
|
model=config.QWEN_MODEL,
|
|
model=config.QWEN_MODEL,
|
|
temperature=0.1,
|
|
temperature=0.1,
|
|
- extra_body={"enable_thinking": False}
|
|
|
|
|
|
+ model_kwargs={
|
|
|
|
+ "extra_body": {
|
|
|
|
+ "enable_thinking": False,
|
|
|
|
+ "misc": {
|
|
|
|
+ "ensure_ascii": False
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
)
|
|
)
|
|
logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
|
|
|
|
@@ -92,40 +100,129 @@ class CustomReactAgent:
|
|
logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
|
|
logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
|
|
|
|
|
|
def _create_graph(self):
|
|
def _create_graph(self):
|
|
- """定义并编译 StateGraph。"""
|
|
|
|
|
|
+ """定义并编译最终的、正确的 StateGraph 结构。"""
|
|
builder = StateGraph(AgentState)
|
|
builder = StateGraph(AgentState)
|
|
|
|
|
|
- # 定义节点
|
|
|
|
|
|
+ # 定义所有需要的节点
|
|
builder.add_node("agent", self._agent_node)
|
|
builder.add_node("agent", self._agent_node)
|
|
|
|
+ builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
|
|
builder.add_node("tools", ToolNode(self.tools))
|
|
builder.add_node("tools", ToolNode(self.tools))
|
|
|
|
+ builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
|
|
|
|
+ builder.add_node("format_final_response", self._format_final_response_node)
|
|
|
|
|
|
- # 定义边
|
|
|
|
|
|
+ # 建立正确的边连接
|
|
builder.set_entry_point("agent")
|
|
builder.set_entry_point("agent")
|
|
builder.add_conditional_edges(
|
|
builder.add_conditional_edges(
|
|
"agent",
|
|
"agent",
|
|
self._should_continue,
|
|
self._should_continue,
|
|
- {"continue": "tools", "end": END}
|
|
|
|
|
|
+ {
|
|
|
|
+ "continue": "prepare_tool_input",
|
|
|
|
+ "end": "format_final_response"
|
|
|
|
+ }
|
|
)
|
|
)
|
|
- builder.add_edge("tools", "agent")
|
|
|
|
|
|
+ builder.add_edge("prepare_tool_input", "tools")
|
|
|
|
+ builder.add_edge("tools", "update_state_after_tool")
|
|
|
|
+ builder.add_edge("update_state_after_tool", "agent")
|
|
|
|
+ builder.add_edge("format_final_response", END)
|
|
|
|
|
|
- # 编译图,并传入 checkpointer
|
|
|
|
return builder.compile(checkpointer=self.checkpointer)
|
|
return builder.compile(checkpointer=self.checkpointer)
|
|
|
|
|
|
def _should_continue(self, state: AgentState) -> str:
|
|
def _should_continue(self, state: AgentState) -> str:
|
|
- """判断是否需要继续调用工具。"""
|
|
|
|
|
|
+ """判断是继续调用工具还是结束。"""
|
|
last_message = state["messages"][-1]
|
|
last_message = state["messages"][-1]
|
|
- if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
|
|
- return "end"
|
|
|
|
- return "continue"
|
|
|
|
|
|
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
|
|
|
+ return "continue"
|
|
|
|
+ return "end"
|
|
|
|
|
|
def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
- """Agent 节点:调用 LLM 进行思考和决策。"""
|
|
|
|
|
|
+ """Agent 节点:只负责调用 LLM 并返回其输出。"""
|
|
logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
|
|
logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
|
|
|
|
|
|
- # LangChain 会自动从 state 中提取 messages 传递给 LLM
|
|
|
|
- response = self.llm_with_tools.invoke(state["messages"])
|
|
|
|
- logger.info(f" LLM 返回: {response.pretty_print()}")
|
|
|
|
|
|
+ messages_for_llm = list(state["messages"])
|
|
|
|
+ if state.get("suggested_next_step"):
|
|
|
|
+ instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
|
|
|
|
+ messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
|
+
|
|
|
|
+ response = self.llm_with_tools.invoke(messages_for_llm)
|
|
|
|
+ logger.info(f" LLM Response: {response.pretty_print()}")
|
|
|
|
+
|
|
|
|
+ # 只返回消息,不承担其他职责
|
|
return {"messages": [response]}
|
|
return {"messages": [response]}
|
|
|
|
+
|
|
|
|
+ def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
+ """
|
|
|
|
+ 信息组装节点:为需要上下文的工具注入历史消息。
|
|
|
|
+ """
|
|
|
|
+ logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
|
|
|
|
+
|
|
|
|
+ last_message = state["messages"][-1]
|
|
|
|
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
|
|
+ return {"messages": [last_message]}
|
|
|
|
+
|
|
|
|
+ # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
|
|
|
|
+ new_tool_calls = []
|
|
|
|
+ for tool_call in last_message.tool_calls:
|
|
|
|
+ if tool_call["name"] == "generate_sql":
|
|
|
|
+ logger.info(" 检测到 generate_sql 调用,注入历史消息。")
|
|
|
|
+ # 复制一份以避免修改原始 tool_call
|
|
|
|
+ modified_args = tool_call["args"].copy()
|
|
|
|
+
|
|
|
|
+ # 将消息对象列表转换为可序列化的字典列表
|
|
|
|
+ serializable_history = []
|
|
|
|
+ for msg in state["messages"]:
|
|
|
|
+ serializable_history.append({
|
|
|
|
+ "type": msg.type,
|
|
|
|
+ "content": msg.content
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ modified_args["history_messages"] = serializable_history
|
|
|
|
+ logger.info(f" 注入了 {len(serializable_history)} 条历史消息")
|
|
|
|
+
|
|
|
|
+ new_tool_calls.append({
|
|
|
|
+ "name": tool_call["name"],
|
|
|
|
+ "args": modified_args,
|
|
|
|
+ "id": tool_call["id"],
|
|
|
|
+ })
|
|
|
|
+ else:
|
|
|
|
+ new_tool_calls.append(tool_call)
|
|
|
|
+
|
|
|
|
+ # 用包含修改后参数的新消息替换掉原来的
|
|
|
|
+ last_message.tool_calls = new_tool_calls
|
|
|
|
+ return {"messages": [last_message]}
|
|
|
|
+
|
|
|
|
+ def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
+ """在工具执行后,更新 suggested_next_step。"""
|
|
|
|
+ logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
|
|
|
|
+
|
|
|
|
+ last_tool_message = state['messages'][-1]
|
|
|
|
+ tool_name = last_tool_message.name
|
|
|
|
+ tool_output = last_tool_message.content
|
|
|
|
+ next_step = None
|
|
|
|
+
|
|
|
|
+ if tool_name == 'generate_sql':
|
|
|
|
+ if "失败" in tool_output or "无法生成" in tool_output:
|
|
|
|
+ next_step = 'answer_with_common_sense'
|
|
|
|
+ else:
|
|
|
|
+ next_step = 'valid_sql'
|
|
|
|
+
|
|
|
|
+ elif tool_name == 'valid_sql':
|
|
|
|
+ if "失败" in tool_output:
|
|
|
|
+ next_step = 'analyze_validation_error'
|
|
|
|
+ else:
|
|
|
|
+ next_step = 'run_sql'
|
|
|
|
+
|
|
|
|
+ elif tool_name == 'run_sql':
|
|
|
|
+ next_step = 'summarize_final_answer'
|
|
|
|
+
|
|
|
|
+ logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
|
|
|
|
+ return {"suggested_next_step": next_step}
|
|
|
|
+
|
|
|
|
+ def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
+ """最终输出格式化节点。"""
|
|
|
|
+ logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
|
|
|
|
+ last_message = state['messages'][-1]
|
|
|
|
+ last_message.content = f"[Formatted Output]\n{last_message.content}"
|
|
|
|
+ return {"messages": [last_message]}
|
|
|
|
|
|
async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
"""
|
|
@@ -134,34 +231,29 @@ class CustomReactAgent:
|
|
if not thread_id:
|
|
if not thread_id:
|
|
thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
|
|
logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
|
|
-
|
|
|
|
- config = {"configurable": {"thread_id": thread_id}}
|
|
|
|
|
|
|
|
- # 定义输入
|
|
|
|
|
|
+ config = {
|
|
|
|
+ "configurable": {
|
|
|
|
+ "thread_id": thread_id,
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
inputs = {
|
|
inputs = {
|
|
"messages": [HumanMessage(content=message)],
|
|
"messages": [HumanMessage(content=message)],
|
|
"user_id": user_id,
|
|
"user_id": user_id,
|
|
"thread_id": thread_id,
|
|
"thread_id": thread_id,
|
|
|
|
+ "suggested_next_step": None,
|
|
}
|
|
}
|
|
|
|
|
|
- final_state = None
|
|
|
|
try:
|
|
try:
|
|
- logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
|
|
|
|
- # 使用 ainvoke 来执行完整的图流程
|
|
|
|
final_state = await self.agent_executor.ainvoke(inputs, config)
|
|
final_state = await self.agent_executor.ainvoke(inputs, config)
|
|
-
|
|
|
|
- if final_state and final_state.get("messages"):
|
|
|
|
- answer = final_state["messages"][-1].content
|
|
|
|
- logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
|
|
|
|
- return {"success": True, "answer": answer, "thread_id": thread_id}
|
|
|
|
- else:
|
|
|
|
- logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
|
|
|
|
- return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
|
|
|
|
-
|
|
|
|
|
|
+ answer = final_state["messages"][-1].content
|
|
|
|
+ logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
|
|
|
|
+ return {"success": True, "answer": answer, "thread_id": thread_id}
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
-
|
|
|
|
|
|
+
|
|
async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
"""从 checkpointer 获取指定线程的对话历史。"""
|
|
"""从 checkpointer 获取指定线程的对话历史。"""
|
|
if not self.checkpointer:
|
|
if not self.checkpointer:
|