|
@@ -385,9 +385,86 @@ def ask_agent():
|
|
|
user_id, conversation_id_input, continue_conversation
|
|
|
)
|
|
|
|
|
|
- # 获取上下文
|
|
|
+ # 获取上下文和上下文类型(提前到缓存检查之前)
|
|
|
context = redis_conversation_manager.get_context(conversation_id)
|
|
|
|
|
|
+ # 获取上下文类型:从最后一条助手消息的metadata中获取类型
|
|
|
+ context_type = None
|
|
|
+ if context:
|
|
|
+ try:
|
|
|
+ # 获取最后一条助手消息的metadata
|
|
|
+ messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
|
|
|
+ for message in reversed(messages): # 从最新的开始找
|
|
|
+ if message.get("role") == "assistant":
|
|
|
+ metadata = message.get("metadata", {})
|
|
|
+ context_type = metadata.get("type")
|
|
|
+ if context_type:
|
|
|
+ logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"获取上下文类型失败: {str(e)}")
|
|
|
+
|
|
|
+ # 检查缓存(新逻辑:放宽使用条件,严控存储条件)
|
|
|
+ cached_answer = redis_conversation_manager.get_cached_answer(question, context)
|
|
|
+ if cached_answer:
|
|
|
+ logger.info(f"[AGENT_API] 使用缓存答案")
|
|
|
+
|
|
|
+ # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
|
|
|
+ cached_response_type = cached_answer.get("type", "UNKNOWN")
|
|
|
+ if cached_response_type == "DATABASE":
|
|
|
+ # DATABASE类型:按优先级选择内容
|
|
|
+ if cached_answer.get("response"):
|
|
|
+ # 优先级1:错误或解释性回复(如SQL生成失败)
|
|
|
+ assistant_response = cached_answer.get("response")
|
|
|
+ elif cached_answer.get("summary"):
|
|
|
+ # 优先级2:查询成功的摘要
|
|
|
+ assistant_response = cached_answer.get("summary")
|
|
|
+ elif cached_answer.get("query_result"):
|
|
|
+ # 优先级3:构造简单描述
|
|
|
+ query_result = cached_answer.get("query_result")
|
|
|
+ row_count = query_result.get("row_count", 0)
|
|
|
+ assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
|
|
|
+ else:
|
|
|
+ # 异常情况
|
|
|
+ assistant_response = "数据库查询已处理。"
|
|
|
+ else:
|
|
|
+ # CHAT类型:直接使用response
|
|
|
+ assistant_response = cached_answer.get("response", "")
|
|
|
+
|
|
|
+ # 更新对话历史
|
|
|
+ redis_conversation_manager.save_message(conversation_id, "user", question)
|
|
|
+ redis_conversation_manager.save_message(
|
|
|
+ conversation_id, "assistant",
|
|
|
+ assistant_response,
|
|
|
+ metadata={"from_cache": True}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加对话信息到缓存结果
|
|
|
+ cached_answer["conversation_id"] = conversation_id
|
|
|
+ cached_answer["user_id"] = user_id
|
|
|
+ cached_answer["from_cache"] = True
|
|
|
+ cached_answer.update(conversation_status)
|
|
|
+
|
|
|
+ # 使用agent_success_response返回标准格式
|
|
|
+ return jsonify(agent_success_response(
|
|
|
+ response_type=cached_answer.get("type", "UNKNOWN"),
|
|
|
+ response=cached_answer.get("response", ""),
|
|
|
+ sql=cached_answer.get("sql"),
|
|
|
+ records=cached_answer.get("query_result"),
|
|
|
+ summary=cached_answer.get("summary"),
|
|
|
+ session_id=browser_session_id,
|
|
|
+ execution_path=cached_answer.get("execution_path", []),
|
|
|
+ classification_info=cached_answer.get("classification_info", {}),
|
|
|
+ conversation_id=conversation_id,
|
|
|
+ user_id=user_id,
|
|
|
+ is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
|
|
|
+ context_used=bool(context),
|
|
|
+ from_cache=True,
|
|
|
+ conversation_status=conversation_status["status"],
|
|
|
+ conversation_message=conversation_status["message"],
|
|
|
+ requested_conversation_id=conversation_status.get("requested_id")
|
|
|
+ ))
|
|
|
+
|
|
|
# 保存用户消息
|
|
|
redis_conversation_manager.save_message(conversation_id, "user", question)
|
|
|
|
|
@@ -399,6 +476,21 @@ def ask_agent():
|
|
|
enhanced_question = question
|
|
|
logger.info(f"[AGENT_API] 新对话,无上下文")
|
|
|
|
|
|
+ # 确定最终使用的路由模式(优先级逻辑)
|
|
|
+ if api_routing_mode:
|
|
|
+ # API传了参数,优先使用
|
|
|
+ effective_routing_mode = api_routing_mode
|
|
|
+ logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
|
|
|
+ else:
|
|
|
+ # API没传参数,使用配置文件
|
|
|
+ try:
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
+ effective_routing_mode = QUESTION_ROUTING_MODE
|
|
|
+ logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
|
|
|
+ except ImportError:
|
|
|
+ effective_routing_mode = "hybrid"
|
|
|
+ logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
|
|
|
+
|
|
|
# Agent处理
|
|
|
try:
|
|
|
agent = get_citu_langraph_agent()
|
|
@@ -412,8 +504,10 @@ def ask_agent():
|
|
|
# 异步调用Agent处理问题
|
|
|
import asyncio
|
|
|
agent_result = asyncio.run(agent.process_question(
|
|
|
- question=enhanced_question,
|
|
|
- session_id=browser_session_id
|
|
|
+ question=enhanced_question, # 使用增强后的问题
|
|
|
+ session_id=browser_session_id,
|
|
|
+ context_type=context_type, # 传递上下文类型
|
|
|
+ routing_mode=effective_routing_mode # 新增:传递路由模式
|
|
|
))
|
|
|
|
|
|
# 处理Agent结果
|
|
@@ -450,6 +544,11 @@ def ask_agent():
|
|
|
}
|
|
|
)
|
|
|
|
|
|
+ # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
|
|
|
+ # 直接缓存agent_result,它已经包含所有需要的字段
|
|
|
+ redis_conversation_manager.cache_answer(question, agent_result, context)
|
|
|
+
|
|
|
+ # 使用agent_success_response的正确方式
|
|
|
return jsonify(agent_success_response(
|
|
|
response_type=response_type,
|
|
|
response=response_text,
|
|
@@ -465,7 +564,10 @@ def ask_agent():
|
|
|
context_used=bool(context),
|
|
|
from_cache=False,
|
|
|
conversation_status=conversation_status["status"],
|
|
|
- conversation_message=conversation_status["message"]
|
|
|
+ conversation_message=conversation_status["message"],
|
|
|
+ requested_conversation_id=conversation_status.get("requested_id"),
|
|
|
+ routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
|
|
|
+ routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
|
|
|
))
|
|
|
else:
|
|
|
# 错误处理
|
|
@@ -505,24 +607,47 @@ def get_qa_feedback_manager():
|
|
|
|
|
|
@app.route('/api/v0/qa_feedback/query', methods=['POST'])
|
|
|
def qa_feedback_query():
|
|
|
- """查询反馈记录API"""
|
|
|
+ """
|
|
|
+ 查询反馈记录API
|
|
|
+ 支持分页、筛选和排序功能
|
|
|
+ """
|
|
|
try:
|
|
|
req = request.get_json(force=True)
|
|
|
|
|
|
+ # 解析参数,设置默认值
|
|
|
page = req.get('page', 1)
|
|
|
page_size = req.get('page_size', 20)
|
|
|
is_thumb_up = req.get('is_thumb_up')
|
|
|
+ create_time_start = req.get('create_time_start')
|
|
|
+ create_time_end = req.get('create_time_end')
|
|
|
+ is_in_training_data = req.get('is_in_training_data')
|
|
|
+ sort_by = req.get('sort_by', 'create_time')
|
|
|
+ sort_order = req.get('sort_order', 'desc')
|
|
|
+
|
|
|
+ # 参数验证
|
|
|
+ if page < 1:
|
|
|
+ return jsonify(bad_request_response(
|
|
|
+ response_text="页码必须大于0",
|
|
|
+ invalid_params=["page"]
|
|
|
+ )), 400
|
|
|
|
|
|
- if page < 1 or page_size < 1 or page_size > 100:
|
|
|
+ if page_size < 1 or page_size > 100:
|
|
|
return jsonify(bad_request_response(
|
|
|
- response_text="参数错误"
|
|
|
+ response_text="每页大小必须在1-100之间",
|
|
|
+ invalid_params=["page_size"]
|
|
|
)), 400
|
|
|
|
|
|
+ # 获取反馈管理器并查询
|
|
|
manager = get_qa_feedback_manager()
|
|
|
records, total = manager.query_feedback(
|
|
|
page=page,
|
|
|
page_size=page_size,
|
|
|
- is_thumb_up=is_thumb_up
|
|
|
+ is_thumb_up=is_thumb_up,
|
|
|
+ create_time_start=create_time_start,
|
|
|
+ create_time_end=create_time_end,
|
|
|
+ is_in_training_data=is_in_training_data,
|
|
|
+ sort_by=sort_by,
|
|
|
+ sort_order=sort_order
|
|
|
)
|
|
|
|
|
|
total_pages = (total + page_size - 1) // page_size
|