|
@@ -1019,41 +1019,133 @@ class CustomReactAgent:
|
|
logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
|
|
|
|
- async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
|
|
- """从 checkpointer 获取指定线程的对话历史。"""
|
|
|
|
|
|
+ async def get_conversation_history(self, thread_id: str, include_tools: bool = False) -> Dict[str, Any]:
|
|
|
|
+ """
|
|
|
|
+ 从 checkpointer 获取指定线程的对话历史,支持消息过滤和时间戳。
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ thread_id: 线程ID
|
|
|
|
+ include_tools: 是否包含工具消息,默认False(只返回human和ai消息)
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ Dict包含: {
|
|
|
|
+ "messages": List[Dict], # 消息列表
|
|
|
|
+ "thread_created_at": str, # 线程创建时间
|
|
|
|
+ "total_checkpoints": int # 总checkpoint数
|
|
|
|
+ }
|
|
|
|
+ """
|
|
if not self.checkpointer:
|
|
if not self.checkpointer:
|
|
- return []
|
|
|
|
|
|
+ return {"messages": [], "thread_created_at": None, "total_checkpoints": 0}
|
|
|
|
|
|
thread_config = {"configurable": {"thread_id": thread_id}}
|
|
thread_config = {"configurable": {"thread_id": thread_id}}
|
|
|
|
+
|
|
try:
|
|
try:
|
|
- conversation_state = await self.checkpointer.aget(thread_config)
|
|
|
|
|
|
+ # 获取所有历史checkpoint,按时间倒序
|
|
|
|
+ checkpoints = []
|
|
|
|
+ async for checkpoint_tuple in self.checkpointer.alist(thread_config):
|
|
|
|
+ checkpoints.append(checkpoint_tuple)
|
|
|
|
+
|
|
|
|
+ if not checkpoints:
|
|
|
|
+ return {"messages": [], "thread_created_at": None, "total_checkpoints": 0}
|
|
|
|
+
|
|
|
|
+ # 解析thread_id中的创建时间
|
|
|
|
+ thread_created_at = self._parse_thread_creation_time(thread_id)
|
|
|
|
+
|
|
|
|
+ # 构建消息到时间戳的映射
|
|
|
|
+ message_timestamps = self._build_message_timestamp_map(checkpoints)
|
|
|
|
+
|
|
|
|
+ # 获取最新状态的消息
|
|
|
|
+ latest_checkpoint = checkpoints[0]
|
|
|
|
+ messages = latest_checkpoint.checkpoint.get('channel_values', {}).get('messages', [])
|
|
|
|
+
|
|
|
|
+ # 过滤和格式化消息
|
|
|
|
+ filtered_messages = []
|
|
|
|
+ for msg in messages:
|
|
|
|
+ # 确定消息类型
|
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
|
+ msg_type = "human"
|
|
|
|
+ elif isinstance(msg, ToolMessage):
|
|
|
|
+ if not include_tools:
|
|
|
|
+ continue # 跳过工具消息
|
|
|
|
+ msg_type = "tool"
|
|
|
|
+ else: # AIMessage
|
|
|
|
+ msg_type = "ai"
|
|
|
|
+ # 如果不包含工具消息,跳过只有工具调用没有内容的AI消息
|
|
|
|
+ if not include_tools and (not msg.content and hasattr(msg, 'tool_calls') and msg.tool_calls):
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 获取消息ID
|
|
|
|
+ msg_id = getattr(msg, 'id', None)
|
|
|
|
+ if not msg_id:
|
|
|
|
+ continue # 跳过没有ID的消息
|
|
|
|
+
|
|
|
|
+ # 获取时间戳
|
|
|
|
+ timestamp = message_timestamps.get(msg_id)
|
|
|
|
+ if not timestamp:
|
|
|
|
+ # 如果没有找到精确时间戳,使用最新checkpoint的时间
|
|
|
|
+ timestamp = latest_checkpoint.checkpoint.get('ts')
|
|
|
|
+
|
|
|
|
+ filtered_messages.append({
|
|
|
|
+ "id": msg_id,
|
|
|
|
+ "type": msg_type,
|
|
|
|
+ "content": msg.content,
|
|
|
|
+ "timestamp": timestamp
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "messages": filtered_messages,
|
|
|
|
+ "thread_created_at": thread_created_at,
|
|
|
|
+ "total_checkpoints": len(checkpoints)
|
|
|
|
+ }
|
|
|
|
+
|
|
except RuntimeError as e:
|
|
except RuntimeError as e:
|
|
if "Event loop is closed" in str(e):
|
|
if "Event loop is closed" in str(e):
|
|
logger.warning(f"⚠️ Event loop已关闭,尝试重新获取对话历史: {thread_id}")
|
|
logger.warning(f"⚠️ Event loop已关闭,尝试重新获取对话历史: {thread_id}")
|
|
- # 如果事件循环关闭,返回空结果而不是抛出异常
|
|
|
|
- return []
|
|
|
|
|
|
+ return {"messages": [], "thread_created_at": None, "total_checkpoints": 0}
|
|
else:
|
|
else:
|
|
raise
|
|
raise
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"❌ 获取对话历史失败: {e}")
|
|
|
|
+ return {"messages": [], "thread_created_at": None, "total_checkpoints": 0}
|
|
|
|
+
|
|
|
|
+ def _parse_thread_creation_time(self, thread_id: str) -> str:
|
|
|
|
+ """解析thread_id中的创建时间,返回带毫秒的格式"""
|
|
|
|
+ try:
|
|
|
|
+ if ':' in thread_id:
|
|
|
|
+ parts = thread_id.split(':')
|
|
|
|
+ if len(parts) >= 2:
|
|
|
|
+ timestamp_part = parts[1]
|
|
|
|
+ if len(timestamp_part) >= 14:
|
|
|
|
+ year = timestamp_part[:4]
|
|
|
|
+ month = timestamp_part[4:6]
|
|
|
|
+ day = timestamp_part[6:8]
|
|
|
|
+ hour = timestamp_part[8:10]
|
|
|
|
+ minute = timestamp_part[10:12]
|
|
|
|
+ second = timestamp_part[12:14]
|
|
|
|
+ ms = timestamp_part[14:17] if len(timestamp_part) > 14 else "000"
|
|
|
|
+ return f"{year}-{month}-{day} {hour}:{minute}:{second}.{ms}"
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f"⚠️ 解析thread创建时间失败: {e}")
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ def _build_message_timestamp_map(self, checkpoints: List) -> Dict[str, str]:
|
|
|
|
+ """构建消息ID到时间戳的映射"""
|
|
|
|
+ message_timestamps = {}
|
|
|
|
|
|
- if not conversation_state:
|
|
|
|
- return []
|
|
|
|
-
|
|
|
|
- history = []
|
|
|
|
- messages = conversation_state.get('channel_values', {}).get('messages', [])
|
|
|
|
- for msg in messages:
|
|
|
|
- if isinstance(msg, HumanMessage):
|
|
|
|
- role = "human"
|
|
|
|
- elif isinstance(msg, ToolMessage):
|
|
|
|
- role = "tool"
|
|
|
|
- else: # AIMessage
|
|
|
|
- role = "ai"
|
|
|
|
|
|
+ # 按时间正序排列checkpoint(最早的在前)
|
|
|
|
+ checkpoints_sorted = sorted(checkpoints, key=lambda x: x.checkpoint.get('ts', ''))
|
|
|
|
+
|
|
|
|
+ for checkpoint_tuple in checkpoints_sorted:
|
|
|
|
+ checkpoint_ts = checkpoint_tuple.checkpoint.get('ts')
|
|
|
|
+ messages = checkpoint_tuple.checkpoint.get('channel_values', {}).get('messages', [])
|
|
|
|
|
|
- history.append({
|
|
|
|
- "type": role,
|
|
|
|
- "content": msg.content,
|
|
|
|
- "tool_calls": getattr(msg, 'tool_calls', None)
|
|
|
|
- })
|
|
|
|
- return history
|
|
|
|
|
|
+ # 为这个checkpoint中的新消息分配时间戳
|
|
|
|
+ for msg in messages:
|
|
|
|
+ msg_id = getattr(msg, 'id', None)
|
|
|
|
+ if msg_id and msg_id not in message_timestamps:
|
|
|
|
+ message_timestamps[msg_id] = checkpoint_ts
|
|
|
|
+
|
|
|
|
+ return message_timestamps
|
|
|
|
|
|
async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
"""
|
|
"""
|
|
@@ -1147,15 +1239,23 @@ class CustomReactAgent:
|
|
# 生成对话预览
|
|
# 生成对话预览
|
|
preview = self._generate_conversation_preview(messages)
|
|
preview = self._generate_conversation_preview(messages)
|
|
|
|
|
|
|
|
+ # 获取最后一条用户消息
|
|
|
|
+ last_human_message = None
|
|
|
|
+ if messages:
|
|
|
|
+ for msg in reversed(messages):
|
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
|
+ last_human_message = msg.content
|
|
|
|
+ break
|
|
|
|
+
|
|
conversations.append({
|
|
conversations.append({
|
|
|
|
+ "conversation_id": thread_id,
|
|
"thread_id": thread_id,
|
|
"thread_id": thread_id,
|
|
"user_id": user_id,
|
|
"user_id": user_id,
|
|
- "timestamp": thread_info["timestamp"],
|
|
|
|
"message_count": len(messages),
|
|
"message_count": len(messages),
|
|
- "last_message": messages[-1].content if messages else None,
|
|
|
|
- "last_updated": state.get('created_at'),
|
|
|
|
- "conversation_preview": preview,
|
|
|
|
- "formatted_time": self._format_timestamp(thread_info["timestamp"])
|
|
|
|
|
|
+ "last_message": last_human_message,
|
|
|
|
+ "updated_at": self._format_utc_to_china_time(state.get('ts')) if state.get('ts') else None,
|
|
|
|
+ "conversation_title": preview,
|
|
|
|
+ "created_at": self._format_timestamp(thread_info["timestamp"])
|
|
})
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -1183,7 +1283,7 @@ class CustomReactAgent:
|
|
return "系统消息"
|
|
return "系统消息"
|
|
|
|
|
|
def _format_timestamp(self, timestamp: str) -> str:
|
|
def _format_timestamp(self, timestamp: str) -> str:
|
|
- """格式化时间戳为可读格式"""
|
|
|
|
|
|
+ """格式化时间戳为可读格式,包含毫秒"""
|
|
try:
|
|
try:
|
|
# timestamp格式: 20250710123137984
|
|
# timestamp格式: 20250710123137984
|
|
if len(timestamp) >= 14:
|
|
if len(timestamp) >= 14:
|
|
@@ -1193,10 +1293,31 @@ class CustomReactAgent:
|
|
hour = timestamp[8:10]
|
|
hour = timestamp[8:10]
|
|
minute = timestamp[10:12]
|
|
minute = timestamp[10:12]
|
|
second = timestamp[12:14]
|
|
second = timestamp[12:14]
|
|
- return f"{year}-{month}-{day} {hour}:{minute}:{second}"
|
|
|
|
|
|
+ # 提取毫秒部分(如果存在)
|
|
|
|
+ millisecond = timestamp[14:17] if len(timestamp) > 14 else "000"
|
|
|
|
+ return f"{year}-{month}-{day} {hour}:{minute}:{second}.{millisecond}"
|
|
except Exception:
|
|
except Exception:
|
|
pass
|
|
pass
|
|
- return timestamp
|
|
|
|
|
|
+ return timestamp
|
|
|
|
+
|
|
|
|
+ def _format_utc_to_china_time(self, utc_time_str: str) -> str:
|
|
|
|
+ """将UTC时间转换为中国时区时间格式"""
|
|
|
|
+ try:
|
|
|
|
+ from datetime import datetime, timezone, timedelta
|
|
|
|
+
|
|
|
|
+ # 解析UTC时间字符串
|
|
|
|
+ # 格式: "2025-07-17T13:21:52.868292+00:00"
|
|
|
|
+ dt = datetime.fromisoformat(utc_time_str.replace('Z', '+00:00'))
|
|
|
|
+
|
|
|
|
+ # 转换为中国时区 (UTC+8)
|
|
|
|
+ china_tz = timezone(timedelta(hours=8))
|
|
|
|
+ china_time = dt.astimezone(china_tz)
|
|
|
|
+
|
|
|
|
+ # 格式化为目标格式: "2025-07-17 21:12:02.456"
|
|
|
|
+ return china_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 只保留3位毫秒
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f"时间格式转换失败: {e}")
|
|
|
|
+ return utc_time_str
|
|
|
|
|
|
def _get_database_scope_prompt(self) -> str:
|
|
def _get_database_scope_prompt(self) -> str:
|
|
"""Get database scope prompt for intelligent query decision making"""
|
|
"""Get database scope prompt for intelligent query decision making"""
|