Browse Source

成功添加ask_agent_stream api,准备增加ask_react_agent_stream api.

wangxq 2 weeks ago
parent
commit
7d9f7a35a7
4 changed files with 544 additions and 274 deletions
  1. 175 0
      agent/citu_agent.py
  2. 0 154
      test_redis_modules.py
  3. 0 120
      test_vector_backup_only.py
  4. 369 0
      unified_api.py

+ 175 - 0
agent/citu_agent.py

@@ -776,6 +776,80 @@ class CituLangGraphAgent:
                 "error_code": 500,
                 "execution_path": ["error"]
             }
+
+    async def process_question_stream(self, question: str, user_id: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None):
+        """
+        流式处理用户问题 - 复用process_question()的所有逻辑
+        
+        Args:
+            question: 用户问题
+            user_id: 用户ID,用于生成conversation_id
+            conversation_id: 对话ID,可选,不提供则自动生成
+            context_type: 上下文类型(保留兼容性参数,当前未使用)
+            routing_mode: 路由模式,可选,用于覆盖配置文件设置
+            
+        Yields:
+            Dict: 流式状态更新,包含进度信息或最终结果
+        """
+        try:
+            self.logger.info(f"🌊 [STREAM] 开始流式处理问题: {question}")
+            if context_type:
+                self.logger.info(f"🌊 [STREAM] 上下文类型: {context_type}")
+            if routing_mode:
+                self.logger.info(f"🌊 [STREAM] 使用指定路由模式: {routing_mode}")
+            
+            # 生成conversation_id(如果未提供)
+            if not conversation_id:
+                conversation_id = self._generate_conversation_id(user_id)
+            
+            # 1. 复用现有的初始化逻辑
+            self.logger.info(f"🌊 [STREAM] 动态创建workflow")
+            workflow = self._create_workflow(routing_mode)
+            
+            # 2. 创建初始状态(复用现有逻辑)
+            initial_state = self._create_initial_state(question, conversation_id, context_type, routing_mode)
+            
+            # 3. 使用astream流式执行
+            self.logger.info(f"🌊 [STREAM] 开始流式执行workflow")
+            async for chunk in workflow.astream(
+                initial_state,
+                config={
+                    "configurable": {"conversation_id": conversation_id}
+                } if conversation_id else None
+            ):
+                # 处理每个节点的输出
+                for node_name, node_data in chunk.items():
+                    self.logger.debug(f"🌊 [STREAM] 收到节点输出: {node_name}")
+                    
+                    # 映射节点状态为用户友好的进度信息
+                    progress_info = self._map_node_to_progress(node_name, node_data)
+                    if progress_info:
+                        yield {
+                            "type": "progress",
+                            "node": node_name,
+                            "progress": progress_info,
+                            "state_data": self._extract_relevant_state(node_data),
+                            "conversation_id": conversation_id
+                        }
+            
+            # 4. 最终结果处理(复用现有的结果提取逻辑)
+            # 注意:由于astream的特性,最后一个chunk包含最终状态
+            final_result = node_data.get("final_response", {})
+            
+            self.logger.info(f"🌊 [STREAM] 流式处理完成: {final_result.get('success', False)}")
+            yield {
+                "type": "completed",
+                "result": final_result,
+                "conversation_id": conversation_id
+            }
+            
+        except Exception as e:
+            self.logger.error(f"🌊 [STREAM] Agent流式执行异常: {str(e)}")
+            yield {
+                "type": "error", 
+                "error": str(e),
+                "conversation_id": conversation_id
+            }
     
     def _create_initial_state(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
         """创建初始状态 - 支持兼容性参数"""
@@ -1034,6 +1108,107 @@ class CituLangGraphAgent:
                 "error": f"修复异常: {str(e)}"
             }
 
