瀏覽代碼

修复chat api的返回格式和数据查询错误.

wangxq 1 月之前
父節點
當前提交
d4ffb11686
共有 3 個文件被更改,包括 239 次插入8 次删除
  1. 13 1
      test/custom_react_agent/agent.py
  2. 28 7
      test/custom_react_agent/api.py
  3. 198 0
      test/custom_react_agent/test_api_design.py

+ 13 - 1
test/custom_react_agent/agent.py

@@ -727,7 +727,13 @@ class CustomReactAgent:
         }
 
         try:
+            logger.info(f"🚀 开始处理用户消息: {message[:50]}...")
+            
             final_state = await self.agent_executor.ainvoke(inputs, config)
+            
+            # 🔍 调试:打印 final_state 的所有 keys
+            logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
+            
             answer = final_state["messages"][-1].content
             
             # 🎯 提取最近的 run_sql 执行结果(不修改messages)
@@ -747,10 +753,16 @@ class CustomReactAgent:
                 result["sql_data"] = sql_data
                 logger.info("   📊 已包含SQL原始数据")
             
-            # 🎯 如果存在API格式数据,也添加到返回结果中(用于API层)
+            # 🔧 修复:检查 api_data 是否在 final_state 中
             if "api_data" in final_state:
                 result["api_data"] = final_state["api_data"]
                 logger.info("   🔌 已包含API格式数据")
+            else:
+                # 🔧 备用方案:如果 final_state 中没有 api_data,手动生成
+                logger.warning("   ⚠️ final_state 中未找到 api_data,手动生成...")
+                api_data = await self._async_generate_api_data(final_state)
+                result["api_data"] = api_data
+                logger.info("   🔌 已手动生成API格式数据")
             
             return result
             

+ 28 - 7
test/custom_react_agent/api.py

@@ -6,9 +6,16 @@ import asyncio
 import logging
 import atexit
 import os
+import sys
 from datetime import datetime
 from typing import Optional, Dict, Any
 
+# 🔧 修复模块路径问题:添加项目根目录到 sys.path
+CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, '..', '..'))
+sys.path.insert(0, CURRENT_DIR)  # 当前目录优先
+sys.path.insert(1, PROJECT_ROOT)  # 项目根目录
+
 from flask import Flask, request, jsonify
 import redis.asyncio as redis
 
@@ -31,25 +38,27 @@ def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
     """验证请求数据"""
     errors = []
     
-    # 验证 question
+    # 验证 question(必填)
     question = data.get('question', '')
     if not question or not question.strip():
         errors.append('问题不能为空')
     elif len(question) > 2000:
         errors.append('问题长度不能超过2000字符')
     
-    # 验证 user_id
+    # 验证 user_id(可选,默认为"guest")
     user_id = data.get('user_id', 'guest')
     if user_id and len(user_id) > 50:
         errors.append('用户ID长度不能超过50字符')
     
+    # thread_id 是可选的,不需要验证
+    
     if errors:
         raise ValueError('; '.join(errors))
     
     return {
         'question': question.strip(),
         'user_id': user_id or 'guest',
-        'thread_id': data.get('thread_id')
+        'thread_id': data.get('thread_id')  # 可选,不传则自动生成新会话
     }
 
 async def initialize_agent():
@@ -194,16 +203,28 @@ async def chat_endpoint():
                 }
             }), 500
         
-        # Agent处理成功,提取数据
+        # Agent处理成功,按照设计文档格式化响应
         api_data = agent_result.get("api_data", {})
         
-        # 构建最终响应
+        # 构建符合设计文档的响应数据
         response_data = {
-            **api_data,  # 包含Agent格式化的所有数据
+            "response": api_data.get("response", ""),
+            "react_agent_meta": api_data.get("react_agent_meta", {
+                "thread_id": agent_result.get("thread_id"),
+                "agent_version": "custom_react_v1"
+            }),
             "timestamp": datetime.now().isoformat()
         }
         
