|
@@ -1335,6 +1335,375 @@ def ask_agent():
|
|
|
response_text="查询处理失败,请稍后重试"
|
|
|
)), 500
|
|
|
|
|
|
+@app.route('/api/v0/ask_agent_stream', methods=['GET'])
|
|
|
+def ask_agent_stream():
|
|
|
+ """Citu Agent 流式API - 支持实时进度显示(EventSource只支持GET请求)
|
|
|
+ 功能与ask_agent完全相同,除了采用流式输出"""
|
|
|
+
|
|
|
+ def generate():
|
|
|
+ try:
|
|
|
+ # 从URL参数获取数据(EventSource只支持GET请求)
|
|
|
+ question = request.args.get('question')
|
|
|
+ user_id_input = request.args.get('user_id')
|
|
|
+ conversation_id_input = request.args.get('conversation_id')
|
|
|
+ continue_conversation = request.args.get('continue_conversation', 'false').lower() == 'true'
|
|
|
+ api_routing_mode = request.args.get('routing_mode')
|
|
|
+
|
|
|
+ VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
|
|
|
+
|
|
|
+ # 参数验证
|
|
|
+ if not question:
|
|
|
+ yield format_sse_error("缺少必需参数:question")
|
|
|
+ return
|
|
|
+
|
|
|
+ if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
|
|
|
+ yield format_sse_error(f"无效的routing_mode参数值: {api_routing_mode}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 🆕 用户ID和对话ID一致性校验(与ask_agent相同)
|
|
|
+ try:
|
|
|
+ # 获取登录用户ID
|
|
|
+ login_user_id = session.get('user_id') if 'user_id' in session else None
|
|
|
+
|
|
|
+ # 用户ID和对话ID一致性校验
|
|
|
+ from common.session_aware_cache import ConversationAwareMemoryCache
|
|
|
+
|
|
|
+ # 如果传递了conversation_id,从中解析user_id
|
|
|
+ extracted_user_id = None
|
|
|
+ if conversation_id_input:
|
|
|
+ extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
|
|
|
+
|
|
|
+ # 如果同时传递了user_id和conversation_id,进行一致性校验
|
|
|
+ if user_id_input:
|
|
|
+ is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
|
|
|
+ conversation_id_input, user_id_input
|
|
|
+ )
|
|
|
+ if not is_valid:
|
|
|
+ yield format_sse_error(error_msg)
|
|
|
+ return
|
|
|
+
|
|
|
+ # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
|
|
|
+ elif not user_id_input and extracted_user_id:
|
|
|
+ user_id_input = extracted_user_id
|
|
|
+ logger.info(f"从conversation_id解析出user_id: {user_id_input}")
|
|
|
+
|
|
|
+ # 如果没有传递user_id,使用默认值guest
|
|
|
+ if not user_id_input:
|
|
|
+ user_id_input = "guest"
|
|
|
+ logger.info("未传递user_id,使用默认值: guest")
|
|
|
+
|
|
|
+ # 🆕 智能ID解析(与ask_agent相同)
|
|
|
+ user_id = redis_conversation_manager.resolve_user_id(
|
|
|
+ user_id_input, None, request.remote_addr, login_user_id
|
|
|
+ )
|
|
|
+ conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
|
|
|
+ user_id, conversation_id_input, continue_conversation
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"用户ID和对话ID解析失败: {str(e)}")
|
|
|
+ yield format_sse_error(f"参数解析失败: {str(e)}")
|
|
|
+ return
|
|
|
+
|
|
|
+ logger.info(f"[STREAM_API] 收到请求 - 问题: {question[:50]}..., 用户: {user_id}, 对话: {conversation_id}")
|
|
|
+
|
|
|
+ # 🆕 获取上下文和上下文类型(与ask_agent相同)
|
|
|
+ context = redis_conversation_manager.get_context(conversation_id)
|
|
|
+
|
|
|
+ # 获取上下文类型:从最后一条助手消息的metadata中获取类型
|
|
|
+ context_type = None
|
|
|
+ if context:
|
|
|
+ try:
|
|
|
+ # 获取最后一条助手消息的metadata
|
|
|
+ messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit=10)
|
|
|
+ for message in reversed(messages): # 从最新的开始找
|
|
|
+ if message.get("role") == "assistant":
|
|
|
+ metadata = message.get("metadata", {})
|
|
|
+ context_type = metadata.get("type")
|
|
|
+ if context_type:
|
|
|
+ logger.info(f"[STREAM_API] 检测到上下文类型: {context_type}")
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"获取上下文类型失败: {str(e)}")
|
|
|
+
|
|
|
+ # 🆕 检查缓存(与ask_agent相同)
|
|
|
+ cached_answer = redis_conversation_manager.get_cached_answer(question, context)
|
|
|
+ if cached_answer:
|
|
|
+ logger.info(f"[STREAM_API] 使用缓存答案")
|
|
|
+
|
|
|
+ # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
|
|
|
+ cached_response_type = cached_answer.get("type", "UNKNOWN")
|
|
|
+ if cached_response_type == "DATABASE":
|
|
|
+ # DATABASE类型:按优先级选择内容
|
|
|
+ if cached_answer.get("response"):
|
|
|
+ assistant_response = cached_answer.get("response")
|
|
|
+ elif cached_answer.get("summary"):
|
|
|
+ assistant_response = cached_answer.get("summary")
|
|
|
+ elif cached_answer.get("query_result"):
|
|
|
+ query_result = cached_answer.get("query_result")
|
|
|
+ row_count = query_result.get("row_count", 0)
|
|
|
+ assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
|
|
|
+ else:
|
|
|
+ assistant_response = "查询处理完成"
|
|
|
+ else:
|
|
|
+ assistant_response = cached_answer.get("response", "处理完成")
|
|
|
+
|
|
|
+ # 返回缓存结果的SSE格式
|
|
|
+ yield format_sse_completed({
|
|
|
+ "type": "completed",
|
|
|
+ "result": {
|
|
|
+ "success": True,
|
|
|
+ "type": cached_response_type,
|
|
|
+ "response": assistant_response,
|
|
|
+ "sql": cached_answer.get("sql"),
|
|
|
+ "query_result": cached_answer.get("query_result"),
|
|
|
+ "summary": cached_answer.get("summary"),
|
|
|
+ "conversation_id": conversation_id,
|
|
|
+ "execution_path": cached_answer.get("execution_path", []),
|
|
|
+ "classification_info": cached_answer.get("classification_info", {}),
|
|
|
+ "user_id": user_id,
|
|
|
+ "context_used": bool(context),
|
|
|
+ "from_cache": True,
|
|
|
+ "conversation_status": conversation_status["status"],
|
|
|
+ "requested_conversation_id": conversation_status.get("requested_id")
|
|
|
+ }
|
|
|
+ })
|
|
|
+ return
|
|
|
+
|
|
|
+ # 🆕 保存用户消息(与ask_agent相同)
|
|
|
+ redis_conversation_manager.save_message(conversation_id, "user", question)
|
|
|
+
|
|
|
+ # 🆕 构建带上下文的问题(与ask_agent相同)
|
|
|
+ if context:
|
|
|
+ enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
|
|
|
+ logger.info(f"[STREAM_API] 使用上下文,长度: {len(context)}字符")
|
|
|
+ else:
|
|
|
+ enhanced_question = question
|
|
|
+ logger.info(f"[STREAM_API] 新对话,无上下文")
|
|
|
+
|
|
|
+ # 获取Agent实例
|
|
|
+ try:
|
|
|
+ agent = get_citu_langraph_agent()
|
|
|
+ if not agent:
|
|
|
+ yield format_sse_error("Agent实例获取失败")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 检查是否有process_question_stream方法
|
|
|
+ if not hasattr(agent, 'process_question_stream'):
|
|
|
+ yield format_sse_error("Agent不支持流式处理")
|
|
|
+ return
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Agent初始化失败: {str(e)}")
|
|
|
+ yield format_sse_error("AI服务暂时不可用,请稍后重试")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 🆕 确定最终使用的路由模式(与ask_agent相同)
|
|
|
+ if api_routing_mode:
|
|
|
+ # API传了参数,优先使用
|
|
|
+ effective_routing_mode = api_routing_mode
|
|
|
+ logger.info(f"[STREAM_API] 使用API指定的路由模式: {effective_routing_mode}")
|
|
|
+ else:
|
|
|
+ # API没传参数,使用配置文件
|
|
|
+ try:
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
+ effective_routing_mode = QUESTION_ROUTING_MODE
|
|
|
+ logger.info(f"[STREAM_API] 使用配置文件路由模式: {effective_routing_mode}")
|
|
|
+ except ImportError:
|
|
|
+ effective_routing_mode = "hybrid"
|
|
|
+ logger.info(f"[STREAM_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
|
|
|
+
|
|
|
+ # 流式处理 - 实时转发
|
|
|
+ try:
|
|
|
+ import asyncio
|
|
|
+
|
|
|
+ # 获取当前事件循环,如果没有则创建新的
|
|
|
+ try:
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ except RuntimeError:
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+
|
|
|
+ # 用于收集最终结果,以便保存到Redis
|
|
|
+ final_result = None
|
|
|
+
|
|
|
+ # 异步生成器,实时yield数据
|
|
|
+ async def stream_generator():
|
|
|
+ nonlocal final_result
|
|
|
+ try:
|
|
|
+ async for chunk in agent.process_question_stream(
|
|
|
+ question=enhanced_question, # 🆕 使用增强后的问题
|
|
|
+ user_id=user_id,
|
|
|
+ conversation_id=conversation_id,
|
|
|
+ context_type=context_type, # 🆕 传递上下文类型
|
|
|
+ routing_mode=effective_routing_mode
|
|
|
+ ):
|
|
|
+ # 如果是完成的chunk,保存最终结果
|
|
|
+ if chunk.get("type") == "completed":
|
|
|
+ final_result = chunk.get("result")
|
|
|
+ yield chunk
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"流式处理异常: {str(e)}")
|
|
|
+ yield {"type": "error", "error": str(e)}
|
|
|
+
|
|
|
+ # 同步包装器,实时转发数据
|
|
|
+ def sync_stream_wrapper():
|
|
|
+ # 创建异步任务
|
|
|
+ async_gen = stream_generator()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ # 获取下一个chunk
|
|
|
+ chunk = loop.run_until_complete(async_gen.__anext__())
|
|
|
+
|
|
|
+ if chunk["type"] == "progress":
|
|
|
+ yield format_sse_progress(chunk)
|
|
|
+ elif chunk["type"] == "completed":
|
|
|
+ yield format_sse_completed(chunk)
|
|
|
+ break # 完成后退出循环
|
|
|
+ elif chunk["type"] == "error":
|
|
|
+ yield format_sse_error(chunk.get("error", "未知错误"))
|
|
|
+ break # 错误后退出循环
|
|
|
+
|
|
|
+ except StopAsyncIteration:
|
|
|
+ # 异步生成器结束
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"流式转发异常: {str(e)}")
|
|
|
+ yield format_sse_error(f"流式处理异常: {str(e)}")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 返回同步生成器
|
|
|
+ yield from sync_stream_wrapper()
|
|
|
+
|
|
|
+ # 🆕 保存助手消息和缓存结果(与ask_agent相同)
|
|
|
+ if final_result and final_result.get("success", False):
|
|
|
+ try:
|
|
|
+ response_type = final_result.get("type", "UNKNOWN")
|
|
|
+ response_text = final_result.get("response", "")
|
|
|
+ sql = final_result.get("sql")
|
|
|
+ query_result = final_result.get("query_result")
|
|
|
+ summary = final_result.get("summary")
|
|
|
+ execution_path = final_result.get("execution_path", [])
|
|
|
+ classification_info = final_result.get("classification_info", {})
|
|
|
+
|
|
|
+ # 确定助手回复内容的优先级
|
|
|
+ if response_type == "DATABASE":
|
|
|
+ if response_text:
|
|
|
+ assistant_response = response_text
|
|
|
+ elif summary:
|
|
|
+ assistant_response = summary
|
|
|
+ elif query_result:
|
|
|
+ row_count = query_result.get("row_count", 0)
|
|
|
+ assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
|
|
|
+ else:
|
|
|
+ assistant_response = "查询处理完成"
|
|
|
+ else:
|
|
|
+ assistant_response = response_text or "处理完成"
|
|
|
+
|
|
|
+ # 保存助手消息
|
|
|
+ metadata = {
|
|
|
+ "type": response_type,
|
|
|
+ "sql": sql,
|
|
|
+ "execution_path": execution_path,
|
|
|
+ "classification_info": classification_info
|
|
|
+ }
|
|
|
+ redis_conversation_manager.save_message(
|
|
|
+ conversation_id, "assistant", assistant_response, metadata
|
|
|
+ )
|
|
|
+
|
|
|
+ # 缓存结果(仅缓存成功的结果)- 与ask_agent相同的调用方式
|
|
|
+ redis_conversation_manager.cache_answer(question, final_result, context)
|
|
|
+ logger.info(f"[STREAM_API] 结果已缓存")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"保存结果和缓存失败: {str(e)}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"流式处理异常: {str(e)}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ yield format_sse_error(f"处理异常: {str(e)}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"流式API异常: {str(e)}")
|
|
|
+ yield format_sse_error(f"服务异常: {str(e)}")
|
|
|
+
|
|
|
+ return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
|
|
+
|
|
|
+def format_sse_progress(chunk: dict) -> str:
|
|
|
+ """格式化进度事件为SSE格式"""
|
|
|
+ progress = chunk.get("progress", {})
|
|
|
+ node = chunk.get("node")
|
|
|
+
|
|
|
+ # 🆕 特殊处理:格式化响应节点的显示内容
|
|
|
+ if node == "format_response":
|
|
|
+ display_name = "格式化响应结果"
|
|
|
+ message = "正在执行:格式化响应结果"
|
|
|
+ else:
|
|
|
+ display_name = progress.get("display_name")
|
|
|
+ message = f"正在执行: {progress.get('display_name', '处理中')}"
|
|
|
+
|
|
|
+ data = {
|
|
|
+ "code": 200,
|
|
|
+ "success": True,
|
|
|
+ "message": message,
|
|
|
+ "data": {
|
|
|
+ "type": "progress",
|
|
|
+ "node": node,
|
|
|
+ "display_name": display_name,
|
|
|
+ # 🆕 删除icon字段
|
|
|
+ "details": progress.get("details"),
|
|
|
+ "sub_status": progress.get("sub_status"),
|
|
|
+ "conversation_id": chunk.get("conversation_id"),
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ import json
|
|
|
+ return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
+
|
|
|
+def format_sse_completed(chunk: dict) -> str:
|
|
|
+ """格式化完成事件为SSE格式"""
|
|
|
+ result = chunk.get("result", {})
|
|
|
+
|
|
|
+ data = {
|
|
|
+ "code": 200,
|
|
|
+ "success": True,
|
|
|
+ "message": "处理完成",
|
|
|
+ "data": {
|
|
|
+ "type": "completed",
|
|
|
+ "response": result.get("response", ""),
|
|
|
+ "response_type": result.get("type", "UNKNOWN"),
|
|
|
+ "sql": result.get("sql"),
|
|
|
+ "query_result": result.get("query_result"),
|
|
|
+ "summary": result.get("summary"),
|
|
|
+ "conversation_id": chunk.get("conversation_id"),
|
|
|
+ "execution_path": result.get("execution_path", []),
|
|
|
+ "classification_info": result.get("classification_info", {}),
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ import json
|
|
|
+ return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
+
|
|
|
+def format_sse_error(error_message: str) -> str:
|
|
|
+ """格式化错误事件为SSE格式"""
|
|
|
+ data = {
|
|
|
+ "code": 500,
|
|
|
+ "success": False,
|
|
|
+ "message": "处理失败",
|
|
|
+ "data": {
|
|
|
+ "type": "error",
|
|
|
+ "error": error_message,
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ import json
|
|
|
+ return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
+
|
|
|
# ==================== QA反馈系统API ====================
|
|
|
|
|
|
qa_feedback_manager = None
|