+    def _generate_conversation_id(self, user_id: str) -> str:
+        """生成对话ID - 使用与React Agent一致的格式"""
+        import pandas as pd
+        timestamp = pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')[:-3]  # 去掉最后3位微秒
+        return f"{user_id}:{timestamp}"
+
+    def _map_node_to_progress(self, node_name: str, node_data: dict) -> dict:
+        """将节点执行状态映射为用户友好的进度信息"""
+        
+        if node_name == "classify_question":
+            question_type = node_data.get("question_type", "UNCERTAIN")
+            confidence = node_data.get("classification_confidence", 0)
+            return {
+                "display_name": "分析问题类型",
+                "icon": "🤔",
+                "details": f"问题类型: {question_type} (置信度: {confidence:.2f})",
+                "sub_status": f"使用{node_data.get('classification_method', '未知')}方法分类"
+            }
+        
+        elif node_name == "agent_sql_generation":
+            if node_data.get("sql_generation_success"):
+                sql = node_data.get("sql", "")
+                sql_preview = sql[:50] + "..." if len(sql) > 50 else sql
+                return {
+                    "display_name": "SQL生成成功",
+                    "icon": "✅",
+                    "details": f"生成SQL: {sql_preview}",
+                    "sub_status": "验证通过,准备执行"
+                }
+            else:
+                error_type = node_data.get("validation_error_type", "unknown")
+                return {
+                    "display_name": "SQL生成处理中",
+                    "icon": "🔧",
+                    "details": f"验证状态: {error_type}",
+                    "sub_status": node_data.get("user_prompt", "正在处理")
+                }
+        
+        elif node_name == "agent_sql_execution":
+            query_result = node_data.get("query_result", {})
+            row_count = query_result.get("row_count", 0)
+            return {
+                "display_name": "执行数据查询", 
+                "icon": "⚙️",
+                "details": f"查询完成,返回 {row_count} 行数据",
+                "sub_status": "正在生成摘要" if row_count > 0 else "查询执行完成"
+            }
+        
+        elif node_name == "agent_chat":
+            return {
+                "display_name": "思考回答",
+                "icon": "💭", 
+                "details": "正在处理您的问题",
+                "sub_status": "使用智能对话模式"
+            }
+        
+        elif node_name == "format_response":
+            return {
+                "display_name": "整理结果",
+                "icon": "📝",
+                "details": "正在格式化响应结果",
+                "sub_status": "即将完成"
+            }
+        
+        return None
+
+    def _extract_relevant_state(self, node_data: dict) -> dict:
+        """从节点数据中提取相关的状态信息,过滤敏感信息"""
+        try:
+            relevant_keys = [
+                "current_step", "execution_path", "question_type",
+                "classification_confidence", "classification_method", 
+                "sql_generation_success", "sql_validation_success",
+                "routing_mode"
+            ]
+            
+            extracted = {}
+            for key in relevant_keys:
+                if key in node_data:
+                    extracted[key] = node_data[key]
+            
+            # 特殊处理SQL:只返回前100个字符避免过长
+            if "sql" in node_data and node_data["sql"]:
+                sql = str(node_data["sql"])
+                extracted["sql_preview"] = sql[:100] + "..." if len(sql) > 100 else sql
+            
+            # 特殊处理查询结果:只返回行数统计
+            if "query_result" in node_data and node_data["query_result"]:
+                query_result = node_data["query_result"]
+                if isinstance(query_result, dict):
+                    extracted["query_summary"] = {
+                        "row_count": query_result.get("row_count", 0),
+                        "column_count": len(query_result.get("columns", []))
+                    }
+            
+            return extracted
+            
+        except Exception as e:
+            self.logger.warning(f"提取状态信息失败: {str(e)}")
+            return {"error": "state_extraction_failed"}
+
     # ==================== 原有方法 ====================
     
     def _extract_original_question(self, question: str) -> str:

+ 0 - 154
test_redis_modules.py