-        logger.info(f"✅ 请求处理成功 - Thread: {api_data.get('react_agent_meta', {}).get('thread_id')}")
+        # 可选字段:SQL(仅当执行SQL时存在)
+        if "sql" in api_data:
+            response_data["sql"] = api_data["sql"]
+        
+        # 可选字段:records(仅当有查询结果时存在)
+        if "records" in api_data:
+            response_data["records"] = api_data["records"]
+        
+        logger.info(f"✅ 请求处理成功 - Thread: {response_data['react_agent_meta'].get('thread_id')}")
         
         return jsonify({
             "code": 200,

+ 198 - 0
test/custom_react_agent/test_api_design.py

@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+"""
+测试修改后的 API 是否符合设计文档要求
+"""
+import json
+import asyncio
+import aiohttp
+from typing import Dict, Any
+
+async def test_api_design_compliance():
+    """测试 API 设计文档合规性"""
+    
+    base_url = "http://localhost:8000"
+    
+    # 测试用例
+    test_cases = [
+        {
+            "name": "基本聊天测试",
+            "payload": {
+                "question": "你好,我想了解一下今天的天气",
+                "user_id": "wang"
+            },
+            "expected_fields": ["response", "react_agent_meta", "timestamp"]
+        },
+        {
+            "name": "SQL查询测试",
+            "payload": {
+                "question": "请查询服务区的收入数据",
+                "user_id": "test_user"
+            },
+            "expected_fields": ["response", "sql", "records", "react_agent_meta", "timestamp"]
+        },
+        {
+            "name": "继续对话测试",
+            "payload": {
+                "question": "请详细说明一下",
+                "user_id": "wang",
+                "thread_id": None  # 将在第一个测试后设置
+            },
+            "expected_fields": ["response", "react_agent_meta", "timestamp"]
+        }
+    ]
+    
+    session = aiohttp.ClientSession()
+    
+    try:
+        print("🧪 开始测试 API 设计文档合规性...")
+        print("=" * 60)
+        
+        thread_id = None
+        
+        for i, test_case in enumerate(test_cases, 1):
+            print(f"\n📋 测试 {i}: {test_case['name']}")
+            print("-" * 40)
+            
+            # 如果是继续对话测试,使用之前的 thread_id
+            if test_case["name"] == "继续对话测试" and thread_id:
+                test_case["payload"]["thread_id"] = thread_id
+            
+            # 发送请求
+            async with session.post(
+                f"{base_url}/api/chat",
+                json=test_case["payload"],
+                headers={"Content-Type": "application/json"}
+            ) as response:
+                
+                print(f"📊 HTTP状态码: {response.status}")
+                
+                if response.status != 200:
+                    print(f"❌ 请求失败,状态码: {response.status}")
+                    continue
+                
+                # 解析响应
+                result = await response.json()
+                
+                # 验证顶级结构
+                required_top_fields = ["code", "message", "success", "data"]
+                for field in required_top_fields:
+                    if field not in result:
+                        print(f"❌ 缺少顶级字段: {field}")
+                    else:
+                        print(f"✅ 顶级字段 {field}: {result[field]}")
+                
+                # 验证 data 字段结构
+                if "data" in result:
+                    data = result["data"]
+                    print(f"\n📦 data 字段包含: {list(data.keys())}")
+                    
+                    # 验证必需字段
+                    required_fields = ["response", "react_agent_meta", "timestamp"]
+                    for field in required_fields:
+                        if field not in data:
+                            print(f"❌ data 中缺少必需字段: {field}")
+                        else:
+                            print(f"✅ 必需字段 {field}: 存在")
+                    
+                    # 验证可选字段
+                    optional_fields = ["sql", "records"]
+                    for field in optional_fields:
+                        if field in data:
+                            print(f"✅ 可选字段 {field}: 存在")
+                        else:
+                            print(f"ℹ️  可选字段 {field}: 不存在(正常)")
+                    
+                    # 验证 react_agent_meta 结构
+                    if "react_agent_meta" in data:
+                        meta = data["react_agent_meta"]
+                        print(f"\n🔧 react_agent_meta 字段: {list(meta.keys())}")
+                        
+                        # 保存 thread_id 用于后续测试
+                        if "thread_id" in meta:
+                            thread_id = meta["thread_id"]
+                            print(f"🆔 Thread ID: {thread_id}")
+                    
+                    # 验证 records 结构(如果存在)
+                    if "records" in data:
+                        records = data["records"]
+                        print(f"\n📊 records 字段: {list(records.keys())}")
+                        required_record_fields = ["columns", "rows", "total_row_count", "is_limited"]
+                        for field in required_record_fields:
+                            if field not in records:
+                                print(f"❌ records 中缺少字段: {field}")
+                            else:
+                                print(f"✅ records 字段 {field}: 存在")
+                
+                print(f"\n✅ 测试 {i} 完成")
+        
+        print("\n" + "=" * 60)
+        print("🎉 所有测试完成!")
+        
+    except Exception as e:
+        print(f"❌ 测试过程中发生错误: {e}")
+        import traceback
+        traceback.print_exc()
+    
+    finally:
+        await session.close()
+
+async def test_error_handling():
+    """测试错误处理"""
+    
+    base_url = "http://localhost:8000"
+    session = aiohttp.ClientSession()
+    
+    try:
+        print("\n🧪 测试错误处理...")
+        print("=" * 60)
+        
+        # 测试参数错误
+        test_cases = [
+            {
+                "name": "缺少问题",
+                "payload": {"user_id": "test"},
+                "expected_code": 400
+            },
+            {
+                "name": "空问题",
+                "payload": {"question": "", "user_id": "test"},
+                "expected_code": 400
+            },
+            {
+                "name": "问题过长",
+                "payload": {"question": "x" * 2001, "user_id": "test"},
+                "expected_code": 400
+            }
+        ]
+        
+        for test_case in test_cases:
+            print(f"\n📋 错误测试: {test_case['name']}")
+            
+            async with session.post(
+                f"{base_url}/api/chat",
+                json=test_case["payload"],
+                headers={"Content-Type": "application/json"}
+            ) as response:
+                
+                result = await response.json()
+                
+                print(f"📊 HTTP状态码: {response.status}")
+                print(f"📋 响应代码: {result.get('code')}")
+                print(f"🎯 成功状态: {result.get('success')}")
+                print(f"❌ 错误信息: {result.get('error')}")
+                
+                if response.status == test_case["expected_code"]:
+                    print("✅ 错误处理正确")
+                else:
+                    print(f"❌ 期望状态码 {test_case['expected_code']}, 实际 {response.status}")
+    
+    finally:
+        await session.close()
+
+if __name__ == "__main__":
+    print("🚀 启动 API 设计文档合规性测试")
+    print("请确保 API 服务已启动 (python api.py)")
+    print("=" * 60)
+    
+    asyncio.run(test_api_design_compliance())
+    asyncio.run(test_error_handling())