Selaa lähdekoodia

修复unified_api.py中的ask_react_agent的返回结果的问题,准备迁移api.py中的其它漏掉的api.

wangxq 1 kuukausi sitten
vanhempi
commit
f3e38aca1d
1 muutettua tiedostoa jossa 133 lisäystä ja 8 poistoa
  1. 133 8
      unified_api.py

+ 133 - 8
unified_api.py

@@ -385,9 +385,86 @@ def ask_agent():
             user_id, conversation_id_input, continue_conversation
         )
         
-        # 获取上下文
+        # 获取上下文和上下文类型(提前到缓存检查之前)
         context = redis_conversation_manager.get_context(conversation_id)
         
+        # 获取上下文类型:从最后一条助手消息的metadata中获取类型
+        context_type = None
+        if context:
+            try:
+                # 获取最后一条助手消息的metadata
+                messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
+                for message in reversed(messages):  # 从最新的开始找
+                    if message.get("role") == "assistant":
+                        metadata = message.get("metadata", {})
+                        context_type = metadata.get("type")
+                        if context_type:
+                            logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
+                            break
+            except Exception as e:
+                logger.warning(f"获取上下文类型失败: {str(e)}")
+        
+        # 检查缓存(新逻辑:放宽使用条件,严控存储条件)
+        cached_answer = redis_conversation_manager.get_cached_answer(question, context)
+        if cached_answer:
+            logger.info(f"[AGENT_API] 使用缓存答案")
+            
+            # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
+            cached_response_type = cached_answer.get("type", "UNKNOWN")
+            if cached_response_type == "DATABASE":
+                # DATABASE类型:按优先级选择内容
+                if cached_answer.get("response"):
+                    # 优先级1:错误或解释性回复(如SQL生成失败)
+                    assistant_response = cached_answer.get("response")
+                elif cached_answer.get("summary"):
+                    # 优先级2:查询成功的摘要
+                    assistant_response = cached_answer.get("summary")
+                elif cached_answer.get("query_result"):
+                    # 优先级3:构造简单描述
+                    query_result = cached_answer.get("query_result")
+                    row_count = query_result.get("row_count", 0)
+                    assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
+                else:
+                    # 异常情况
+                    assistant_response = "数据库查询已处理。"
+            else:
+                # CHAT类型:直接使用response
+                assistant_response = cached_answer.get("response", "")
+            
+            # 更新对话历史
+            redis_conversation_manager.save_message(conversation_id, "user", question)
+            redis_conversation_manager.save_message(
+                conversation_id, "assistant", 
+                assistant_response,
+                metadata={"from_cache": True}
+            )
+            
+            # 添加对话信息到缓存结果
+            cached_answer["conversation_id"] = conversation_id
+            cached_answer["user_id"] = user_id
+            cached_answer["from_cache"] = True
+            cached_answer.update(conversation_status)
+            
+            # 使用agent_success_response返回标准格式
+            return jsonify(agent_success_response(
+                response_type=cached_answer.get("type", "UNKNOWN"),
+                response=cached_answer.get("response", ""),
+                sql=cached_answer.get("sql"),
+                records=cached_answer.get("query_result"),
+                summary=cached_answer.get("summary"),
+                session_id=browser_session_id,
+                execution_path=cached_answer.get("execution_path", []),
+                classification_info=cached_answer.get("classification_info", {}),
+                conversation_id=conversation_id,
+                user_id=user_id,
+                is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
+                context_used=bool(context),
+                from_cache=True,
+                conversation_status=conversation_status["status"],
+                conversation_message=conversation_status["message"],
+                requested_conversation_id=conversation_status.get("requested_id")
+            ))
+        
         # 保存用户消息
         redis_conversation_manager.save_message(conversation_id, "user", question)
         
@@ -399,6 +476,21 @@ def ask_agent():
             enhanced_question = question
             logger.info(f"[AGENT_API] 新对话,无上下文")
         
+        # 确定最终使用的路由模式(优先级逻辑)
+        if api_routing_mode:
+            # API传了参数,优先使用
+            effective_routing_mode = api_routing_mode
+            logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
+        else:
+            # API没传参数,使用配置文件
+            try:
+                from app_config import QUESTION_ROUTING_MODE
+                effective_routing_mode = QUESTION_ROUTING_MODE
+                logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
+            except ImportError:
+                effective_routing_mode = "hybrid"
+                logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
+        
         # Agent处理
         try:
             agent = get_citu_langraph_agent()
@@ -412,8 +504,10 @@ def ask_agent():
         # 异步调用Agent处理问题
         import asyncio
         agent_result = asyncio.run(agent.process_question(
-            question=enhanced_question,
-            session_id=browser_session_id
+            question=enhanced_question,  # 使用增强后的问题
+            session_id=browser_session_id,
+            context_type=context_type,  # 传递上下文类型
+            routing_mode=effective_routing_mode  # 新增:传递路由模式
         ))
         
         # 处理Agent结果
@@ -450,6 +544,11 @@ def ask_agent():
                 }
             )
             
+            # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
+            # 直接缓存agent_result,它已经包含所有需要的字段
+            redis_conversation_manager.cache_answer(question, agent_result, context)
+            
+            # 使用agent_success_response的正确方式
             return jsonify(agent_success_response(
                 response_type=response_type,
                 response=response_text,
@@ -465,7 +564,10 @@ def ask_agent():
                 context_used=bool(context),
                 from_cache=False,
                 conversation_status=conversation_status["status"],
-                conversation_message=conversation_status["message"]
+                conversation_message=conversation_status["message"],
+                requested_conversation_id=conversation_status.get("requested_id"),
+                routing_mode_used=effective_routing_mode,  # 新增:实际使用的路由模式
+                routing_mode_source="api" if api_routing_mode else "config"  # 新增:路由模式来源
             ))
         else:
             # 错误处理
@@ -505,24 +607,47 @@ def get_qa_feedback_manager():
 
 @app.route('/api/v0/qa_feedback/query', methods=['POST'])
 def qa_feedback_query():
-    """查询反馈记录API"""
+    """
+    查询反馈记录API
+    支持分页、筛选和排序功能
+    """
     try:
         req = request.get_json(force=True)
         
+        # 解析参数,设置默认值
         page = req.get('page', 1)
         page_size = req.get('page_size', 20)
         is_thumb_up = req.get('is_thumb_up')
+        create_time_start = req.get('create_time_start')
+        create_time_end = req.get('create_time_end')
+        is_in_training_data = req.get('is_in_training_data')
+        sort_by = req.get('sort_by', 'create_time')
+        sort_order = req.get('sort_order', 'desc')
+        
+        # 参数验证
+        if page < 1:
+            return jsonify(bad_request_response(
+                response_text="页码必须大于0",
+                invalid_params=["page"]
+            )), 400
         
-        if page < 1 or page_size < 1 or page_size > 100:
+        if page_size < 1 or page_size > 100:
             return jsonify(bad_request_response(
-                response_text="参数错误"
+                response_text="每页大小必须在1-100之间",
+                invalid_params=["page_size"]
             )), 400
         
+        # 获取反馈管理器并查询
         manager = get_qa_feedback_manager()
         records, total = manager.query_feedback(
             page=page,
             page_size=page_size,
-            is_thumb_up=is_thumb_up
+            is_thumb_up=is_thumb_up,
+            create_time_start=create_time_start,
+            create_time_end=create_time_end,
+            is_in_training_data=is_in_training_data,
+            sort_by=sort_by,
+            sort_order=sort_order
         )
         
         total_pages = (total + page_size - 1) // page_size