@@ -1,154 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Redis模块测试脚本
-用于检测Redis服务器是否安装了RediSearch和ReJSON模块
-"""
-
-import redis
-import sys
-from typing import Dict, Any
-
-
-def test_redis_modules(host: str = 'localhost', port: int = 6379, password: str = None, db: int = 0) -> Dict[str, Any]:
-    """
-    测试Redis服务器是否安装了RediSearch和ReJSON模块
-    
-    Args:
-        host: Redis服务器地址
-        port: Redis服务器端口
-        password: Redis密码(可选)
-        db: 数据库编号
-    
-    Returns:
-        包含测试结果的字典
-    """
-    results = {
-        'redis_connection': False,
-        'redijson_available': False,
-        'redisearch_available': False,
-        'errors': []
-    }
-    
-    try:
-        # 连接Redis
-        r = redis.Redis(host=host, port=port, password=password, db=db, decode_responses=True)
-        
-        # 测试连接
-        r.ping()
-        results['redis_connection'] = True
-        print(f"✅ Redis连接成功 - {host}:{port}")
-        
-    except Exception as e:
-        error_msg = f"❌ Redis连接失败: {str(e)}"
-        results['errors'].append(error_msg)
-        print(error_msg)
-        return results
-    
-    # 测试RedisJSON
-    try:
-        # 尝试设置JSON文档
-        r.execute_command('JSON.SET', 'test_doc', '$', '{"test":"value"}')
-        # 尝试获取JSON文档
-        result = r.execute_command('JSON.GET', 'test_doc')
-        # 清理测试数据
-        r.execute_command('JSON.DEL', 'test_doc')
-        
-        results['redijson_available'] = True
-        print("✅ RedisJSON 模块可用")
-        
-    except redis.exceptions.ResponseError as e:
-        error_msg = f"❌ RedisJSON 模块不可用: {str(e)}"
-        results['errors'].append(error_msg)
-        print(error_msg)
-    except Exception as e:
-        error_msg = f"❌ RedisJSON 测试失败: {str(e)}"
-        results['errors'].append(error_msg)
-        print(error_msg)
-    
-    # 测试RediSearch
-    try:
-        # 尝试创建索引
-        r.execute_command('FT.CREATE', 'test_idx', 'ON', 'HASH', 'PREFIX', '1', 'test:', 'SCHEMA', 'title', 'TEXT')
-        # 清理测试索引
-        r.execute_command('FT.DROPINDEX', 'test_idx')
-        
-        results['redisearch_available'] = True
-        print("✅ RediSearch 模块可用")
-        
-    except redis.exceptions.ResponseError as e:
-        error_msg = f"❌ RediSearch 模块不可用: {str(e)}"
-        results['errors'].append(error_msg)
-        print(error_msg)
-    except Exception as e:
-        error_msg = f"❌ RediSearch 测试失败: {str(e)}"
-        results['errors'].append(error_msg)
-        print(error_msg)
-    
-    return results
-
-
-def main():
-    """主函数"""
-    print("=" * 60)
-    print("Redis模块测试工具")
-    print("=" * 60)
-    
-    # 获取用户输入的Redis连接信息
-    print("\n请输入Redis服务器连接信息:")
-    host = 'localhost'
-    port_input = '6379'
-    password =  None
-    db_input = '0'
-    
-    try:
-        port = int(port_input)
-        db = int(db_input)
-    except ValueError:
-        print("❌ 端口和数据库编号必须是数字")
-        sys.exit(1)
-    
-    print(f"\n正在测试Redis服务器: {host}:{port}")
-    print("-" * 40)
-    
-    # 执行测试
-    results = test_redis_modules(host=host, port=port, password=password, db=db)
-    
-    # 输出测试总结
-    print("\n" + "=" * 60)
-    print("测试结果总结:")
-    print("=" * 60)
-    
-    if results['redis_connection']:
-        print("✅ Redis连接: 成功")
-    else:
-        print("❌ Redis连接: 失败")
-    
-    if results['redijson_available']:
-        print("✅ RedisJSON: 已安装")
-    else:
-        print("❌ RedisJSON: 未安装")
-    
-    if results['redisearch_available']:
-        print("✅ RediSearch: 已安装")
-    else:
-        print("❌ RediSearch: 未安装")
-    
-    if results['errors']:
-        print(f"\n错误信息:")
-        for error in results['errors']:
-            print(f"  - {error}")
-    
-    print("\n" + "=" * 60)
-    
-    # 返回适当的退出码
-    if results['redis_connection'] and results['redijson_available'] and results['redisearch_available']:
-        print("🎉 所有模块都可用!")
-        sys.exit(0)
-    else:
-        print("⚠️  部分模块不可用,请检查Redis配置")
-        sys.exit(1)
-
-
-if __name__ == "__main__":
-    main() 

+ 0 - 120
test_vector_backup_only.py

@@ -1,120 +0,0 @@
-#!/usr/bin/env python3
-"""
-独立测试Vector表备份功能
-只备份langchain_pg_collection和langchain_pg_embedding表
-"""
-
-import asyncio
-import os
-from pathlib import Path
-from datetime import datetime
-
-
-async def test_vector_backup():
-    """测试vector表备份功能"""
-    
-    print("🧪 开始测试Vector表备份功能...")
-    print("=" * 50)
-    
-    # 1. 设置测试输出目录
-    test_dir = Path("./test_vector_backup_output")
-    test_dir.mkdir(exist_ok=True)
-    
-    print(f"📁 测试输出目录: {test_dir.resolve()}")
-    
-    try:
-        # 2. 导入VectorTableManager
-        from data_pipeline.trainer.vector_table_manager import VectorTableManager
-        
-        # 3. 创建管理器实例
-        task_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
-        vector_manager = VectorTableManager(
-            task_output_dir=str(test_dir), 
-            task_id=task_id
-        )
-        
-        print(f"🆔 任务ID: {task_id}")
-        print("🔧 VectorTableManager 创建成功")
-        
-        # 4. 执行备份(只备份,不清空)
-        print("\n🗂️ 开始执行备份...")
-        result = vector_manager.execute_vector_management(
-            backup=True,    # 执行备份
-            truncate=False  # 不清空表
-        )
-        
-        # 5. 显示结果
-        print("\n📊 备份结果:")
-        print("=" * 30)
-        
-        if result.get("backup_performed", False):
-            print("✅ 备份状态: 已执行")
-            
-            tables_info = result.get("tables_backed_up", {})
-            for table_name, info in tables_info.items():
-                if info.get("success", False):
-                    print(f"  ✅ {table_name}: {info['row_count']}行 -> {info['backup_file']} ({info['file_size']})")
-                else:
-                    print(f"  ❌ {table_name}: 失败 - {info.get('error', '未知错误')}")
-        else:
-            print("❌ 备份状态: 未执行")
-        
-        duration = result.get("duration", 0)
-        print(f"⏱️  总耗时: {duration:.2f}秒")
-        
-        errors = result.get("errors", [])
-        if errors:
-            print(f"⚠️  错误信息: {'; '.join(errors)}")
-        
-        # 6. 检查生成的文件
-        backup_dir = test_dir / "vector_bak"
-        if backup_dir.exists():
-            print(f"\n📂 备份文件目录: {backup_dir.resolve()}")
-            backup_files = list(backup_dir.glob("*.csv"))
-            if backup_files:
-                print("📄 生成的备份文件:")
-                for file in backup_files:
-                    file_size = file.stat().st_size
-                    print(f"  📄 {file.name} ({file_size} bytes)")
-            else:
-                print("⚠️  未找到CSV备份文件")
-                
-            log_files = list(backup_dir.glob("*.txt"))
-            if log_files:
-                print("📋 日志文件:")
-                for file in log_files:
-                    print(f"  📋 {file.name}")
-        else:
-            print("❌ 备份目录不存在")
-        
-        print("\n🎉 测试完成!")
-        return True
-        
-    except Exception as e:
-        print(f"\n❌ 测试失败: {e}")
-        import traceback
-        print("详细错误信息:")
-        print(traceback.format_exc())
-        return False
-
-
-def main():
-    """主函数"""
-    print("Vector表备份功能独立测试")
-    print("测试目标: langchain_pg_collection, langchain_pg_embedding")
-    print("数据库: 从 data_pipeline.config 自动获取连接配置")
-    print()
-    
-    # 运行异步测试
-    success = asyncio.run(test_vector_backup())
-    
-    if success:
-        print("\n✅ 所有测试通过!")
-        exit(0)
-    else:
-        print("\n❌ 测试失败!")
-        exit(1)
-
-
-if __name__ == "__main__":
-    main() 

+ 369 - 0
unified_api.py

@@ -1335,6 +1335,375 @@ def ask_agent():
             response_text="查询处理失败,请稍后重试"
         )), 500
 
+@app.route('/api/v0/ask_agent_stream', methods=['GET'])
+def ask_agent_stream():
+    """Citu Agent 流式API - 支持实时进度显示(EventSource只支持GET请求)
+    功能与ask_agent完全相同,除了采用流式输出"""
+    
+    def generate():
+        try:
+            # 从URL参数获取数据(EventSource只支持GET请求)
+            question = request.args.get('question')
+            user_id_input = request.args.get('user_id')
+            conversation_id_input = request.args.get('conversation_id')
+            continue_conversation = request.args.get('continue_conversation', 'false').lower() == 'true'
+            api_routing_mode = request.args.get('routing_mode')
+            
+            VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
+            
+            # 参数验证
+            if not question:
+                yield format_sse_error("缺少必需参数:question")
+                return
+            
+            if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
+                yield format_sse_error(f"无效的routing_mode参数值: {api_routing_mode}")
+                return
+
+            # 🆕 用户ID和对话ID一致性校验(与ask_agent相同)
+            try:
+                # 获取登录用户ID
+                login_user_id = session.get('user_id') if 'user_id' in session else None
+                
+                # 用户ID和对话ID一致性校验
+                from common.session_aware_cache import ConversationAwareMemoryCache
+                
+                # 如果传递了conversation_id,从中解析user_id
+                extracted_user_id = None
+                if conversation_id_input:
+                    extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
+                    
+                    # 如果同时传递了user_id和conversation_id,进行一致性校验
+                    if user_id_input:
+                        is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
+                            conversation_id_input, user_id_input
+                        )
+                        if not is_valid:
+                            yield format_sse_error(error_msg)
+                            return
+                    
+                    # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
+                    elif not user_id_input and extracted_user_id:
+                        user_id_input = extracted_user_id
+                        logger.info(f"从conversation_id解析出user_id: {user_id_input}")
+                
+                # 如果没有传递user_id,使用默认值guest
+                if not user_id_input:
+                    user_id_input = "guest"
+                    logger.info("未传递user_id,使用默认值: guest")
+                
+                # 🆕 智能ID解析(与ask_agent相同)
+                user_id = redis_conversation_manager.resolve_user_id(
+                    user_id_input, None, request.remote_addr, login_user_id
+                )
+                conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
+                    user_id, conversation_id_input, continue_conversation
+                )
+                
+            except Exception as e:
+                logger.error(f"用户ID和对话ID解析失败: {str(e)}")
+                yield format_sse_error(f"参数解析失败: {str(e)}")
+                return
+
+            logger.info(f"[STREAM_API] 收到请求 - 问题: {question[:50]}..., 用户: {user_id}, 对话: {conversation_id}")
+
+            # 🆕 获取上下文和上下文类型(与ask_agent相同)
+            context = redis_conversation_manager.get_context(conversation_id)
+            
+            # 获取上下文类型:从最后一条助手消息的metadata中获取类型
+            context_type = None
+            if context:
+                try:
+                    # 获取最后一条助手消息的metadata
+                    messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit=10)
+                    for message in reversed(messages):  # 从最新的开始找
+                        if message.get("role") == "assistant":
+                            metadata = message.get("metadata", {})
+                            context_type = metadata.get("type")
+                            if context_type:
+                                logger.info(f"[STREAM_API] 检测到上下文类型: {context_type}")
+                                break
+                except Exception as e:
+                    logger.warning(f"获取上下文类型失败: {str(e)}")
+            
+            # 🆕 检查缓存(与ask_agent相同)
+            cached_answer = redis_conversation_manager.get_cached_answer(question, context)
+            if cached_answer:
+                logger.info(f"[STREAM_API] 使用缓存答案")
+                
+                # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
+                cached_response_type = cached_answer.get("type", "UNKNOWN")
+                if cached_response_type == "DATABASE":
+                    # DATABASE类型:按优先级选择内容
+                    if cached_answer.get("response"):
+                        assistant_response = cached_answer.get("response")
+                    elif cached_answer.get("summary"):
+                        assistant_response = cached_answer.get("summary")
+                    elif cached_answer.get("query_result"):
+                        query_result = cached_answer.get("query_result")
+                        row_count = query_result.get("row_count", 0)
+                        assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
+                    else:
+                        assistant_response = "查询处理完成"
+                else:
+                    assistant_response = cached_answer.get("response", "处理完成")
+                
+                # 返回缓存结果的SSE格式
+                yield format_sse_completed({
+                    "type": "completed",
+                    "result": {
+                        "success": True,
+                        "type": cached_response_type,
+                        "response": assistant_response,
+                        "sql": cached_answer.get("sql"),
+                        "query_result": cached_answer.get("query_result"),
+                        "summary": cached_answer.get("summary"),
+                        "conversation_id": conversation_id,
+                        "execution_path": cached_answer.get("execution_path", []),
+                        "classification_info": cached_answer.get("classification_info", {}),
+                        "user_id": user_id,
+                        "context_used": bool(context),
+                        "from_cache": True,
+                        "conversation_status": conversation_status["status"],
+                        "requested_conversation_id": conversation_status.get("requested_id")
+                    }
+                })
+                return
+            
+            # 🆕 保存用户消息(与ask_agent相同)
+            redis_conversation_manager.save_message(conversation_id, "user", question)
+            
+            # 🆕 构建带上下文的问题(与ask_agent相同)
+            if context:
+                enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
+                logger.info(f"[STREAM_API] 使用上下文,长度: {len(context)}字符")
+            else:
+                enhanced_question = question
+                logger.info(f"[STREAM_API] 新对话,无上下文")
+
+            # 获取Agent实例
+            try:
+                agent = get_citu_langraph_agent()
+                if not agent:
+                    yield format_sse_error("Agent实例获取失败")
+                    return
+                    
+                # 检查是否有process_question_stream方法
+                if not hasattr(agent, 'process_question_stream'):
+                    yield format_sse_error("Agent不支持流式处理")
+                    return
+                    
+            except Exception as e:
+                logger.error(f"Agent初始化失败: {str(e)}")
+                yield format_sse_error("AI服务暂时不可用,请稍后重试")
+                return
+            
+            # 🆕 确定最终使用的路由模式(与ask_agent相同)
+            if api_routing_mode:
+                # API传了参数,优先使用
+                effective_routing_mode = api_routing_mode
+                logger.info(f"[STREAM_API] 使用API指定的路由模式: {effective_routing_mode}")
+            else:
+                # API没传参数,使用配置文件
+                try:
+                    from app_config import QUESTION_ROUTING_MODE
+                    effective_routing_mode = QUESTION_ROUTING_MODE
+                    logger.info(f"[STREAM_API] 使用配置文件路由模式: {effective_routing_mode}")
+                except ImportError:
+                    effective_routing_mode = "hybrid"
+                    logger.info(f"[STREAM_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
+            
+            # 流式处理 - 实时转发
+            try:
+                import asyncio
+                
+                # 获取当前事件循环,如果没有则创建新的
+                try:
+                    loop = asyncio.get_event_loop()
+                except RuntimeError:
+                    loop = asyncio.new_event_loop()
+                    asyncio.set_event_loop(loop)
+                
+                # 用于收集最终结果,以便保存到Redis
+                final_result = None
+                
+                # 异步生成器,实时yield数据
+                async def stream_generator():
+                    nonlocal final_result
+                    try:
+                        async for chunk in agent.process_question_stream(
+                            question=enhanced_question,  # 🆕 使用增强后的问题
+                            user_id=user_id,
+                            conversation_id=conversation_id,
+                            context_type=context_type,  # 🆕 传递上下文类型
+                            routing_mode=effective_routing_mode
+                        ):
+                            # 如果是完成的chunk,保存最终结果
+                            if chunk.get("type") == "completed":
+                                final_result = chunk.get("result")
+                            yield chunk
+                    except Exception as e:
+                        logger.error(f"流式处理异常: {str(e)}")
+                        yield {"type": "error", "error": str(e)}
+                
+                # 同步包装器,实时转发数据
+                def sync_stream_wrapper():
+                    # 创建异步任务
+                    async_gen = stream_generator()
+                    
+                    while True:
+                        try:
+                            # 获取下一个chunk
+                            chunk = loop.run_until_complete(async_gen.__anext__())
+                            
+                            if chunk["type"] == "progress":
+                                yield format_sse_progress(chunk)
+                            elif chunk["type"] == "completed":
+                                yield format_sse_completed(chunk)
+                                break  # 完成后退出循环
+                            elif chunk["type"] == "error":
+                                yield format_sse_error(chunk.get("error", "未知错误"))
+                                break  # 错误后退出循环
+                                
+                        except StopAsyncIteration:
+                            # 异步生成器结束
+                            break
+                        except Exception as e:
+                            logger.error(f"流式转发异常: {str(e)}")
+                            yield format_sse_error(f"流式处理异常: {str(e)}")
+                            break
+                
+                # 返回同步生成器
+                yield from sync_stream_wrapper()
+                
+                # 🆕 保存助手消息和缓存结果(与ask_agent相同)
+                if final_result and final_result.get("success", False):
+                    try:
+                        response_type = final_result.get("type", "UNKNOWN")
+                        response_text = final_result.get("response", "")
+                        sql = final_result.get("sql")
+                        query_result = final_result.get("query_result")
+                        summary = final_result.get("summary")
+                        execution_path = final_result.get("execution_path", [])
+                        classification_info = final_result.get("classification_info", {})
+                        
+                        # 确定助手回复内容的优先级
+                        if response_type == "DATABASE":
+                            if response_text:
+                                assistant_response = response_text
+                            elif summary:
+                                assistant_response = summary
+                            elif query_result:
+                                row_count = query_result.get("row_count", 0)
+                                assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
+                            else:
+                                assistant_response = "查询处理完成"
+                        else:
+                            assistant_response = response_text or "处理完成"
+                        
+                        # 保存助手消息
+                        metadata = {
+                            "type": response_type,
+                            "sql": sql,
+                            "execution_path": execution_path,
+                            "classification_info": classification_info
+                        }
+                        redis_conversation_manager.save_message(
+                            conversation_id, "assistant", assistant_response, metadata
+                        )
+                        
+                        # 缓存结果(仅缓存成功的结果)- 与ask_agent相同的调用方式
+                        redis_conversation_manager.cache_answer(question, final_result, context)
+                        logger.info(f"[STREAM_API] 结果已缓存")
+                        
+                    except Exception as e:
+                        logger.error(f"保存结果和缓存失败: {str(e)}")
+                
+            except Exception as e:
+                logger.error(f"流式处理异常: {str(e)}")
+                import traceback
+                traceback.print_exc()
+                yield format_sse_error(f"处理异常: {str(e)}")
+                
+        except Exception as e:
+            logger.error(f"流式API异常: {str(e)}")
+            yield format_sse_error(f"服务异常: {str(e)}")
+    
+    return Response(stream_with_context(generate()), mimetype='text/event-stream')
+
+def format_sse_progress(chunk: dict) -> str:
+    """格式化进度事件为SSE格式"""
+    progress = chunk.get("progress", {})
+    node = chunk.get("node")
+    
+    # 🆕 特殊处理:格式化响应节点的显示内容
+    if node == "format_response":
+        display_name = "格式化响应结果"
+        message = "正在执行:格式化响应结果"
+    else:
+        display_name = progress.get("display_name")
+        message = f"正在执行: {progress.get('display_name', '处理中')}"
+    
+    data = {
+        "code": 200,
+        "success": True,
+        "message": message,
+        "data": {
+            "type": "progress",
+            "node": node,
+            "display_name": display_name,
+            # 🆕 删除icon字段
+            "details": progress.get("details"),
+            "sub_status": progress.get("sub_status"),
+            "conversation_id": chunk.get("conversation_id"),
+            "timestamp": datetime.now().isoformat()
+        }
+    }
+    
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+def format_sse_completed(chunk: dict) -> str:
+    """格式化完成事件为SSE格式"""
+    result = chunk.get("result", {})
+    
+    data = {
+        "code": 200,
+        "success": True,
+        "message": "处理完成",
+        "data": {
+            "type": "completed",
+            "response": result.get("response", ""),
+            "response_type": result.get("type", "UNKNOWN"),
+            "sql": result.get("sql"),
+            "query_result": result.get("query_result"),
+            "summary": result.get("summary"),
+            "conversation_id": chunk.get("conversation_id"),
+            "execution_path": result.get("execution_path", []),
+            "classification_info": result.get("classification_info", {}),
+            "timestamp": datetime.now().isoformat()
+        }
+    }
+    
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+def format_sse_error(error_message: str) -> str:
+    """格式化错误事件为SSE格式"""
+    data = {
+        "code": 500,
+        "success": False,
+        "message": "处理失败",
+        "data": {
+            "type": "error",
+            "error": error_message,
+            "timestamp": datetime.now().isoformat()
+        }
+    }
+    
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
 # ==================== QA反馈系统API ====================
 
 qa_feedback_manager = None