123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299 |
- """
- 统一 API 服务
- 集成 citu_app.py 指定API 和 react_agent/api.py 的所有功能
- 提供数据库问答、Redis对话管理、QA反馈、训练数据管理、React Agent等功能
- 使用普通 Flask 应用 + ASGI 包装实现异步支持
- """
- import asyncio
- import logging
- import atexit
- import os
- import sys
- from datetime import datetime, timedelta
- from typing import Optional, Dict, Any, TYPE_CHECKING, Union
- import signal
- if TYPE_CHECKING:
- from react_agent.agent import CustomReactAgent
- # 初始化日志系统 - 必须在最前面
- from core.logging import initialize_logging, get_app_logger
- initialize_logging()
- # 标准 Flask 导入
- from flask import Flask, request, jsonify, session
- import redis.asyncio as redis
- # 基础依赖
- import pandas as pd
- import json
- import sqlparse
- # 项目模块导入
- from core.vanna_llm_factory import create_vanna_instance
- from common.redis_conversation_manager import RedisConversationManager
- from common.qa_feedback_manager import QAFeedbackManager
- from common.result import (
- success_response, bad_request_response, not_found_response, internal_error_response,
- error_response, service_unavailable_response,
- agent_success_response, agent_error_response,
- validation_failed_response
- )
- from app_config import (
- USER_MAX_CONVERSATIONS, CONVERSATION_CONTEXT_COUNT,
- DEFAULT_ANONYMOUS_USER, ENABLE_QUESTION_ANSWER_CACHE
- )
- # 创建标准 Flask 应用
- app = Flask(__name__)
- # 创建日志记录器
- logger = get_app_logger("UnifiedApp")
- # React Agent 导入
- try:
- from react_agent.agent import CustomReactAgent
- except ImportError:
- try:
- from test.custom_react_agent.agent import CustomReactAgent
- except ImportError:
- logger.warning("无法导入 CustomReactAgent,React Agent功能将不可用")
- CustomReactAgent = None
- # 初始化核心组件
- vn = create_vanna_instance()
- redis_conversation_manager = RedisConversationManager()
- # ==================== React Agent 全局实例管理 ====================
- _react_agent_instance: Optional[Any] = None
- _redis_client: Optional[redis.Redis] = None
- def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
- """验证请求数据,并支持从thread_id中推断user_id"""
- errors = []
-
- # 验证 question(必填)
- question = data.get('question', '')
- if not question or not question.strip():
- errors.append('问题不能为空')
- elif len(question) > 2000:
- errors.append('问题长度不能超过2000字符')
-
- # 优先获取 thread_id
- thread_id = data.get('thread_id') or data.get('conversation_id')
-
- # 获取 user_id,但暂不设置默认值
- user_id = data.get('user_id')
- # 如果没有传递 user_id,则尝试从 thread_id 中推断
- if not user_id:
- if thread_id and ':' in thread_id:
- inferred_user_id = thread_id.split(':', 1)[0]
- if inferred_user_id:
- user_id = inferred_user_id
- logger.info(f"👤 未提供user_id,从 thread_id '{thread_id}' 中推断出: '{user_id}'")
- else:
- user_id = 'guest'
- else:
- user_id = 'guest'
- # 验证 user_id 长度
- if user_id and len(user_id) > 50:
- errors.append('用户ID长度不能超过50字符')
-
- # 用户ID与会话ID一致性校验
- if thread_id:
- if ':' not in thread_id:
- errors.append('会话ID格式无效,期望格式为 user_id:timestamp')
- else:
- thread_user_id = thread_id.split(':', 1)[0]
- if thread_user_id != user_id:
- errors.append(f'会话归属验证失败:会话ID [{thread_id}] 不属于当前用户 [{user_id}]')
-
- if errors:
- raise ValueError('; '.join(errors))
-
- return {
- 'question': question.strip(),
- 'user_id': user_id,
- 'thread_id': thread_id # 可选,不传则自动生成新会话
- }
- async def get_react_agent() -> Any:
- """获取 React Agent 实例(懒加载)"""
- global _react_agent_instance, _redis_client
-
- if _react_agent_instance is None:
- if CustomReactAgent is None:
- logger.error("❌ CustomReactAgent 未能导入,无法初始化")
- raise ImportError("CustomReactAgent 未能导入")
-
- logger.info("🚀 正在异步初始化 Custom React Agent...")
- try:
- # 设置环境变量
- os.environ['REDIS_URL'] = 'redis://localhost:6379'
-
- # 初始化共享的Redis客户端
- _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
- await _redis_client.ping()
- logger.info("✅ Redis客户端连接成功")
-
- _react_agent_instance = await CustomReactAgent.create()
- logger.info("✅ React Agent 异步初始化完成")
- except Exception as e:
- logger.error(f"❌ React Agent 异步初始化失败: {e}")
- raise
-
- return _react_agent_instance
- async def ensure_agent_ready() -> bool:
- """异步确保Agent实例可用"""
- global _react_agent_instance
-
- if _react_agent_instance is None:
- await get_react_agent()
-
- # 测试Agent是否还可用
- try:
- test_result = await _react_agent_instance.get_user_recent_conversations("__test__", 1)
- return True
- except Exception as e:
- logger.warning(f"⚠️ Agent实例不可用: {e}")
- _react_agent_instance = None
- await get_react_agent()
- return True
- def cleanup_resources():
- """清理资源"""
- global _react_agent_instance, _redis_client
-
- async def async_cleanup():
- if _react_agent_instance:
- await _react_agent_instance.close()
- logger.info("✅ React Agent 资源已清理")
-
- if _redis_client:
- await _redis_client.aclose()
- logger.info("✅ Redis客户端已关闭")
-
- try:
- asyncio.run(async_cleanup())
- except Exception as e:
- logger.error(f"清理资源失败: {e}")
- atexit.register(cleanup_resources)
- # ==================== 基础路由 ====================
- @app.route("/")
- def index():
- """根路径健康检查"""
- return jsonify({"message": "统一API服务正在运行", "version": "1.0.0"})
- @app.route('/health', methods=['GET'])
- def health_check():
- """健康检查端点"""
- try:
- health_status = {
- "status": "healthy",
- "react_agent_initialized": _react_agent_instance is not None,
- "timestamp": datetime.now().isoformat(),
- "services": {
- "redis": redis_conversation_manager.is_available(),
- "vanna": vn is not None
- }
- }
- return jsonify(health_status), 200
- except Exception as e:
- logger.error(f"健康检查失败: {e}")
- return jsonify({"status": "unhealthy", "error": str(e)}), 500
- # ==================== React Agent API ====================
- @app.route("/api/v0/ask_react_agent", methods=["POST"])
- async def ask_react_agent():
- """异步React Agent智能问答接口"""
- global _react_agent_instance
-
- # 确保Agent已初始化
- if not await ensure_agent_ready():
- return jsonify({
- "code": 503,
- "message": "服务未就绪",
- "success": False,
- "error": "React Agent 初始化失败"
- }), 503
-
- try:
- # 获取请求数据
- try:
- data = request.get_json(force=True)
- except Exception as json_error:
- logger.warning(f"⚠️ JSON解析失败: {json_error}")
- return jsonify({
- "code": 400,
- "message": "请求格式错误",
- "success": False,
- "error": "无效的JSON格式",
- "details": str(json_error)
- }), 400
-
- if not data:
- return jsonify({
- "code": 400,
- "message": "请求参数错误",
- "success": False,
- "error": "请求体不能为空"
- }), 400
-
- # 验证请求数据
- validated_data = validate_request_data(data)
-
- logger.info(f"📨 收到React Agent请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
-
- # 异步调用处理
- agent_result = await _react_agent_instance.chat(
- message=validated_data['question'],
- user_id=validated_data['user_id'],
- thread_id=validated_data['thread_id']
- )
-
- if not agent_result.get("success", False):
- # Agent处理失败
- error_msg = agent_result.get("error", "React Agent处理失败")
- logger.error(f"❌ React Agent处理失败: {error_msg}")
-
- return jsonify({
- "code": 500,
- "message": "处理失败",
- "success": False,
- "error": error_msg,
- "data": {
- "conversation_id": agent_result.get("thread_id"),
- "user_id": validated_data['user_id'],
- "timestamp": datetime.now().isoformat()
- }
- }), 500
-
- # Agent处理成功
- api_data = agent_result.get("api_data", {})
-
- # 构建响应数据(按照 react_agent/api.py 的正确格式)
- response_data = {
- "response": api_data.get("response", ""),
- "conversation_id": agent_result.get("thread_id"),
- "user_id": validated_data['user_id'],
- "react_agent_meta": api_data.get("react_agent_meta", {
- "thread_id": agent_result.get("thread_id"),
- "agent_version": "custom_react_v1_async"
- }),
- "timestamp": datetime.now().isoformat()
- }
-
- # 可选字段:SQL(仅当执行SQL时存在)
- if "sql" in api_data:
- response_data["sql"] = api_data["sql"]
-
- # 可选字段:records(仅当有查询结果时存在)
- if "records" in api_data:
- response_data["records"] = api_data["records"]
-
- return jsonify({
- "code": 200,
- "message": "处理成功",
- "success": True,
- "data": response_data
- }), 200
-
- except ValueError as ve:
- # 参数验证错误
- logger.warning(f"⚠️ 参数验证失败: {ve}")
- return jsonify({
- "code": 400,
- "message": "参数验证失败",
- "success": False,
- "error": str(ve)
- }), 400
-
- except Exception as e:
- logger.error(f"❌ React Agent API 异常: {e}")
- return jsonify({
- "code": 500,
- "message": "内部服务错误",
- "success": False,
- "error": "服务暂时不可用,请稍后重试"
- }), 500
- # ==================== LangGraph Agent API ====================
- # 全局Agent实例(单例模式)
- citu_langraph_agent = None
- def get_citu_langraph_agent():
- """获取LangGraph Agent实例(懒加载)"""
- global citu_langraph_agent
- if citu_langraph_agent is None:
- try:
- from agent.citu_agent import CituLangGraphAgent
- logger.info("开始创建LangGraph Agent实例...")
- citu_langraph_agent = CituLangGraphAgent()
- logger.info("LangGraph Agent实例创建成功")
- except ImportError as e:
- logger.critical(f"Agent模块导入失败: {str(e)}")
- raise Exception(f"Agent模块导入失败: {str(e)}")
- except Exception as e:
- logger.critical(f"LangGraph Agent实例创建失败: {str(e)}")
- raise Exception(f"Agent初始化失败: {str(e)}")
- return citu_langraph_agent
- @app.route('/api/v0/ask_agent', methods=['POST'])
- def ask_agent():
- """支持对话上下文的ask_agent API"""
- req = request.get_json(force=True)
- question = req.get("question", None)
- browser_session_id = req.get("session_id", None)
- user_id_input = req.get("user_id", None)
- conversation_id_input = req.get("conversation_id", None)
- continue_conversation = req.get("continue_conversation", False)
- api_routing_mode = req.get("routing_mode", None)
-
- VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
-
- if not question:
- return jsonify(bad_request_response(
- response_text="缺少必需参数:question",
- missing_params=["question"]
- )), 400
-
- if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
- return jsonify(bad_request_response(
- response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
- invalid_params=["routing_mode"]
- )), 400
- try:
- # 获取登录用户ID
- login_user_id = session.get('user_id') if 'user_id' in session else None
-
- # 智能ID解析
- user_id = redis_conversation_manager.resolve_user_id(
- user_id_input, browser_session_id, request.remote_addr, login_user_id
- )
- conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
- user_id, conversation_id_input, continue_conversation
- )
-
- # 获取上下文和上下文类型(提前到缓存检查之前)
- context = redis_conversation_manager.get_context(conversation_id)
-
- # 获取上下文类型:从最后一条助手消息的metadata中获取类型
- context_type = None
- if context:
- try:
- # 获取最后一条助手消息的metadata
- messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
- for message in reversed(messages): # 从最新的开始找
- if message.get("role") == "assistant":
- metadata = message.get("metadata", {})
- context_type = metadata.get("type")
- if context_type:
- logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
- break
- except Exception as e:
- logger.warning(f"获取上下文类型失败: {str(e)}")
-
- # 检查缓存(新逻辑:放宽使用条件,严控存储条件)
- cached_answer = redis_conversation_manager.get_cached_answer(question, context)
- if cached_answer:
- logger.info(f"[AGENT_API] 使用缓存答案")
-
- # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
- cached_response_type = cached_answer.get("type", "UNKNOWN")
- if cached_response_type == "DATABASE":
- # DATABASE类型:按优先级选择内容
- if cached_answer.get("response"):
- # 优先级1:错误或解释性回复(如SQL生成失败)
- assistant_response = cached_answer.get("response")
- elif cached_answer.get("summary"):
- # 优先级2:查询成功的摘要
- assistant_response = cached_answer.get("summary")
- elif cached_answer.get("query_result"):
- # 优先级3:构造简单描述
- query_result = cached_answer.get("query_result")
- row_count = query_result.get("row_count", 0)
- assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
- else:
- # 异常情况
- assistant_response = "数据库查询已处理。"
- else:
- # CHAT类型:直接使用response
- assistant_response = cached_answer.get("response", "")
-
- # 更新对话历史
- redis_conversation_manager.save_message(conversation_id, "user", question)
- redis_conversation_manager.save_message(
- conversation_id, "assistant",
- assistant_response,
- metadata={"from_cache": True}
- )
-
- # 添加对话信息到缓存结果
- cached_answer["conversation_id"] = conversation_id
- cached_answer["user_id"] = user_id
- cached_answer["from_cache"] = True
- cached_answer.update(conversation_status)
-
- # 使用agent_success_response返回标准格式
- return jsonify(agent_success_response(
- response_type=cached_answer.get("type", "UNKNOWN"),
- response=cached_answer.get("response", ""),
- sql=cached_answer.get("sql"),
- records=cached_answer.get("query_result"),
- summary=cached_answer.get("summary"),
- session_id=browser_session_id,
- execution_path=cached_answer.get("execution_path", []),
- classification_info=cached_answer.get("classification_info", {}),
- conversation_id=conversation_id,
- user_id=user_id,
- is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
- context_used=bool(context),
- from_cache=True,
- conversation_status=conversation_status["status"],
- conversation_message=conversation_status["message"],
- requested_conversation_id=conversation_status.get("requested_id")
- ))
-
- # 保存用户消息
- redis_conversation_manager.save_message(conversation_id, "user", question)
-
- # 构建带上下文的问题
- if context:
- enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
- logger.info(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
- else:
- enhanced_question = question
- logger.info(f"[AGENT_API] 新对话,无上下文")
-
- # 确定最终使用的路由模式(优先级逻辑)
- if api_routing_mode:
- # API传了参数,优先使用
- effective_routing_mode = api_routing_mode
- logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
- else:
- # API没传参数,使用配置文件
- try:
- from app_config import QUESTION_ROUTING_MODE
- effective_routing_mode = QUESTION_ROUTING_MODE
- logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
- except ImportError:
- effective_routing_mode = "hybrid"
- logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
-
- # Agent处理
- try:
- agent = get_citu_langraph_agent()
- except Exception as e:
- logger.critical(f"Agent初始化失败: {str(e)}")
- return jsonify(service_unavailable_response(
- response_text="AI服务暂时不可用,请稍后重试",
- can_retry=True
- )), 503
-
- # 异步调用Agent处理问题
- import asyncio
- agent_result = asyncio.run(agent.process_question(
- question=enhanced_question, # 使用增强后的问题
- session_id=browser_session_id,
- context_type=context_type, # 传递上下文类型
- routing_mode=effective_routing_mode # 新增:传递路由模式
- ))
-
- # 处理Agent结果
- if agent_result.get("success", False):
- response_type = agent_result.get("type", "UNKNOWN")
- response_text = agent_result.get("response", "")
- sql = agent_result.get("sql")
- query_result = agent_result.get("query_result")
- summary = agent_result.get("summary")
- execution_path = agent_result.get("execution_path", [])
- classification_info = agent_result.get("classification_info", {})
-
- # 确定助手回复内容的优先级
- if response_type == "DATABASE":
- if response_text:
- assistant_response = response_text
- elif summary:
- assistant_response = summary
- elif query_result:
- row_count = query_result.get("row_count", 0)
- assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
- else:
- assistant_response = "数据库查询已处理。"
- else:
- assistant_response = response_text
-
- # 保存助手回复
- redis_conversation_manager.save_message(
- conversation_id, "assistant", assistant_response,
- metadata={
- "type": response_type,
- "sql": sql,
- "execution_path": execution_path
- }
- )
-
- # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
- # 直接缓存agent_result,它已经包含所有需要的字段
- redis_conversation_manager.cache_answer(question, agent_result, context)
-
- # 使用agent_success_response的正确方式
- return jsonify(agent_success_response(
- response_type=response_type,
- response=response_text,
- sql=sql,
- records=query_result,
- summary=summary,
- session_id=browser_session_id,
- execution_path=execution_path,
- classification_info=classification_info,
- conversation_id=conversation_id,
- user_id=user_id,
- is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
- context_used=bool(context),
- from_cache=False,
- conversation_status=conversation_status["status"],
- conversation_message=conversation_status["message"],
- requested_conversation_id=conversation_status.get("requested_id"),
- routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
- routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
- ))
- else:
- # 错误处理
- error_message = agent_result.get("error", "Agent处理失败")
- error_code = agent_result.get("error_code", 500)
-
- return jsonify(agent_error_response(
- response_text=error_message,
- error_type="agent_processing_failed",
- code=error_code,
- session_id=browser_session_id,
- conversation_id=conversation_id,
- user_id=user_id
- )), error_code
-
- except Exception as e:
- logger.error(f"ask_agent执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="查询处理失败,请稍后重试"
- )), 500
- # ==================== QA反馈系统API ====================
- qa_feedback_manager = None
- def get_qa_feedback_manager():
- """获取QA反馈管理器实例(懒加载)"""
- global qa_feedback_manager
- if qa_feedback_manager is None:
- try:
- qa_feedback_manager = QAFeedbackManager(vanna_instance=vn)
- logger.info("QA反馈管理器实例创建成功")
- except Exception as e:
- logger.critical(f"QA反馈管理器创建失败: {str(e)}")
- raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
- return qa_feedback_manager
- @app.route('/api/v0/qa_feedback/query', methods=['POST'])
- def qa_feedback_query():
- """
- 查询反馈记录API
- 支持分页、筛选和排序功能
- """
- try:
- req = request.get_json(force=True)
-
- # 解析参数,设置默认值
- page = req.get('page', 1)
- page_size = req.get('page_size', 20)
- is_thumb_up = req.get('is_thumb_up')
- create_time_start = req.get('create_time_start')
- create_time_end = req.get('create_time_end')
- is_in_training_data = req.get('is_in_training_data')
- sort_by = req.get('sort_by', 'create_time')
- sort_order = req.get('sort_order', 'desc')
-
- # 参数验证
- if page < 1:
- return jsonify(bad_request_response(
- response_text="页码必须大于0",
- invalid_params=["page"]
- )), 400
-
- if page_size < 1 or page_size > 100:
- return jsonify(bad_request_response(
- response_text="每页大小必须在1-100之间",
- invalid_params=["page_size"]
- )), 400
-
- # 获取反馈管理器并查询
- manager = get_qa_feedback_manager()
- records, total = manager.query_feedback(
- page=page,
- page_size=page_size,
- is_thumb_up=is_thumb_up,
- create_time_start=create_time_start,
- create_time_end=create_time_end,
- is_in_training_data=is_in_training_data,
- sort_by=sort_by,
- sort_order=sort_order
- )
-
- total_pages = (total + page_size - 1) // page_size
-
- return jsonify(success_response(
- response_text=f"查询成功,共找到 {total} 条记录",
- data={
- "records": records,
- "pagination": {
- "page": page,
- "page_size": page_size,
- "total": total,
- "total_pages": total_pages,
- "has_next": page < total_pages,
- "has_prev": page > 1
- }
- }
- ))
-
- except Exception as e:
- logger.error(f"qa_feedback_query执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="查询反馈记录失败,请稍后重试"
- )), 500
- @app.route('/api/v0/qa_feedback/delete/<int:feedback_id>', methods=['DELETE'])
- def qa_feedback_delete(feedback_id):
- """删除反馈记录API"""
- try:
- manager = get_qa_feedback_manager()
- success = manager.delete_feedback(feedback_id)
-
- if success:
- return jsonify(success_response(
- response_text=f"反馈记录删除成功",
- data={"deleted_id": feedback_id}
- ))
- else:
- return jsonify(not_found_response(
- response_text=f"反馈记录不存在 (ID: {feedback_id})"
- )), 404
-
- except Exception as e:
- logger.error(f"qa_feedback_delete执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="删除反馈记录失败,请稍后重试"
- )), 500
- @app.route('/api/v0/qa_feedback/update/<int:feedback_id>', methods=['PUT'])
- def qa_feedback_update(feedback_id):
- """更新反馈记录API"""
- try:
- req = request.get_json(force=True)
-
- allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
- update_data = {}
-
- for field in allowed_fields:
- if field in req:
- update_data[field] = req[field]
-
- if not update_data:
- return jsonify(bad_request_response(
- response_text="没有提供有效的更新字段"
- )), 400
-
- manager = get_qa_feedback_manager()
- success = manager.update_feedback(feedback_id, **update_data)
-
- if success:
- return jsonify(success_response(
- response_text="反馈记录更新成功",
- data={
- "updated_id": feedback_id,
- "updated_fields": list(update_data.keys())
- }
- ))
- else:
- return jsonify(not_found_response(
- response_text=f"反馈记录不存在或无变化 (ID: {feedback_id})"
- )), 404
-
- except Exception as e:
- logger.error(f"qa_feedback_update执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="更新反馈记录失败,请稍后重试"
- )), 500
- @app.route('/api/v0/qa_feedback/add_to_training', methods=['POST'])
- def qa_feedback_add_to_training():
- """将反馈记录添加到训练数据集API"""
- try:
- req = request.get_json(force=True)
- feedback_ids = req.get('feedback_ids', [])
-
- if not feedback_ids or not isinstance(feedback_ids, list):
- return jsonify(bad_request_response(
- response_text="缺少有效的反馈ID列表"
- )), 400
-
- manager = get_qa_feedback_manager()
- records = manager.get_feedback_by_ids(feedback_ids)
-
- if not records:
- return jsonify(not_found_response(
- response_text="未找到任何有效的反馈记录"
- )), 404
-
- positive_count = 0
- negative_count = 0
- successfully_trained_ids = []
-
- for record in records:
- try:
- if record['is_in_training_data']:
- continue
-
- if record['is_thumb_up']:
- training_id = vn.train(
- question=record['question'],
- sql=record['sql']
- )
- positive_count += 1
- else:
- training_id = vn.train_error_sql(
- question=record['question'],
- sql=record['sql']
- )
- negative_count += 1
-
- successfully_trained_ids.append(record['id'])
-
- except Exception as e:
- logger.error(f"训练失败 - 反馈ID: {record['id']}, 错误: {e}")
-
- if successfully_trained_ids:
- manager.mark_training_status(successfully_trained_ids, True)
-
- return jsonify(success_response(
- response_text=f"训练数据添加完成,成功处理 {positive_count + negative_count} 条记录",
- data={
- "positive_trained": positive_count,
- "negative_trained": negative_count,
- "successfully_trained_ids": successfully_trained_ids
- }
- ))
-
- except Exception as e:
- logger.error(f"qa_feedback_add_to_training执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="添加训练数据失败,请稍后重试"
- )), 500
- @app.route('/api/v0/qa_feedback/add', methods=['POST'])
- def qa_feedback_add():
- """添加反馈记录API"""
- try:
- req = request.get_json(force=True)
- question = req.get('question')
- sql = req.get('sql')
- is_thumb_up = req.get('is_thumb_up')
- user_id = req.get('user_id', 'guest')
-
- if not question or not sql or is_thumb_up is None:
- return jsonify(bad_request_response(
- response_text="缺少必需参数"
- )), 400
-
- manager = get_qa_feedback_manager()
- feedback_id = manager.add_feedback(
- question=question,
- sql=sql,
- is_thumb_up=bool(is_thumb_up),
- user_id=user_id
- )
-
- return jsonify(success_response(
- response_text="反馈记录创建成功",
- data={"feedback_id": feedback_id}
- ))
-
- except Exception as e:
- logger.error(f"qa_feedback_add执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="创建反馈记录失败,请稍后重试"
- )), 500
- @app.route('/api/v0/qa_feedback/stats', methods=['GET'])
- def qa_feedback_stats():
- """反馈统计API"""
- try:
- manager = get_qa_feedback_manager()
-
- all_records, total_count = manager.query_feedback(page=1, page_size=1)
- positive_records, positive_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=True)
- negative_records, negative_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=False)
-
- return jsonify(success_response(
- response_text="统计信息获取成功",
- data={
- "total_feedback": total_count,
- "positive_feedback": positive_count,
- "negative_feedback": negative_count,
- "positive_rate": round(positive_count / max(total_count, 1) * 100, 2)
- }
- ))
-
- except Exception as e:
- logger.error(f"qa_feedback_stats执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="获取统计信息失败,请稍后重试"
- )), 500
- # ==================== Redis对话管理API ====================
- @app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
- def get_user_conversations_redis(user_id: str):
- """获取用户的对话列表"""
- try:
- limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
- conversations = redis_conversation_manager.get_conversations(user_id, limit)
-
- return jsonify(success_response(
- response_text="获取用户对话列表成功",
- data={
- "user_id": user_id,
- "conversations": conversations,
- "total_count": len(conversations)
- }
- ))
-
- except Exception as e:
- return jsonify(internal_error_response(
- response_text="获取对话列表失败,请稍后重试"
- )), 500
- @app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
- def get_conversation_messages_redis(conversation_id: str):
- """获取特定对话的消息历史"""
- try:
- limit = request.args.get('limit', type=int)
- messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
- meta = redis_conversation_manager.get_conversation_meta(conversation_id)
-
- return jsonify(success_response(
- response_text="获取对话消息成功",
- data={
- "conversation_id": conversation_id,
- "conversation_meta": meta,
- "messages": messages,
- "message_count": len(messages)
- }
- ))
-
- except Exception as e:
- return jsonify(internal_error_response(
- response_text="获取对话消息失败"
- )), 500
- @app.route('/api/v0/conversation_stats', methods=['GET'])
- def conversation_stats():
- """获取对话系统统计信息"""
- try:
- stats = redis_conversation_manager.get_stats()
-
- return jsonify(success_response(
- response_text="获取统计信息成功",
- data=stats
- ))
-
- except Exception as e:
- return jsonify(internal_error_response(
- response_text="获取统计信息失败,请稍后重试"
- )), 500
- @app.route('/api/v0/conversation_cleanup', methods=['POST'])
- def conversation_cleanup():
- """手动清理过期对话"""
- try:
- redis_conversation_manager.cleanup_expired_conversations()
-
- return jsonify(success_response(
- response_text="对话清理完成"
- ))
-
- except Exception as e:
- return jsonify(internal_error_response(
- response_text="对话清理失败,请稍后重试"
- )), 500
- @app.route('/api/v0/embedding_cache_stats', methods=['GET'])
- def embedding_cache_stats():
- """获取embedding缓存统计信息"""
- try:
- from common.embedding_cache_manager import get_embedding_cache_manager
-
- cache_manager = get_embedding_cache_manager()
- stats = cache_manager.get_cache_stats()
-
- return jsonify(success_response(
- response_text="获取embedding缓存统计成功",
- data=stats
- ))
-
- except Exception as e:
- logger.error(f"获取embedding缓存统计失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="获取embedding缓存统计失败,请稍后重试"
- )), 500
- @app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
- def embedding_cache_cleanup():
- """清空所有embedding缓存"""
- try:
- from common.embedding_cache_manager import get_embedding_cache_manager
-
- cache_manager = get_embedding_cache_manager()
-
- if not cache_manager.is_available():
- return jsonify(internal_error_response(
- response_text="Embedding缓存功能未启用或不可用"
- )), 400
-
- success = cache_manager.clear_all_cache()
-
- if success:
- return jsonify(success_response(
- response_text="所有embedding缓存已清空",
- data={"cleared": True}
- ))
- else:
- return jsonify(internal_error_response(
- response_text="清空embedding缓存失败"
- )), 500
-
- except Exception as e:
- logger.error(f"清空embedding缓存失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="清空embedding缓存失败,请稍后重试"
- )), 500
- # ==================== 训练数据管理API ====================
- def validate_sql_syntax(sql: str) -> tuple[bool, str]:
- """SQL语法检查"""
- try:
- parsed = sqlparse.parse(sql.strip())
-
- if not parsed or not parsed[0].tokens:
- return False, "SQL语法错误:空语句"
-
- sql_upper = sql.strip().upper()
- if not any(sql_upper.startswith(keyword) for keyword in
- ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
- return False, "SQL语法错误:不是有效的SQL语句"
-
- return True, ""
- except Exception as e:
- return False, f"SQL语法错误:{str(e)}"
- @app.route('/api/v0/training_data/stats', methods=['GET'])
- def training_data_stats():
- """获取训练数据统计信息API"""
- try:
- training_data = vn.get_training_data()
-
- if training_data is None or training_data.empty:
- return jsonify(success_response(
- response_text="统计信息获取成功",
- data={
- "total_count": 0,
- "last_updated": datetime.now().isoformat()
- }
- ))
-
- total_count = len(training_data)
-
- return jsonify(success_response(
- response_text="统计信息获取成功",
- data={
- "total_count": total_count,
- "last_updated": datetime.now().isoformat()
- }
- ))
-
- except Exception as e:
- logger.error(f"training_data_stats执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="获取统计信息失败,请稍后重试"
- )), 500
- @app.route('/api/v0/training_data/query', methods=['POST'])
- def training_data_query():
- """分页查询训练数据API"""
- try:
- req = request.get_json(force=True)
-
- page = req.get('page', 1)
- page_size = req.get('page_size', 20)
-
- if page < 1 or page_size < 1 or page_size > 100:
- return jsonify(bad_request_response(
- response_text="参数错误"
- )), 400
-
- training_data = vn.get_training_data()
-
- if training_data is None or training_data.empty:
- return jsonify(success_response(
- response_text="查询成功,暂无训练数据",
- data={
- "records": [],
- "pagination": {
- "page": page,
- "page_size": page_size,
- "total": 0,
- "total_pages": 0,
- "has_next": False,
- "has_prev": False
- }
- }
- ))
-
- records = training_data.to_dict(orient="records")
- total = len(records)
- start_idx = (page - 1) * page_size
- end_idx = start_idx + page_size
- page_data = records[start_idx:end_idx]
- total_pages = (total + page_size - 1) // page_size
-
- return jsonify(success_response(
- response_text=f"查询成功,共找到 {total} 条记录",
- data={
- "records": page_data,
- "pagination": {
- "page": page,
- "page_size": page_size,
- "total": total,
- "total_pages": total_pages,
- "has_next": end_idx < total,
- "has_prev": page > 1
- }
- }
- ))
-
- except Exception as e:
- logger.error(f"training_data_query执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="查询训练数据失败,请稍后重试"
- )), 500
- @app.route('/api/v0/training_data/create', methods=['POST'])
- def training_data_create():
- """创建训练数据API"""
- try:
- req = request.get_json(force=True)
- data = req.get('data')
-
- if not data:
- return jsonify(bad_request_response(
- response_text="缺少必需参数:data"
- )), 400
-
- if isinstance(data, dict):
- data_list = [data]
- elif isinstance(data, list):
- data_list = data
- else:
- return jsonify(bad_request_response(
- response_text="data字段格式错误,应为对象或数组"
- )), 400
-
- if len(data_list) > 50:
- return jsonify(bad_request_response(
- response_text="批量操作最大支持50条记录"
- )), 400
-
- results = []
- successful_count = 0
-
- for index, item in enumerate(data_list):
- try:
- training_type = item.get('training_data_type')
-
- if training_type == 'sql':
- sql = item.get('sql')
- if not sql:
- raise ValueError("SQL字段是必需的")
-
- is_valid, error_msg = validate_sql_syntax(sql)
- if not is_valid:
- raise ValueError(error_msg)
-
- question = item.get('question')
- if question:
- training_id = vn.train(question=question, sql=sql)
- else:
- training_id = vn.train(sql=sql)
-
- elif training_type == 'documentation':
- content = item.get('content')
- if not content:
- raise ValueError("content字段是必需的")
- training_id = vn.train(documentation=content)
-
- elif training_type == 'ddl':
- ddl = item.get('ddl')
- if not ddl:
- raise ValueError("ddl字段是必需的")
- training_id = vn.train(ddl=ddl)
-
- else:
- raise ValueError(f"不支持的训练数据类型: {training_type}")
-
- results.append({
- "index": index,
- "success": True,
- "training_id": training_id,
- "type": training_type,
- "message": f"{training_type}训练数据创建成功"
- })
- successful_count += 1
-
- except Exception as e:
- results.append({
- "index": index,
- "success": False,
- "type": item.get('training_data_type', 'unknown'),
- "error": str(e),
- "message": "创建失败"
- })
-
- failed_count = len(data_list) - successful_count
-
- if failed_count == 0:
- return jsonify(success_response(
- response_text="训练数据创建完成",
- data={
- "total_requested": len(data_list),
- "successfully_created": successful_count,
- "failed_count": failed_count,
- "results": results
- }
- ))
- else:
- return jsonify(error_response(
- response_text=f"训练数据创建部分成功,成功{successful_count}条,失败{failed_count}条",
- data={
- "total_requested": len(data_list),
- "successfully_created": successful_count,
- "failed_count": failed_count,
- "results": results
- }
- )), 207
-
- except Exception as e:
- logger.error(f"training_data_create执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="创建训练数据失败,请稍后重试"
- )), 500
- @app.route('/api/v0/training_data/delete', methods=['POST'])
- def training_data_delete():
- """删除训练数据API"""
- try:
- req = request.get_json(force=True)
- ids = req.get('ids', [])
- confirm = req.get('confirm', False)
-
- if not ids or not isinstance(ids, list):
- return jsonify(bad_request_response(
- response_text="缺少有效的ID列表"
- )), 400
-
- if not confirm:
- return jsonify(bad_request_response(
- response_text="删除操作需要确认,请设置confirm为true"
- )), 400
-
- if len(ids) > 50:
- return jsonify(bad_request_response(
- response_text="批量删除最大支持50条记录"
- )), 400
-
- deleted_ids = []
- failed_ids = []
-
- for training_id in ids:
- try:
- success = vn.remove_training_data(training_id)
- if success:
- deleted_ids.append(training_id)
- else:
- failed_ids.append(training_id)
- except Exception as e:
- failed_ids.append(training_id)
-
- failed_count = len(failed_ids)
-
- if failed_count == 0:
- return jsonify(success_response(
- response_text="训练数据删除完成",
- data={
- "total_requested": len(ids),
- "successfully_deleted": len(deleted_ids),
- "failed_count": failed_count,
- "deleted_ids": deleted_ids,
- "failed_ids": failed_ids
- }
- ))
- else:
- return jsonify(error_response(
- response_text=f"训练数据删除部分成功,成功{len(deleted_ids)}条,失败{failed_count}条",
- data={
- "total_requested": len(ids),
- "successfully_deleted": len(deleted_ids),
- "failed_count": failed_count,
- "deleted_ids": deleted_ids,
- "failed_ids": failed_ids
- }
- )), 207
-
- except Exception as e:
- logger.error(f"training_data_delete执行失败: {str(e)}")
- return jsonify(internal_error_response(
- response_text="删除训练数据失败,请稍后重试"
- )), 500
- # ==================== 启动逻辑 ====================
- def signal_handler(signum, frame):
- """信号处理器,优雅退出"""
- logger.info(f"接收到信号 {signum},准备退出...")
- cleanup_resources()
- sys.exit(0)
- if __name__ == '__main__':
- # 注册信号处理器
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- logger.info("🚀 启动统一API服务...")
- logger.info("📍 服务地址: http://localhost:8084")
- logger.info("🔗 健康检查: http://localhost:8084/health")
- logger.info("📘 React Agent API: http://localhost:8084/api/v0/ask_react_agent")
- logger.info("📘 LangGraph Agent API: http://localhost:8084/api/v0/ask_agent")
-
- # 启动标准Flask应用(支持异步路由)
- app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)
|