agent.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. """
  2. 基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
  3. """
  4. import logging
  5. import json
  6. import pandas as pd
  7. from typing import List, Optional, Dict, Any, Tuple
  8. from contextlib import AsyncExitStack
  9. from langchain_openai import ChatOpenAI
  10. from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
  11. from langgraph.graph import StateGraph, END
  12. from langgraph.prebuilt import ToolNode
  13. from redis.asyncio import Redis
  14. try:
  15. from langgraph.checkpoint.redis import AsyncRedisSaver
  16. except ImportError:
  17. AsyncRedisSaver = None
  18. # 从新模块导入配置、状态和工具
  19. try:
  20. # 尝试相对导入(当作为模块导入时)
  21. from . import config
  22. from .state import AgentState
  23. from .sql_tools import sql_tools
  24. except ImportError:
  25. # 如果相对导入失败,尝试绝对导入(直接运行时)
  26. import config
  27. from state import AgentState
  28. from sql_tools import sql_tools
  29. from langchain_core.runnables import RunnablePassthrough
  30. logger = logging.getLogger(__name__)
  31. class CustomReactAgent:
  32. """
  33. 一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
  34. """
  35. def __init__(self):
  36. """私有构造函数,请使用 create() 类方法来创建实例。"""
  37. self.llm = None
  38. self.tools = None
  39. self.agent_executor = None
  40. self.checkpointer = None
  41. self._exit_stack = None
  42. @classmethod
  43. async def create(cls):
  44. """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
  45. instance = cls()
  46. await instance._async_init()
  47. return instance
  48. async def _async_init(self):
  49. """异步初始化所有组件。"""
  50. logger.info("🚀 开始初始化 CustomReactAgent...")
  51. # 1. 初始化 LLM
  52. self.llm = ChatOpenAI(
  53. api_key=config.QWEN_API_KEY,
  54. base_url=config.QWEN_BASE_URL,
  55. model=config.QWEN_MODEL,
  56. temperature=0.1,
  57. extra_body={
  58. "enable_thinking": False,
  59. "misc": {
  60. "ensure_ascii": False
  61. }
  62. }
  63. )
  64. logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
  65. # 2. 绑定工具
  66. self.tools = sql_tools
  67. self.llm_with_tools = self.llm.bind_tools(self.tools)
  68. logger.info(f" 已绑定 {len(self.tools)} 个工具。")
  69. # 3. 初始化 Redis Checkpointer
  70. if config.REDIS_ENABLED and AsyncRedisSaver is not None:
  71. try:
  72. self._exit_stack = AsyncExitStack()
  73. checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
  74. self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
  75. await self.checkpointer.asetup()
  76. logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
  77. except Exception as e:
  78. logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
  79. if self._exit_stack:
  80. await self._exit_stack.aclose()
  81. self.checkpointer = None
  82. else:
  83. logger.warning(" Redis 持久化功能已禁用。")
  84. # 4. 构建 StateGraph
  85. self.agent_executor = self._create_graph()
  86. logger.info(" StateGraph 已构建并编译。")
  87. logger.info("✅ CustomReactAgent 初始化完成。")
  88. async def close(self):
  89. """清理资源,关闭 Redis 连接。"""
  90. if self._exit_stack:
  91. await self._exit_stack.aclose()
  92. self._exit_stack = None
  93. self.checkpointer = None
  94. logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
  95. def _create_graph(self):
  96. """定义并编译最终的、正确的 StateGraph 结构。"""
  97. builder = StateGraph(AgentState)
  98. # 定义所有需要的节点
  99. builder.add_node("agent", self._agent_node)
  100. builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
  101. builder.add_node("tools", ToolNode(self.tools))
  102. builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
  103. builder.add_node("format_final_response", self._format_final_response_node)
  104. # 建立正确的边连接
  105. builder.set_entry_point("agent")
  106. builder.add_conditional_edges(
  107. "agent",
  108. self._should_continue,
  109. {
  110. "continue": "prepare_tool_input",
  111. "end": "format_final_response"
  112. }
  113. )
  114. builder.add_edge("prepare_tool_input", "tools")
  115. builder.add_edge("tools", "update_state_after_tool")
  116. builder.add_edge("update_state_after_tool", "agent")
  117. builder.add_edge("format_final_response", END)
  118. return builder.compile(checkpointer=self.checkpointer)
  119. def _should_continue(self, state: AgentState) -> str:
  120. """判断是继续调用工具还是结束。"""
  121. last_message = state["messages"][-1]
  122. if hasattr(last_message, "tool_calls") and last_message.tool_calls:
  123. return "continue"
  124. return "end"
  125. def _agent_node(self, state: AgentState) -> Dict[str, Any]:
  126. """Agent 节点:只负责调用 LLM 并返回其输出。"""
  127. logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
  128. messages_for_llm = list(state["messages"])
  129. if state.get("suggested_next_step"):
  130. instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
  131. messages_for_llm.append(SystemMessage(content=instruction))
  132. response = self.llm_with_tools.invoke(messages_for_llm)
  133. logger.info(f" LLM Response: {response.pretty_print()}")
  134. # 只返回消息,不承担其他职责
  135. return {"messages": [response]}
  136. def _print_state_info(self, state: AgentState, node_name: str) -> None:
  137. """
  138. 打印 state 的全部信息,用于调试
  139. """
  140. logger.info(" ~" * 10 + " State Print Start" + " ~" * 10)
  141. logger.info(f"📋 [State Debug] {node_name} - 当前状态信息:")
  142. # 🎯 打印 state 中的所有字段
  143. logger.info(" State中的所有字段:")
  144. for key, value in state.items():
  145. if key == "messages":
  146. logger.info(f" {key}: {len(value)} 条消息")
  147. else:
  148. logger.info(f" {key}: {value}")
  149. # 原有的详细消息信息
  150. logger.info(f" 用户ID: {state.get('user_id', 'N/A')}")
  151. logger.info(f" 线程ID: {state.get('thread_id', 'N/A')}")
  152. logger.info(f" 建议下一步: {state.get('suggested_next_step', 'N/A')}")
  153. messages = state.get("messages", [])
  154. logger.info(f" 消息历史数量: {len(messages)}")
  155. if messages:
  156. logger.info(" 最近的消息:")
  157. for i, msg in enumerate(messages[-10:], start=max(0, len(messages)-10)): # 显示最后3条消息
  158. msg_type = type(msg).__name__
  159. content_preview = str(msg.content)[:100] + "..." if len(str(msg.content)) > 100 else str(msg.content)
  160. logger.info(f" [{i}] {msg_type}: {content_preview}")
  161. # 如果是 AIMessage 且有工具调用,显示工具调用信息
  162. if hasattr(msg, 'tool_calls') and msg.tool_calls:
  163. for tool_call in msg.tool_calls:
  164. tool_name = tool_call.get('name', 'Unknown')
  165. tool_args = tool_call.get('args', {})
  166. logger.info(f" 工具调用: {tool_name}")
  167. logger.info(f" 参数: {str(tool_args)[:200]}...")
  168. logger.info(" ~" * 10 + " State Print End" + " ~" * 10)
  169. def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
  170. """
  171. 信息组装节点:为需要上下文的工具注入历史消息。
  172. """
  173. logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
  174. # 🎯 打印 state 全部信息
  175. # self._print_state_info(state, "prepare_tool_input")
  176. last_message = state["messages"][-1]
  177. if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
  178. return {"messages": [last_message]}
  179. # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
  180. new_tool_calls = []
  181. for tool_call in last_message.tool_calls:
  182. if tool_call["name"] == "generate_sql":
  183. logger.info(" 检测到 generate_sql 调用,注入历史消息。")
  184. # 复制一份以避免修改原始 tool_call
  185. modified_args = tool_call["args"].copy()
  186. # 🎯 改进的消息过滤逻辑:只保留有用的对话上下文,排除当前问题
  187. clean_history = []
  188. messages_except_current = state["messages"][:-1] # 排除最后一个消息(当前问题)
  189. for msg in messages_except_current:
  190. if isinstance(msg, HumanMessage):
  191. # 保留历史用户消息(但不包括当前问题)
  192. clean_history.append({
  193. "type": "human",
  194. "content": msg.content
  195. })
  196. elif isinstance(msg, AIMessage):
  197. # 只保留最终的、面向用户的回答(包含"[Formatted Output]"的消息)
  198. if msg.content and "[Formatted Output]" in msg.content:
  199. # 去掉 "[Formatted Output]" 标记,只保留真正的回答
  200. clean_content = msg.content.replace("[Formatted Output]\n", "")
  201. clean_history.append({
  202. "type": "ai",
  203. "content": clean_content
  204. })
  205. # 跳过包含工具调用的 AIMessage(中间步骤)
  206. # 跳过所有 ToolMessage(工具执行结果)
  207. modified_args["history_messages"] = clean_history
  208. logger.info(f" 注入了 {len(clean_history)} 条过滤后的历史消息")
  209. new_tool_calls.append({
  210. "name": tool_call["name"],
  211. "args": modified_args,
  212. "id": tool_call["id"],
  213. })
  214. else:
  215. new_tool_calls.append(tool_call)
  216. # 用包含修改后参数的新消息替换掉原来的
  217. last_message.tool_calls = new_tool_calls
  218. return {"messages": [last_message]}
  219. def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
  220. """在工具执行后,更新 suggested_next_step 并清理参数。"""
  221. logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
  222. # 🎯 打印 state 全部信息
  223. self._print_state_info(state, "update_state_after_tool")
  224. last_tool_message = state['messages'][-1]
  225. tool_name = last_tool_message.name
  226. tool_output = last_tool_message.content
  227. next_step = None
  228. if tool_name == 'generate_sql':
  229. if "失败" in tool_output or "无法生成" in tool_output:
  230. next_step = 'answer_with_common_sense'
  231. else:
  232. next_step = 'valid_sql'
  233. # 🎯 清理 generate_sql 的 history_messages 参数,设置为空字符串
  234. # self._clear_history_messages_parameter(state['messages'])
  235. elif tool_name == 'valid_sql':
  236. if "失败" in tool_output:
  237. next_step = 'analyze_validation_error'
  238. else:
  239. next_step = 'run_sql'
  240. elif tool_name == 'run_sql':
  241. next_step = 'summarize_final_answer'
  242. logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
  243. return {"suggested_next_step": next_step}
  244. def _clear_history_messages_parameter(self, messages: List[BaseMessage]) -> None:
  245. """
  246. 将 generate_sql 工具的 history_messages 参数设置为空字符串
  247. """
  248. for message in messages:
  249. if hasattr(message, "tool_calls") and message.tool_calls:
  250. for tool_call in message.tool_calls:
  251. if tool_call["name"] == "generate_sql" and "history_messages" in tool_call["args"]:
  252. tool_call["args"]["history_messages"] = ""
  253. logger.info(f" 已将 generate_sql 的 history_messages 设置为空字符串")
  254. def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
  255. """最终输出格式化节点。"""
  256. logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
  257. # 保持原有的消息格式化(用于shell.py兼容)
  258. last_message = state['messages'][-1]
  259. last_message.content = f"[Formatted Output]\n{last_message.content}"
  260. # 生成API格式的数据
  261. api_data = self._generate_api_data(state)
  262. # 打印api_data
  263. print("-"*20+"api_data_start"+"-"*20)
  264. print(api_data)
  265. print("-"*20+"api_data_end"+"-"*20)
  266. return {
  267. "messages": [last_message],
  268. "api_data": api_data # 新增:API格式数据
  269. }
  270. def _generate_api_data(self, state: AgentState) -> Dict[str, Any]:
  271. """生成API格式的数据结构"""
  272. logger.info("📊 生成API格式数据...")
  273. # 提取基础响应内容
  274. last_message = state['messages'][-1]
  275. response_content = last_message.content
  276. # 去掉格式化标记,获取纯净的回答
  277. if response_content.startswith("[Formatted Output]\n"):
  278. response_content = response_content.replace("[Formatted Output]\n", "")
  279. # 初始化API数据结构
  280. api_data = {
  281. "response": response_content
  282. }
  283. # 提取SQL和数据记录
  284. sql_info = self._extract_sql_and_data(state['messages'])
  285. if sql_info['sql']:
  286. api_data["sql"] = sql_info['sql']
  287. if sql_info['records']:
  288. api_data["records"] = sql_info['records']
  289. # 生成Agent元数据
  290. api_data["react_agent_meta"] = self._collect_agent_metadata(state)
  291. logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}")
  292. return api_data
  293. def _extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
  294. """从消息历史中提取SQL和数据记录"""
  295. result = {"sql": None, "records": None}
  296. # 查找最后一个HumanMessage之后的工具执行结果
  297. last_human_index = -1
  298. for i in range(len(messages) - 1, -1, -1):
  299. if isinstance(messages[i], HumanMessage):
  300. last_human_index = i
  301. break
  302. if last_human_index == -1:
  303. return result
  304. # 在当前对话轮次中查找工具执行结果
  305. current_conversation = messages[last_human_index:]
  306. sql_query = None
  307. sql_data = None
  308. for msg in current_conversation:
  309. if isinstance(msg, ToolMessage):
  310. if msg.name == 'generate_sql':
  311. # 提取生成的SQL
  312. content = msg.content
  313. if content and not any(keyword in content for keyword in ["失败", "无法生成", "Database query failed"]):
  314. sql_query = content.strip()
  315. elif msg.name == 'run_sql':
  316. # 提取SQL执行结果
  317. try:
  318. import json
  319. parsed_data = json.loads(msg.content)
  320. if isinstance(parsed_data, list) and len(parsed_data) > 0:
  321. # DataFrame.to_json(orient='records') 格式
  322. columns = list(parsed_data[0].keys()) if parsed_data else []
  323. sql_data = {
  324. "columns": columns,
  325. "rows": parsed_data,
  326. "total_row_count": len(parsed_data),
  327. "is_limited": False # 当前版本没有实现限制
  328. }
  329. except (json.JSONDecodeError, Exception) as e:
  330. logger.warning(f" 解析SQL结果失败: {e}")
  331. if sql_query:
  332. result["sql"] = sql_query
  333. if sql_data:
  334. result["records"] = sql_data
  335. return result
  336. def _collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
  337. """收集Agent元数据"""
  338. messages = state['messages']
  339. # 统计工具使用情况
  340. tools_used = []
  341. sql_execution_count = 0
  342. context_injected = False
  343. # 计算对话轮次(HumanMessage的数量)
  344. conversation_rounds = sum(1 for msg in messages if isinstance(msg, HumanMessage))
  345. # 分析工具调用和执行
  346. for msg in messages:
  347. if isinstance(msg, ToolMessage):
  348. if msg.name not in tools_used:
  349. tools_used.append(msg.name)
  350. if msg.name == 'run_sql':
  351. sql_execution_count += 1
  352. elif isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
  353. for tool_call in msg.tool_calls:
  354. tool_name = tool_call.get('name')
  355. if tool_name and tool_name not in tools_used:
  356. tools_used.append(tool_name)
  357. # 检查是否注入了历史上下文
  358. if (tool_name == 'generate_sql' and
  359. tool_call.get('args', {}).get('history_messages')):
  360. context_injected = True
  361. # 构建执行路径(简化版本)
  362. execution_path = ["agent"]
  363. if tools_used:
  364. execution_path.extend(["prepare_tool_input", "tools"])
  365. execution_path.append("format_final_response")
  366. return {
  367. "thread_id": state['thread_id'],
  368. "conversation_rounds": conversation_rounds,
  369. "tools_used": tools_used,
  370. "execution_path": execution_path,
  371. "total_messages": len(messages),
  372. "sql_execution_count": sql_execution_count,
  373. "context_injected": context_injected,
  374. "agent_version": "custom_react_v1"
  375. }
  376. def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
  377. """从消息历史中提取最近的run_sql执行结果,但仅限于当前对话轮次。"""
  378. logger.info("🔍 提取最新的SQL执行结果...")
  379. # 🎯 只查找最后一个HumanMessage之后的SQL执行结果
  380. last_human_index = -1
  381. for i in range(len(messages) - 1, -1, -1):
  382. if isinstance(messages[i], HumanMessage):
  383. last_human_index = i
  384. break
  385. if last_human_index == -1:
  386. logger.info(" 未找到用户消息,跳过SQL数据提取")
  387. return None
  388. # 只在当前对话轮次中查找SQL结果
  389. current_conversation = messages[last_human_index:]
  390. logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
  391. for msg in reversed(current_conversation):
  392. if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
  393. logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
  394. # 🎯 处理Unicode转义序列,将其转换为正常的中文字符
  395. try:
  396. # 先尝试解析JSON以验证格式
  397. parsed_data = json.loads(msg.content)
  398. # 重新序列化,确保中文字符正常显示
  399. formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
  400. logger.info(f" 已转换Unicode转义序列为中文字符")
  401. return formatted_content
  402. except json.JSONDecodeError:
  403. # 如果不是有效JSON,直接返回原内容
  404. logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
  405. return msg.content
  406. logger.info(" 当前对话轮次中未找到run_sql执行结果")
  407. return None
  408. async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
  409. """
  410. 处理用户聊天请求。
  411. """
  412. if not thread_id:
  413. now = pd.Timestamp.now()
  414. milliseconds = int(now.microsecond / 1000)
  415. thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
  416. logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
  417. config = {
  418. "configurable": {
  419. "thread_id": thread_id,
  420. }
  421. }
  422. inputs = {
  423. "messages": [HumanMessage(content=message)],
  424. "user_id": user_id,
  425. "thread_id": thread_id,
  426. "suggested_next_step": None,
  427. }
  428. try:
  429. final_state = await self.agent_executor.ainvoke(inputs, config)
  430. answer = final_state["messages"][-1].content
  431. # 🎯 提取最近的 run_sql 执行结果(不修改messages)
  432. sql_data = self._extract_latest_sql_data(final_state["messages"])
  433. logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
  434. # 构建返回结果(保持简化格式用于shell.py)
  435. result = {
  436. "success": True,
  437. "answer": answer,
  438. "thread_id": thread_id
  439. }
  440. # 只有当存在SQL数据时才添加到返回结果中
  441. if sql_data:
  442. result["sql_data"] = sql_data
  443. logger.info(" 📊 已包含SQL原始数据")
  444. # 🎯 如果存在API格式数据,也添加到返回结果中(用于API层)
  445. if "api_data" in final_state:
  446. result["api_data"] = final_state["api_data"]
  447. logger.info(" 🔌 已包含API格式数据")
  448. return result
  449. except Exception as e:
  450. logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
  451. return {"success": False, "error": str(e), "thread_id": thread_id}
  452. async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
  453. """从 checkpointer 获取指定线程的对话历史。"""
  454. if not self.checkpointer:
  455. return []
  456. config = {"configurable": {"thread_id": thread_id}}
  457. try:
  458. conversation_state = await self.checkpointer.aget(config)
  459. except RuntimeError as e:
  460. if "Event loop is closed" in str(e):
  461. logger.warning(f"⚠️ Event loop已关闭,尝试重新获取对话历史: {thread_id}")
  462. # 如果事件循环关闭,返回空结果而不是抛出异常
  463. return []
  464. else:
  465. raise
  466. if not conversation_state:
  467. return []
  468. history = []
  469. messages = conversation_state.get('channel_values', {}).get('messages', [])
  470. for msg in messages:
  471. if isinstance(msg, HumanMessage):
  472. role = "human"
  473. elif isinstance(msg, ToolMessage):
  474. role = "tool"
  475. else: # AIMessage
  476. role = "ai"
  477. history.append({
  478. "type": role,
  479. "content": msg.content,
  480. "tool_calls": getattr(msg, 'tool_calls', None)
  481. })
  482. return history
  483. async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
  484. """
  485. 获取指定用户的最近聊天记录列表
  486. 利用thread_id格式 'user_id:timestamp' 来查询
  487. """
  488. if not self.checkpointer:
  489. return []
  490. try:
  491. # 创建Redis连接 - 使用与checkpointer相同的连接配置
  492. from redis.asyncio import Redis
  493. redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
  494. # 1. 扫描匹配该用户的所有checkpoint keys
  495. # checkpointer的key格式通常是: checkpoint:thread_id:checkpoint_id
  496. pattern = f"checkpoint:{user_id}:*"
  497. logger.info(f"🔍 扫描模式: {pattern}")
  498. user_threads = {}
  499. cursor = 0
  500. while True:
  501. cursor, keys = await redis_client.scan(
  502. cursor=cursor,
  503. match=pattern,
  504. count=1000
  505. )
  506. for key in keys:
  507. try:
  508. # 解析key获取thread_id和checkpoint信息
  509. # key格式: checkpoint:user_id:timestamp:status:checkpoint_id
  510. key_str = key.decode() if isinstance(key, bytes) else key
  511. parts = key_str.split(':')
  512. if len(parts) >= 4:
  513. # thread_id = user_id:timestamp
  514. thread_id = f"{parts[1]}:{parts[2]}"
  515. timestamp = parts[2]
  516. # 跟踪每个thread的最新checkpoint
  517. if thread_id not in user_threads:
  518. user_threads[thread_id] = {
  519. "thread_id": thread_id,
  520. "timestamp": timestamp,
  521. "latest_key": key_str
  522. }
  523. else:
  524. # 保留最新的checkpoint key(通常checkpoint_id越大越新)
  525. if len(parts) > 4 and parts[4] > user_threads[thread_id]["latest_key"].split(':')[4]:
  526. user_threads[thread_id]["latest_key"] = key_str
  527. except Exception as e:
  528. logger.warning(f"解析key {key} 失败: {e}")
  529. continue
  530. if cursor == 0:
  531. break
  532. # 关闭临时Redis连接
  533. await redis_client.close()
  534. # 2. 按时间戳排序(新的在前)
  535. sorted_threads = sorted(
  536. user_threads.values(),
  537. key=lambda x: x["timestamp"],
  538. reverse=True
  539. )[:limit]
  540. # 3. 获取每个thread的详细信息
  541. conversations = []
  542. for thread_info in sorted_threads:
  543. try:
  544. thread_id = thread_info["thread_id"]
  545. thread_config = {"configurable": {"thread_id": thread_id}}
  546. try:
  547. state = await self.checkpointer.aget(thread_config)
  548. except RuntimeError as e:
  549. if "Event loop is closed" in str(e):
  550. logger.warning(f"⚠️ Event loop已关闭,跳过thread: {thread_id}")
  551. continue
  552. else:
  553. raise
  554. if state and state.get('channel_values', {}).get('messages'):
  555. messages = state['channel_values']['messages']
  556. # 生成对话预览
  557. preview = self._generate_conversation_preview(messages)
  558. conversations.append({
  559. "thread_id": thread_id,
  560. "user_id": user_id,
  561. "timestamp": thread_info["timestamp"],
  562. "message_count": len(messages),
  563. "last_message": messages[-1].content if messages else None,
  564. "last_updated": state.get('created_at'),
  565. "conversation_preview": preview,
  566. "formatted_time": self._format_timestamp(thread_info["timestamp"])
  567. })
  568. except Exception as e:
  569. logger.error(f"获取thread {thread_info['thread_id']} 详情失败: {e}")
  570. continue
  571. logger.info(f"✅ 找到用户 {user_id} 的 {len(conversations)} 个对话")
  572. return conversations
  573. except Exception as e:
  574. logger.error(f"❌ 获取用户 {user_id} 对话列表失败: {e}")
  575. return []
  576. def _generate_conversation_preview(self, messages: List[BaseMessage]) -> str:
  577. """生成对话预览"""
  578. if not messages:
  579. return "空对话"
  580. # 获取第一个用户消息作为预览
  581. for msg in messages:
  582. if isinstance(msg, HumanMessage):
  583. content = str(msg.content)
  584. return content[:50] + "..." if len(content) > 50 else content
  585. return "系统消息"
  586. def _format_timestamp(self, timestamp: str) -> str:
  587. """格式化时间戳为可读格式"""
  588. try:
  589. # timestamp格式: 20250710123137984
  590. if len(timestamp) >= 14:
  591. year = timestamp[:4]
  592. month = timestamp[4:6]
  593. day = timestamp[6:8]
  594. hour = timestamp[8:10]
  595. minute = timestamp[10:12]
  596. second = timestamp[12:14]
  597. return f"{year}-{month}-{day} {hour}:{minute}:{second}"
  598. except Exception:
  599. pass
  600. return timestamp