|
@@ -11,7 +11,7 @@ from langchain_openai import ChatOpenAI
|
|
|
from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
|
|
|
from langgraph.graph import StateGraph, END
|
|
|
from langgraph.prebuilt import ToolNode
|
|
|
-from redis.asyncio import Redis
|
|
|
+import redis.asyncio as redis
|
|
|
try:
|
|
|
from langgraph.checkpoint.redis import AsyncRedisSaver
|
|
|
except ImportError:
|
|
@@ -43,6 +43,7 @@ class CustomReactAgent:
|
|
|
self.agent_executor = None
|
|
|
self.checkpointer = None
|
|
|
self._exit_stack = None
|
|
|
+ self.redis_client = None
|
|
|
|
|
|
@classmethod
|
|
|
async def create(cls):
|
|
@@ -55,7 +56,16 @@ class CustomReactAgent:
|
|
|
"""异步初始化所有组件。"""
|
|
|
logger.info("🚀 开始初始化 CustomReactAgent...")
|
|
|
|
|
|
- # 1. 初始化 LLM
|
|
|
+ # 1. 初始化异步Redis客户端
|
|
|
+ self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+ try:
|
|
|
+ await self.redis_client.ping()
|
|
|
+ logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f" ❌ Redis连接失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # 2. 初始化 LLM
|
|
|
self.llm = ChatOpenAI(
|
|
|
api_key=config.QWEN_API_KEY,
|
|
|
base_url=config.QWEN_BASE_URL,
|
|
@@ -72,12 +82,12 @@ class CustomReactAgent:
|
|
|
)
|
|
|
logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
|
|
|
|
- # 2. 绑定工具
|
|
|
+ # 3. 绑定工具
|
|
|
self.tools = sql_tools
|
|
|
self.llm_with_tools = self.llm.bind_tools(self.tools)
|
|
|
logger.info(f" 已绑定 {len(self.tools)} 个工具。")
|
|
|
|
|
|
- # 3. 初始化 Redis Checkpointer
|
|
|
+ # 4. 初始化 Redis Checkpointer
|
|
|
if config.REDIS_ENABLED and AsyncRedisSaver is not None:
|
|
|
try:
|
|
|
self._exit_stack = AsyncExitStack()
|
|
@@ -93,7 +103,7 @@ class CustomReactAgent:
|
|
|
else:
|
|
|
logger.warning(" Redis 持久化功能已禁用。")
|
|
|
|
|
|
- # 4. 构建 StateGraph
|
|
|
+ # 5. 构建 StateGraph
|
|
|
self.agent_executor = self._create_graph()
|
|
|
logger.info(" StateGraph 已构建并编译。")
|
|
|
logger.info("✅ CustomReactAgent 初始化完成。")
|
|
@@ -105,23 +115,27 @@ class CustomReactAgent:
|
|
|
self._exit_stack = None
|
|
|
self.checkpointer = None
|
|
|
logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
|
|
|
+
|
|
|
+ if self.redis_client:
|
|
|
+ await self.redis_client.aclose()
|
|
|
+ logger.info("✅ Redis客户端已关闭。")
|
|
|
|
|
|
def _create_graph(self):
|
|
|
"""定义并编译最终的、正确的 StateGraph 结构。"""
|
|
|
builder = StateGraph(AgentState)
|
|
|
|
|
|
- # 定义所有需要的节点
|
|
|
- builder.add_node("agent", self._agent_node)
|
|
|
- builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
|
|
|
+ # 定义所有需要的节点 - 全部改为异步
|
|
|
+ builder.add_node("agent", self._async_agent_node)
|
|
|
+ builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
|
|
|
builder.add_node("tools", ToolNode(self.tools))
|
|
|
- builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
|
|
|
- builder.add_node("format_final_response", self._format_final_response_node)
|
|
|
+ builder.add_node("update_state_after_tool", self._async_update_state_after_tool_node)
|
|
|
+ builder.add_node("format_final_response", self._async_format_final_response_node)
|
|
|
|
|
|
# 建立正确的边连接
|
|
|
builder.set_entry_point("agent")
|
|
|
builder.add_conditional_edges(
|
|
|
"agent",
|
|
|
- self._should_continue,
|
|
|
+ self._async_should_continue,
|
|
|
{
|
|
|
"continue": "prepare_tool_input",
|
|
|
"end": "format_final_response"
|
|
@@ -134,30 +148,38 @@ class CustomReactAgent:
|
|
|
|
|
|
return builder.compile(checkpointer=self.checkpointer)
|
|
|
|
|
|
- def _should_continue(self, state: AgentState) -> str:
|
|
|
- """判断是继续调用工具还是结束。"""
|
|
|
+ async def _async_should_continue(self, state: AgentState) -> str:
|
|
|
+ """异步判断是继续调用工具还是结束。"""
|
|
|
last_message = state["messages"][-1]
|
|
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
|
|
return "continue"
|
|
|
return "end"
|
|
|
|
|
|
- def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
- """Agent 节点:只负责调用 LLM 并返回其输出。"""
|
|
|
- logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
|
|
|
+ async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """异步Agent 节点:使用异步LLM调用。"""
|
|
|
+ logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}")
|
|
|
|
|
|
messages_for_llm = list(state["messages"])
|
|
|
+
|
|
|
+ # 🎯 添加数据库范围系统提示词(仅在对话开始时添加)
|
|
|
+ if len(state["messages"]) == 1 and isinstance(state["messages"][0], HumanMessage):
|
|
|
+ db_scope_prompt = self._get_database_scope_prompt()
|
|
|
+ if db_scope_prompt:
|
|
|
+ messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
|
|
|
+ logger.info(" ✅ 已添加数据库范围判断提示词")
|
|
|
+
|
|
|
if state.get("suggested_next_step"):
|
|
|
- instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
|
|
|
+ instruction = f"Suggestion: Consider using the '{state['suggested_next_step']}' tool for the next step."
|
|
|
messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
|
|
|
# 添加重试机制处理网络连接问题
|
|
|
- import time
|
|
|
+ import asyncio
|
|
|
max_retries = config.MAX_RETRIES
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
|
- response = self.llm_with_tools.invoke(messages_for_llm)
|
|
|
- logger.info(f" LLM Response: {response.pretty_print()}")
|
|
|
- # 只返回消息,不承担其他职责
|
|
|
+ # 使用异步调用
|
|
|
+ response = await self.llm_with_tools.ainvoke(messages_for_llm)
|
|
|
+ logger.info(f" ✅ 异步LLM调用成功")
|
|
|
return {"messages": [response]}
|
|
|
|
|
|
except Exception as e:
|
|
@@ -172,14 +194,14 @@ class CustomReactAgent:
|
|
|
if attempt < max_retries - 1:
|
|
|
wait_time = config.RETRY_BASE_DELAY ** attempt # 指数退避:2, 4, 8秒
|
|
|
logger.info(f" 🔄 网络错误,{wait_time}秒后重试...")
|
|
|
- time.sleep(wait_time)
|
|
|
+ await asyncio.sleep(wait_time)
|
|
|
continue
|
|
|
else:
|
|
|
# 所有重试都失败了,返回一个降级的回答
|
|
|
logger.error(f" ❌ 网络连接持续失败,返回降级回答")
|
|
|
|
|
|
# 检查是否有SQL执行结果可以利用
|
|
|
- sql_data = self._extract_latest_sql_data(state["messages"])
|
|
|
+ sql_data = await self._async_extract_latest_sql_data(state["messages"])
|
|
|
if sql_data:
|
|
|
fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
|
|
|
else:
|
|
@@ -232,11 +254,11 @@ class CustomReactAgent:
|
|
|
|
|
|
logger.info(" ~" * 10 + " State Print End" + " ~" * 10)
|
|
|
|
|
|
- def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""
|
|
|
- 信息组装节点:为需要上下文的工具注入历史消息。
|
|
|
+ 异步信息组装节点:为需要上下文的工具注入历史消息。
|
|
|
"""
|
|
|
- logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
|
|
|
+ logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")
|
|
|
|
|
|
# 🎯 打印 state 全部信息
|
|
|
# self._print_state_info(state, "prepare_tool_input")
|
|
@@ -291,7 +313,7 @@ class CustomReactAgent:
|
|
|
last_message.tool_calls = new_tool_calls
|
|
|
return {"messages": [last_message]}
|
|
|
|
|
|
- def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""在工具执行后,更新 suggested_next_step 并清理参数。"""
|
|
|
logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
|
|
|
|
|
@@ -335,16 +357,16 @@ class CustomReactAgent:
|
|
|
tool_call["args"]["history_messages"] = ""
|
|
|
logger.info(f" 已将 generate_sql 的 history_messages 设置为空字符串")
|
|
|
|
|
|
- def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
- """最终输出格式化节点。"""
|
|
|
- logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
|
|
|
+ async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """异步最终输出格式化节点。"""
|
|
|
+ logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}")
|
|
|
|
|
|
# 保持原有的消息格式化(用于shell.py兼容)
|
|
|
last_message = state['messages'][-1]
|
|
|
last_message.content = f"[Formatted Output]\n{last_message.content}"
|
|
|
|
|
|
# 生成API格式的数据
|
|
|
- api_data = self._generate_api_data(state)
|
|
|
+ api_data = await self._async_generate_api_data(state)
|
|
|
|
|
|
# 打印api_data
|
|
|
print("-"*20+"api_data_start"+"-"*20)
|
|
@@ -356,9 +378,9 @@ class CustomReactAgent:
|
|
|
"api_data": api_data # 新增:API格式数据
|
|
|
}
|
|
|
|
|
|
- def _generate_api_data(self, state: AgentState) -> Dict[str, Any]:
|
|
|
- """生成API格式的数据结构"""
|
|
|
- logger.info("📊 生成API格式数据...")
|
|
|
+ async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """异步生成API格式的数据结构"""
|
|
|
+ logger.info("📊 异步生成API格式数据...")
|
|
|
|
|
|
# 提取基础响应内容
|
|
|
last_message = state['messages'][-1]
|
|
@@ -374,20 +396,20 @@ class CustomReactAgent:
|
|
|
}
|
|
|
|
|
|
# 提取SQL和数据记录
|
|
|
- sql_info = self._extract_sql_and_data(state['messages'])
|
|
|
+ sql_info = await self._async_extract_sql_and_data(state['messages'])
|
|
|
if sql_info['sql']:
|
|
|
api_data["sql"] = sql_info['sql']
|
|
|
if sql_info['records']:
|
|
|
api_data["records"] = sql_info['records']
|
|
|
|
|
|
# 生成Agent元数据
|
|
|
- api_data["react_agent_meta"] = self._collect_agent_metadata(state)
|
|
|
+ api_data["react_agent_meta"] = await self._async_collect_agent_metadata(state)
|
|
|
|
|
|
logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}")
|
|
|
return api_data
|
|
|
|
|
|
- def _extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
|
- """从消息历史中提取SQL和数据记录"""
|
|
|
+ async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
|
+ """异步从消息历史中提取SQL和数据记录"""
|
|
|
result = {"sql": None, "records": None}
|
|
|
|
|
|
# 查找最后一个HumanMessage之后的工具执行结果
|
|
@@ -438,7 +460,7 @@ class CustomReactAgent:
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def _collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ async def _async_collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
|
|
|
"""收集Agent元数据"""
|
|
|
messages = state['messages']
|
|
|
|
|
@@ -485,7 +507,7 @@ class CustomReactAgent:
|
|
|
"agent_version": "custom_react_v1"
|
|
|
}
|
|
|
|
|
|
- def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
+ async def _async_extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
"""从消息历史中提取最近的run_sql执行结果,但仅限于当前对话轮次。"""
|
|
|
logger.info("🔍 提取最新的SQL执行结果...")
|
|
|
|
|
@@ -552,7 +574,7 @@ class CustomReactAgent:
|
|
|
answer = final_state["messages"][-1].content
|
|
|
|
|
|
# 🎯 提取最近的 run_sql 执行结果(不修改messages)
|
|
|
- sql_data = self._extract_latest_sql_data(final_state["messages"])
|
|
|
+ sql_data = await self._async_extract_latest_sql_data(final_state["messages"])
|
|
|
|
|
|
logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
|
|
|
|
|
@@ -624,9 +646,8 @@ class CustomReactAgent:
|
|
|
return []
|
|
|
|
|
|
try:
|
|
|
- # 创建Redis连接 - 使用与checkpointer相同的连接配置
|
|
|
- from redis.asyncio import Redis
|
|
|
- redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+ # 使用统一的异步Redis客户端
|
|
|
+ redis_client = self.redis_client
|
|
|
|
|
|
# 1. 扫描匹配该用户的所有checkpoint keys
|
|
|
# checkpointer的key格式通常是: checkpoint:thread_id:checkpoint_id
|
|
@@ -757,4 +778,47 @@ class CustomReactAgent:
|
|
|
return f"{year}-{month}-{day} {hour}:{minute}:{second}"
|
|
|
except Exception:
|
|
|
pass
|
|
|
- return timestamp
|
|
|
+ return timestamp
|
|
|
+
|
|
|
+ def _get_database_scope_prompt(self) -> str:
|
|
|
+ """Get database scope prompt for intelligent query decision making"""
|
|
|
+ try:
|
|
|
+ import os
|
|
|
+ # Read agent/tools/db_query_decision_prompt.txt
|
|
|
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
+ db_scope_file = os.path.join(project_root, "agent", "tools", "db_query_decision_prompt.txt")
|
|
|
+
|
|
|
+ with open(db_scope_file, 'r', encoding='utf-8') as f:
|
|
|
+ db_scope_content = f.read().strip()
|
|
|
+
|
|
|
+ prompt = f"""You are an intelligent database query assistant. When deciding whether to use database query tools, please follow these rules:
|
|
|
+
|
|
|
+=== DATABASE BUSINESS SCOPE ===
|
|
|
+{db_scope_content}
|
|
|
+
|
|
|
+=== DECISION RULES ===
|
|
|
+1. If the question involves data within the above business scope (service areas, branches, revenue, traffic flow, etc.), use the generate_sql tool
|
|
|
+2. If the question is about general knowledge (like "when do lychees ripen?", weather, historical events, etc.), answer directly based on your knowledge WITHOUT using database tools
|
|
|
+3. When answering general knowledge questions, clearly indicate that this is based on general knowledge, not database data
|
|
|
+
|
|
|
+=== FALLBACK STRATEGY ===
|
|
|
+When generate_sql returns an error message or when queries return no results:
|
|
|
+1. First, check if the question is within the database scope described above
|
|
|
+2. For questions clearly OUTSIDE the database scope (world events, general knowledge, etc.):
|
|
|
+ - Provide the answer based on your knowledge immediately
|
|
|
+ - CLEARLY indicate this is based on general knowledge, not database data
|
|
|
+ - Use format: "Based on general knowledge (not database data): [answer]"
|
|
|
+3. For questions within database scope but queries return no results:
|
|
|
+ - If it's a reasonable question that might have a general answer, provide it
|
|
|
+ - Still indicate the source: "Based on general knowledge (database had no results): [answer]"
|
|
|
+4. For questions that definitely require specific database data:
|
|
|
+ - Acknowledge the limitation and suggest the data may not be available
|
|
|
+ - Do not attempt to guess or fabricate specific data
|
|
|
+
|
|
|
+Please intelligently choose whether to query the database based on the nature of the user's question."""
|
|
|
+
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"⚠️ Unable to read database scope description file: {e}")
|
|
|
+ return ""
|