Browse Source

准备开始修改 /aip/chat.

wangxq 1 month ago
parent
commit
2cae793303
2 changed files with 222 additions and 2 deletions
  1. 141 2
      test/custom_react_agent/shell.py
  2. 81 0
      test/custom_react_agent/test_shell_features.py

+ 141 - 2
test/custom_react_agent/shell.py

@@ -8,6 +8,7 @@ import logging
 import sys
 import os
 import json
+from typing import List, Dict, Any
 
 # 将当前目录和项目根目录添加到 sys.path
 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -31,6 +32,7 @@ class CustomAgentShell:
         self.agent = agent
         self.user_id: str = config.DEFAULT_USER_ID
         self.thread_id: str | None = None
+        self.recent_conversations: List[Dict[str, Any]] = []  # 存储最近的对话列表
 
     @classmethod
     async def create(cls):
@@ -43,6 +45,101 @@ class CustomAgentShell:
         if self.agent:
             await self.agent.close()
 
+    async def _fetch_recent_conversations(self, user_id: str, limit: int = 5) -> List[Dict[str, Any]]:
+        """获取最近的对话列表"""
+        try:
+            logger.info(f"🔍 获取用户 {user_id} 的最近 {limit} 次对话...")
+            conversations = await self.agent.get_user_recent_conversations(user_id, limit)
+            logger.info(f"✅ 成功获取 {len(conversations)} 个对话")
+            return conversations
+        except Exception as e:
+            logger.error(f"❌ 获取对话列表失败: {e}")
+            print(f"⚠️ 获取历史对话失败: {e}")
+            print("   将直接开始新对话...")
+            return []
+
+    def _display_conversation_list(self, conversations: List[Dict[str, Any]]) -> None:
+        """显示对话列表"""
+        if not conversations:
+            print("📭 暂无历史对话,将开始新对话。")
+            return
+        
+        print("\n📋 最近的对话记录:")
+        print("-" * 60)
+        
+        for i, conv in enumerate(conversations, 1):
+            thread_id = conv.get('thread_id', '')
+            formatted_time = conv.get('formatted_time', '')
+            preview = conv.get('conversation_preview', '无预览')
+            message_count = conv.get('message_count', 0)
+            
+            print(f"[{i}] {formatted_time} - {preview}")
+            print(f"    Thread ID: {thread_id} | 消息数: {message_count}")
+            print()
+        
+        print("💡 选择方式:")
+        print("   - 输入序号 (1-5): 选择对应的对话")
+        print("   - 输入 Thread ID: 直接指定对话")
+        print("   - 输入日期 (YYYY-MM-DD): 选择当天最新对话")
+        print("   - 输入 'new': 开始新对话")
+        print("   - 直接输入问题: 开始新对话")
+        print("-" * 60)
+
+    def _parse_conversation_selection(self, user_input: str) -> Dict[str, Any]:
+        """解析用户的对话选择"""
+        user_input = user_input.strip()
+        
+        # 检查是否是数字序号 (1-5)
+        if user_input.isdigit():
+            index = int(user_input)
+            if 1 <= index <= len(self.recent_conversations):
+                selected_conv = self.recent_conversations[index - 1]
+                return {
+                    "type": "select_by_index",
+                    "thread_id": selected_conv["thread_id"],
+                    "preview": selected_conv["conversation_preview"]
+                }
+            else:
+                return {"type": "invalid_index", "message": f"序号 {index} 无效,请输入 1-{len(self.recent_conversations)}"}
+        
+        # 检查是否是 Thread ID 格式 (包含冒号)
+        if ':' in user_input and len(user_input.split(':')) == 2:
+            user_part, timestamp_part = user_input.split(':')
+            # 简单验证格式
+            if user_part == self.user_id and timestamp_part.isdigit():
+                # 检查该Thread ID是否存在于历史对话中
+                for conv in self.recent_conversations:
+                    if conv["thread_id"] == user_input:
+                        return {
+                            "type": "select_by_thread_id",
+                            "thread_id": user_input,
+                            "preview": conv["conversation_preview"]
+                        }
+                return {"type": "thread_not_found", "message": f"Thread ID {user_input} 不存在于最近的对话中"}
+        
+        # 检查是否是日期格式 (YYYY-MM-DD)
+        import re
+        date_pattern = r'^\d{4}-\d{2}-\d{2}$'
+        if re.match(date_pattern, user_input):
+            # 查找该日期的最新对话
+            target_date = user_input.replace('-', '')  # 转换为 YYYYMMDD 格式
+            for conv in self.recent_conversations:
+                timestamp = conv.get('timestamp', '')
+                if timestamp.startswith(target_date):
+                    return {
+                        "type": "select_by_date",
+                        "thread_id": conv["thread_id"],
+                        "preview": f"日期 {user_input} 的对话: {conv['conversation_preview']}"
+                    }
+            return {"type": "no_date_match", "message": f"未找到 {user_input} 的对话"}
+        
+        # 检查是否是 'new' 命令
+        if user_input.lower() == 'new':
+            return {"type": "new_conversation"}
+        
+        # 其他情况当作新问题处理
+        return {"type": "new_question", "question": user_input}
+
     async def start(self):
         """启动 Shell 界面。"""
         print("\n🚀 Custom React Agent Shell (StateGraph Version)")
@@ -54,7 +151,11 @@ class CustomAgentShell:
             self.user_id = user_input
         
         print(f"👤 当前用户: {self.user_id}")
-        # 这里可以加入显示历史会话的逻辑
+        
+        # 获取并显示最近的对话列表
+        print("\n🔍 正在获取历史对话...")
+        self.recent_conversations = await self._fetch_recent_conversations(self.user_id, 5)
+        self._display_conversation_list(self.recent_conversations)
         
         print("\n💬 开始对话 (输入 'exit' 或 'quit' 退出)")
         print("-" * 50)
@@ -81,7 +182,45 @@ class CustomAgentShell:
                 await self._show_current_history()
                 continue
             
-            # 正常对话
+            # 如果还没有选择对话,且有历史对话,则处理对话选择
+            if self.thread_id is None and self.recent_conversations:
+                selection = self._parse_conversation_selection(user_input)
+                
+                if selection["type"] == "select_by_index":
+                    self.thread_id = selection["thread_id"]
+                    print(f"📖 已选择对话: {selection['preview']}")
+                    print(f"💬 Thread ID: {self.thread_id}")
+                    print("现在可以在此对话中继续聊天...\n")
+                    continue
+                
+                elif selection["type"] == "select_by_thread_id":
+                    self.thread_id = selection["thread_id"]
+                    print(f"📖 已选择对话: {selection['preview']}")
+                    print("现在可以在此对话中继续聊天...\n")
+                    continue
+                
+                elif selection["type"] == "select_by_date":
+                    self.thread_id = selection["thread_id"]
+                    print(f"📖 已选择对话: {selection['preview']}")
+                    print("现在可以在此对话中继续聊天...\n")
+                    continue
+                
+                elif selection["type"] == "new_conversation":
+                    self.thread_id = None
+                    print("🆕 已开始新会话。")
+                    continue
+                
+                elif selection["type"] == "new_question":
+                    # 当作新问题处理,继续下面的正常对话流程
+                    user_input = selection["question"]
+                    self.thread_id = None
+                    print("🆕 开始新对话...")
+                
+                elif selection["type"] in ["invalid_index", "no_date_match", "thread_not_found"]:
+                    print(f"❌ {selection['message']}")
+                    continue
+            
+            # 正常对话流程
             print("🤖 Agent 正在思考...")
             result = await self.agent.chat(user_input, self.user_id, self.thread_id)
             

+ 81 - 0
test/custom_react_agent/test_shell_features.py

@@ -0,0 +1,81 @@
+#!/usr/bin/env python3
+"""
+测试 shell.py 新增的对话选择功能
+"""
+import asyncio
+import sys
+import os
+
+# 确保导入路径正确
+CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, CURRENT_DIR)
+
+from shell import CustomAgentShell
+
+async def test_conversation_selection():
+    """测试对话选择功能"""
+    print("🧪 测试对话选择功能...")
+    
+    try:
+        # 创建shell实例
+        shell = await CustomAgentShell.create()
+        print("✅ Shell创建成功!")
+        
+        # 设置测试数据
+        shell.user_id = 'test_user'
+        shell.recent_conversations = [
+            {
+                'thread_id': 'test_user:20250101120000001', 
+                'conversation_preview': 'Python编程问题',
+                'timestamp': '20250101120000001',
+                'formatted_time': '2025-01-01 12:00:00'
+            },
+            {
+                'thread_id': 'test_user:20250101130000001', 
+                'conversation_preview': 'SQL查询帮助',
+                'timestamp': '20250101130000001',
+                'formatted_time': '2025-01-01 13:00:00'
+            },
+        ]
+        
+        print("\n📋 测试对话选择解析:")
+        
+        # 测试不同的选择类型
+        test_cases = [
+            ('1', '数字序号选择'),
+            ('test_user:20250101120000001', 'Thread ID选择'),
+            ('2025-01-01', '日期选择'),
+            ('new', '新对话命令'),
+            ('What is Python?', '新问题'),
+            ('999', '无效序号'),
+            ('wrong_user:20250101120000001', '无效Thread ID'),
+            ('2025-12-31', '无效日期'),
+        ]
+        
+        for user_input, description in test_cases:
+            result = shell._parse_conversation_selection(user_input)
+            print(f"   输入: '{user_input}' ({description})")
+            print(f"   结果: {result['type']}")
+            if 'message' in result:
+                print(f"   消息: {result['message']}")
+            elif 'thread_id' in result:
+                print(f"   Thread ID: {result['thread_id']}")
+            print()
+        
+        print("📄 测试对话列表显示:")
+        shell._display_conversation_list(shell.recent_conversations)
+        
+        # 测试获取对话功能(这个需要真实的Agent连接)
+        print("\n🔍 测试获取对话功能:")
+        print("   (需要Redis和Agent连接,此处跳过)")
+        
+        await shell.close()
+        print("✅ 所有测试完成!")
+        
+    except Exception as e:
+        print(f"❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+
+if __name__ == "__main__":
+    asyncio.run(test_conversation_selection())