unified_api.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174
  1. """
  2. 统一 API 服务
  3. 集成 citu_app.py 指定API 和 react_agent/api.py 的所有功能
  4. 提供数据库问答、Redis对话管理、QA反馈、训练数据管理、React Agent等功能
  5. 使用普通 Flask 应用 + ASGI 包装实现异步支持
  6. """
  7. import asyncio
  8. import logging
  9. import atexit
  10. import os
  11. import sys
  12. from datetime import datetime, timedelta
  13. from typing import Optional, Dict, Any, TYPE_CHECKING, Union
  14. import signal
  15. if TYPE_CHECKING:
  16. from react_agent.agent import CustomReactAgent
  17. # 初始化日志系统 - 必须在最前面
  18. from core.logging import initialize_logging, get_app_logger
  19. initialize_logging()
  20. # 标准 Flask 导入
  21. from flask import Flask, request, jsonify, session
  22. import redis.asyncio as redis
  23. # 基础依赖
  24. import pandas as pd
  25. import json
  26. import sqlparse
  27. # 项目模块导入
  28. from core.vanna_llm_factory import create_vanna_instance
  29. from common.redis_conversation_manager import RedisConversationManager
  30. from common.qa_feedback_manager import QAFeedbackManager
  31. from common.result import (
  32. success_response, bad_request_response, not_found_response, internal_error_response,
  33. error_response, service_unavailable_response,
  34. agent_success_response, agent_error_response,
  35. validation_failed_response
  36. )
  37. from app_config import (
  38. USER_MAX_CONVERSATIONS, CONVERSATION_CONTEXT_COUNT,
  39. DEFAULT_ANONYMOUS_USER, ENABLE_QUESTION_ANSWER_CACHE
  40. )
  41. # 创建标准 Flask 应用
  42. app = Flask(__name__)
  43. # 创建日志记录器
  44. logger = get_app_logger("UnifiedApp")
  45. # React Agent 导入
  46. try:
  47. from react_agent.agent import CustomReactAgent
  48. except ImportError:
  49. try:
  50. from test.custom_react_agent.agent import CustomReactAgent
  51. except ImportError:
  52. logger.warning("无法导入 CustomReactAgent,React Agent功能将不可用")
  53. CustomReactAgent = None
  54. # 初始化核心组件
  55. vn = create_vanna_instance()
  56. redis_conversation_manager = RedisConversationManager()
  57. # ==================== React Agent 全局实例管理 ====================
  58. _react_agent_instance: Optional[Any] = None
  59. _redis_client: Optional[redis.Redis] = None
  60. def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
  61. """验证请求数据,并支持从thread_id中推断user_id"""
  62. errors = []
  63. # 验证 question(必填)
  64. question = data.get('question', '')
  65. if not question or not question.strip():
  66. errors.append('问题不能为空')
  67. elif len(question) > 2000:
  68. errors.append('问题长度不能超过2000字符')
  69. # 优先获取 thread_id
  70. thread_id = data.get('thread_id') or data.get('conversation_id')
  71. # 获取 user_id,但暂不设置默认值
  72. user_id = data.get('user_id')
  73. # 如果没有传递 user_id,则尝试从 thread_id 中推断
  74. if not user_id:
  75. if thread_id and ':' in thread_id:
  76. inferred_user_id = thread_id.split(':', 1)[0]
  77. if inferred_user_id:
  78. user_id = inferred_user_id
  79. logger.info(f"👤 未提供user_id,从 thread_id '{thread_id}' 中推断出: '{user_id}'")
  80. else:
  81. user_id = 'guest'
  82. else:
  83. user_id = 'guest'
  84. # 验证 user_id 长度
  85. if user_id and len(user_id) > 50:
  86. errors.append('用户ID长度不能超过50字符')
  87. # 用户ID与会话ID一致性校验
  88. if thread_id:
  89. if ':' not in thread_id:
  90. errors.append('会话ID格式无效,期望格式为 user_id:timestamp')
  91. else:
  92. thread_user_id = thread_id.split(':', 1)[0]
  93. if thread_user_id != user_id:
  94. errors.append(f'会话归属验证失败:会话ID [{thread_id}] 不属于当前用户 [{user_id}]')
  95. if errors:
  96. raise ValueError('; '.join(errors))
  97. return {
  98. 'question': question.strip(),
  99. 'user_id': user_id,
  100. 'thread_id': thread_id # 可选,不传则自动生成新会话
  101. }
  102. async def get_react_agent() -> Any:
  103. """获取 React Agent 实例(懒加载)"""
  104. global _react_agent_instance, _redis_client
  105. if _react_agent_instance is None:
  106. if CustomReactAgent is None:
  107. logger.error("❌ CustomReactAgent 未能导入,无法初始化")
  108. raise ImportError("CustomReactAgent 未能导入")
  109. logger.info("🚀 正在异步初始化 Custom React Agent...")
  110. try:
  111. # 设置环境变量
  112. os.environ['REDIS_URL'] = 'redis://localhost:6379'
  113. # 初始化共享的Redis客户端
  114. _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
  115. await _redis_client.ping()
  116. logger.info("✅ Redis客户端连接成功")
  117. _react_agent_instance = await CustomReactAgent.create()
  118. logger.info("✅ React Agent 异步初始化完成")
  119. except Exception as e:
  120. logger.error(f"❌ React Agent 异步初始化失败: {e}")
  121. raise
  122. return _react_agent_instance
  123. async def ensure_agent_ready() -> bool:
  124. """异步确保Agent实例可用"""
  125. global _react_agent_instance
  126. if _react_agent_instance is None:
  127. await get_react_agent()
  128. # 测试Agent是否还可用
  129. try:
  130. test_result = await _react_agent_instance.get_user_recent_conversations("__test__", 1)
  131. return True
  132. except Exception as e:
  133. logger.warning(f"⚠️ Agent实例不可用: {e}")
  134. _react_agent_instance = None
  135. await get_react_agent()
  136. return True
  137. def cleanup_resources():
  138. """清理资源"""
  139. global _react_agent_instance, _redis_client
  140. async def async_cleanup():
  141. if _react_agent_instance:
  142. await _react_agent_instance.close()
  143. logger.info("✅ React Agent 资源已清理")
  144. if _redis_client:
  145. await _redis_client.aclose()
  146. logger.info("✅ Redis客户端已关闭")
  147. try:
  148. asyncio.run(async_cleanup())
  149. except Exception as e:
  150. logger.error(f"清理资源失败: {e}")
  151. atexit.register(cleanup_resources)
  152. # ==================== 基础路由 ====================
  153. @app.route("/")
  154. def index():
  155. """根路径健康检查"""
  156. return jsonify({"message": "统一API服务正在运行", "version": "1.0.0"})
  157. @app.route('/health', methods=['GET'])
  158. def health_check():
  159. """健康检查端点"""
  160. try:
  161. health_status = {
  162. "status": "healthy",
  163. "react_agent_initialized": _react_agent_instance is not None,
  164. "timestamp": datetime.now().isoformat(),
  165. "services": {
  166. "redis": redis_conversation_manager.is_available(),
  167. "vanna": vn is not None
  168. }
  169. }
  170. return jsonify(health_status), 200
  171. except Exception as e:
  172. logger.error(f"健康检查失败: {e}")
  173. return jsonify({"status": "unhealthy", "error": str(e)}), 500
  174. # ==================== React Agent API ====================
  175. @app.route("/api/v0/ask_react_agent", methods=["POST"])
  176. async def ask_react_agent():
  177. """异步React Agent智能问答接口"""
  178. global _react_agent_instance
  179. # 确保Agent已初始化
  180. if not await ensure_agent_ready():
  181. return jsonify({
  182. "code": 503,
  183. "message": "服务未就绪",
  184. "success": False,
  185. "error": "React Agent 初始化失败"
  186. }), 503
  187. try:
  188. # 获取请求数据
  189. try:
  190. data = request.get_json(force=True)
  191. except Exception as json_error:
  192. logger.warning(f"⚠️ JSON解析失败: {json_error}")
  193. return jsonify({
  194. "code": 400,
  195. "message": "请求格式错误",
  196. "success": False,
  197. "error": "无效的JSON格式",
  198. "details": str(json_error)
  199. }), 400
  200. if not data:
  201. return jsonify({
  202. "code": 400,
  203. "message": "请求参数错误",
  204. "success": False,
  205. "error": "请求体不能为空"
  206. }), 400
  207. # 验证请求数据
  208. validated_data = validate_request_data(data)
  209. logger.info(f"📨 收到React Agent请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
  210. # 异步调用处理
  211. agent_result = await _react_agent_instance.chat(
  212. message=validated_data['question'],
  213. user_id=validated_data['user_id'],
  214. thread_id=validated_data['thread_id']
  215. )
  216. if not agent_result.get("success", False):
  217. # Agent处理失败
  218. error_msg = agent_result.get("error", "React Agent处理失败")
  219. logger.error(f"❌ React Agent处理失败: {error_msg}")
  220. return jsonify({
  221. "code": 500,
  222. "message": "处理失败",
  223. "success": False,
  224. "error": error_msg,
  225. "data": {
  226. "conversation_id": agent_result.get("thread_id"),
  227. "user_id": validated_data['user_id'],
  228. "timestamp": datetime.now().isoformat()
  229. }
  230. }), 500
  231. # Agent处理成功
  232. api_data = agent_result.get("api_data", {})
  233. # 构建响应数据(按照 react_agent/api.py 的正确格式)
  234. response_data = {
  235. "response": api_data.get("response", ""),
  236. "conversation_id": agent_result.get("thread_id"),
  237. "user_id": validated_data['user_id'],
  238. "react_agent_meta": api_data.get("react_agent_meta", {
  239. "thread_id": agent_result.get("thread_id"),
  240. "agent_version": "custom_react_v1_async"
  241. }),
  242. "timestamp": datetime.now().isoformat()
  243. }
  244. # 可选字段:SQL(仅当执行SQL时存在)
  245. if "sql" in api_data:
  246. response_data["sql"] = api_data["sql"]
  247. # 可选字段:records(仅当有查询结果时存在)
  248. if "records" in api_data:
  249. response_data["records"] = api_data["records"]
  250. return jsonify({
  251. "code": 200,
  252. "message": "处理成功",
  253. "success": True,
  254. "data": response_data
  255. }), 200
  256. except ValueError as ve:
  257. # 参数验证错误
  258. logger.warning(f"⚠️ 参数验证失败: {ve}")
  259. return jsonify({
  260. "code": 400,
  261. "message": "参数验证失败",
  262. "success": False,
  263. "error": str(ve)
  264. }), 400
  265. except Exception as e:
  266. logger.error(f"❌ React Agent API 异常: {e}")
  267. return jsonify({
  268. "code": 500,
  269. "message": "内部服务错误",
  270. "success": False,
  271. "error": "服务暂时不可用,请稍后重试"
  272. }), 500
  273. # ==================== LangGraph Agent API ====================
  274. # 全局Agent实例(单例模式)
  275. citu_langraph_agent = None
  276. def get_citu_langraph_agent():
  277. """获取LangGraph Agent实例(懒加载)"""
  278. global citu_langraph_agent
  279. if citu_langraph_agent is None:
  280. try:
  281. from agent.citu_agent import CituLangGraphAgent
  282. logger.info("开始创建LangGraph Agent实例...")
  283. citu_langraph_agent = CituLangGraphAgent()
  284. logger.info("LangGraph Agent实例创建成功")
  285. except ImportError as e:
  286. logger.critical(f"Agent模块导入失败: {str(e)}")
  287. raise Exception(f"Agent模块导入失败: {str(e)}")
  288. except Exception as e:
  289. logger.critical(f"LangGraph Agent实例创建失败: {str(e)}")
  290. raise Exception(f"Agent初始化失败: {str(e)}")
  291. return citu_langraph_agent
  292. @app.route('/api/v0/ask_agent', methods=['POST'])
  293. def ask_agent():
  294. """支持对话上下文的ask_agent API"""
  295. req = request.get_json(force=True)
  296. question = req.get("question", None)
  297. browser_session_id = req.get("session_id", None)
  298. user_id_input = req.get("user_id", None)
  299. conversation_id_input = req.get("conversation_id", None)
  300. continue_conversation = req.get("continue_conversation", False)
  301. api_routing_mode = req.get("routing_mode", None)
  302. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  303. if not question:
  304. return jsonify(bad_request_response(
  305. response_text="缺少必需参数:question",
  306. missing_params=["question"]
  307. )), 400
  308. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  309. return jsonify(bad_request_response(
  310. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  311. invalid_params=["routing_mode"]
  312. )), 400
  313. try:
  314. # 获取登录用户ID
  315. login_user_id = session.get('user_id') if 'user_id' in session else None
  316. # 智能ID解析
  317. user_id = redis_conversation_manager.resolve_user_id(
  318. user_id_input, browser_session_id, request.remote_addr, login_user_id
  319. )
  320. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  321. user_id, conversation_id_input, continue_conversation
  322. )
  323. # 获取上下文
  324. context = redis_conversation_manager.get_context(conversation_id)
  325. # 保存用户消息
  326. redis_conversation_manager.save_message(conversation_id, "user", question)
  327. # 构建带上下文的问题
  328. if context:
  329. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  330. logger.info(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  331. else:
  332. enhanced_question = question
  333. logger.info(f"[AGENT_API] 新对话,无上下文")
  334. # Agent处理
  335. try:
  336. agent = get_citu_langraph_agent()
  337. except Exception as e:
  338. logger.critical(f"Agent初始化失败: {str(e)}")
  339. return jsonify(service_unavailable_response(
  340. response_text="AI服务暂时不可用,请稍后重试",
  341. can_retry=True
  342. )), 503
  343. # 异步调用Agent处理问题
  344. import asyncio
  345. agent_result = asyncio.run(agent.process_question(
  346. question=enhanced_question,
  347. session_id=browser_session_id
  348. ))
  349. # 处理Agent结果
  350. if agent_result.get("success", False):
  351. response_type = agent_result.get("type", "UNKNOWN")
  352. response_text = agent_result.get("response", "")
  353. sql = agent_result.get("sql")
  354. query_result = agent_result.get("query_result")
  355. summary = agent_result.get("summary")
  356. execution_path = agent_result.get("execution_path", [])
  357. classification_info = agent_result.get("classification_info", {})
  358. # 确定助手回复内容的优先级
  359. if response_type == "DATABASE":
  360. if response_text:
  361. assistant_response = response_text
  362. elif summary:
  363. assistant_response = summary
  364. elif query_result:
  365. row_count = query_result.get("row_count", 0)
  366. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  367. else:
  368. assistant_response = "数据库查询已处理。"
  369. else:
  370. assistant_response = response_text
  371. # 保存助手回复
  372. redis_conversation_manager.save_message(
  373. conversation_id, "assistant", assistant_response,
  374. metadata={
  375. "type": response_type,
  376. "sql": sql,
  377. "execution_path": execution_path
  378. }
  379. )
  380. return jsonify(agent_success_response(
  381. response_type=response_type,
  382. response=response_text,
  383. sql=sql,
  384. records=query_result,
  385. summary=summary,
  386. session_id=browser_session_id,
  387. execution_path=execution_path,
  388. classification_info=classification_info,
  389. conversation_id=conversation_id,
  390. user_id=user_id,
  391. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  392. context_used=bool(context),
  393. from_cache=False,
  394. conversation_status=conversation_status["status"],
  395. conversation_message=conversation_status["message"]
  396. ))
  397. else:
  398. # 错误处理
  399. error_message = agent_result.get("error", "Agent处理失败")
  400. error_code = agent_result.get("error_code", 500)
  401. return jsonify(agent_error_response(
  402. response_text=error_message,
  403. error_type="agent_processing_failed",
  404. code=error_code,
  405. session_id=browser_session_id,
  406. conversation_id=conversation_id,
  407. user_id=user_id
  408. )), error_code
  409. except Exception as e:
  410. logger.error(f"ask_agent执行失败: {str(e)}")
  411. return jsonify(internal_error_response(
  412. response_text="查询处理失败,请稍后重试"
  413. )), 500
  414. # ==================== QA反馈系统API ====================
  415. qa_feedback_manager = None
  416. def get_qa_feedback_manager():
  417. """获取QA反馈管理器实例(懒加载)"""
  418. global qa_feedback_manager
  419. if qa_feedback_manager is None:
  420. try:
  421. qa_feedback_manager = QAFeedbackManager(vanna_instance=vn)
  422. logger.info("QA反馈管理器实例创建成功")
  423. except Exception as e:
  424. logger.critical(f"QA反馈管理器创建失败: {str(e)}")
  425. raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
  426. return qa_feedback_manager
  427. @app.route('/api/v0/qa_feedback/query', methods=['POST'])
  428. def qa_feedback_query():
  429. """查询反馈记录API"""
  430. try:
  431. req = request.get_json(force=True)
  432. page = req.get('page', 1)
  433. page_size = req.get('page_size', 20)
  434. is_thumb_up = req.get('is_thumb_up')
  435. if page < 1 or page_size < 1 or page_size > 100:
  436. return jsonify(bad_request_response(
  437. response_text="参数错误"
  438. )), 400
  439. manager = get_qa_feedback_manager()
  440. records, total = manager.query_feedback(
  441. page=page,
  442. page_size=page_size,
  443. is_thumb_up=is_thumb_up
  444. )
  445. total_pages = (total + page_size - 1) // page_size
  446. return jsonify(success_response(
  447. response_text=f"查询成功,共找到 {total} 条记录",
  448. data={
  449. "records": records,
  450. "pagination": {
  451. "page": page,
  452. "page_size": page_size,
  453. "total": total,
  454. "total_pages": total_pages,
  455. "has_next": page < total_pages,
  456. "has_prev": page > 1
  457. }
  458. }
  459. ))
  460. except Exception as e:
  461. logger.error(f"qa_feedback_query执行失败: {str(e)}")
  462. return jsonify(internal_error_response(
  463. response_text="查询反馈记录失败,请稍后重试"
  464. )), 500
  465. @app.route('/api/v0/qa_feedback/delete/<int:feedback_id>', methods=['DELETE'])
  466. def qa_feedback_delete(feedback_id):
  467. """删除反馈记录API"""
  468. try:
  469. manager = get_qa_feedback_manager()
  470. success = manager.delete_feedback(feedback_id)
  471. if success:
  472. return jsonify(success_response(
  473. response_text=f"反馈记录删除成功",
  474. data={"deleted_id": feedback_id}
  475. ))
  476. else:
  477. return jsonify(not_found_response(
  478. response_text=f"反馈记录不存在 (ID: {feedback_id})"
  479. )), 404
  480. except Exception as e:
  481. logger.error(f"qa_feedback_delete执行失败: {str(e)}")
  482. return jsonify(internal_error_response(
  483. response_text="删除反馈记录失败,请稍后重试"
  484. )), 500
  485. @app.route('/api/v0/qa_feedback/update/<int:feedback_id>', methods=['PUT'])
  486. def qa_feedback_update(feedback_id):
  487. """更新反馈记录API"""
  488. try:
  489. req = request.get_json(force=True)
  490. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  491. update_data = {}
  492. for field in allowed_fields:
  493. if field in req:
  494. update_data[field] = req[field]
  495. if not update_data:
  496. return jsonify(bad_request_response(
  497. response_text="没有提供有效的更新字段"
  498. )), 400
  499. manager = get_qa_feedback_manager()
  500. success = manager.update_feedback(feedback_id, **update_data)
  501. if success:
  502. return jsonify(success_response(
  503. response_text="反馈记录更新成功",
  504. data={
  505. "updated_id": feedback_id,
  506. "updated_fields": list(update_data.keys())
  507. }
  508. ))
  509. else:
  510. return jsonify(not_found_response(
  511. response_text=f"反馈记录不存在或无变化 (ID: {feedback_id})"
  512. )), 404
  513. except Exception as e:
  514. logger.error(f"qa_feedback_update执行失败: {str(e)}")
  515. return jsonify(internal_error_response(
  516. response_text="更新反馈记录失败,请稍后重试"
  517. )), 500
  518. @app.route('/api/v0/qa_feedback/add_to_training', methods=['POST'])
  519. def qa_feedback_add_to_training():
  520. """将反馈记录添加到训练数据集API"""
  521. try:
  522. req = request.get_json(force=True)
  523. feedback_ids = req.get('feedback_ids', [])
  524. if not feedback_ids or not isinstance(feedback_ids, list):
  525. return jsonify(bad_request_response(
  526. response_text="缺少有效的反馈ID列表"
  527. )), 400
  528. manager = get_qa_feedback_manager()
  529. records = manager.get_feedback_by_ids(feedback_ids)
  530. if not records:
  531. return jsonify(not_found_response(
  532. response_text="未找到任何有效的反馈记录"
  533. )), 404
  534. positive_count = 0
  535. negative_count = 0
  536. successfully_trained_ids = []
  537. for record in records:
  538. try:
  539. if record['is_in_training_data']:
  540. continue
  541. if record['is_thumb_up']:
  542. training_id = vn.train(
  543. question=record['question'],
  544. sql=record['sql']
  545. )
  546. positive_count += 1
  547. else:
  548. training_id = vn.train_error_sql(
  549. question=record['question'],
  550. sql=record['sql']
  551. )
  552. negative_count += 1
  553. successfully_trained_ids.append(record['id'])
  554. except Exception as e:
  555. logger.error(f"训练失败 - 反馈ID: {record['id']}, 错误: {e}")
  556. if successfully_trained_ids:
  557. manager.mark_training_status(successfully_trained_ids, True)
  558. return jsonify(success_response(
  559. response_text=f"训练数据添加完成,成功处理 {positive_count + negative_count} 条记录",
  560. data={
  561. "positive_trained": positive_count,
  562. "negative_trained": negative_count,
  563. "successfully_trained_ids": successfully_trained_ids
  564. }
  565. ))
  566. except Exception as e:
  567. logger.error(f"qa_feedback_add_to_training执行失败: {str(e)}")
  568. return jsonify(internal_error_response(
  569. response_text="添加训练数据失败,请稍后重试"
  570. )), 500
  571. @app.route('/api/v0/qa_feedback/add', methods=['POST'])
  572. def qa_feedback_add():
  573. """添加反馈记录API"""
  574. try:
  575. req = request.get_json(force=True)
  576. question = req.get('question')
  577. sql = req.get('sql')
  578. is_thumb_up = req.get('is_thumb_up')
  579. user_id = req.get('user_id', 'guest')
  580. if not question or not sql or is_thumb_up is None:
  581. return jsonify(bad_request_response(
  582. response_text="缺少必需参数"
  583. )), 400
  584. manager = get_qa_feedback_manager()
  585. feedback_id = manager.add_feedback(
  586. question=question,
  587. sql=sql,
  588. is_thumb_up=bool(is_thumb_up),
  589. user_id=user_id
  590. )
  591. return jsonify(success_response(
  592. response_text="反馈记录创建成功",
  593. data={"feedback_id": feedback_id}
  594. ))
  595. except Exception as e:
  596. logger.error(f"qa_feedback_add执行失败: {str(e)}")
  597. return jsonify(internal_error_response(
  598. response_text="创建反馈记录失败,请稍后重试"
  599. )), 500
  600. @app.route('/api/v0/qa_feedback/stats', methods=['GET'])
  601. def qa_feedback_stats():
  602. """反馈统计API"""
  603. try:
  604. manager = get_qa_feedback_manager()
  605. all_records, total_count = manager.query_feedback(page=1, page_size=1)
  606. positive_records, positive_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=True)
  607. negative_records, negative_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=False)
  608. return jsonify(success_response(
  609. response_text="统计信息获取成功",
  610. data={
  611. "total_feedback": total_count,
  612. "positive_feedback": positive_count,
  613. "negative_feedback": negative_count,
  614. "positive_rate": round(positive_count / max(total_count, 1) * 100, 2)
  615. }
  616. ))
  617. except Exception as e:
  618. logger.error(f"qa_feedback_stats执行失败: {str(e)}")
  619. return jsonify(internal_error_response(
  620. response_text="获取统计信息失败,请稍后重试"
  621. )), 500
  622. # ==================== Redis对话管理API ====================
  623. @app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  624. def get_user_conversations_redis(user_id: str):
  625. """获取用户的对话列表"""
  626. try:
  627. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  628. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  629. return jsonify(success_response(
  630. response_text="获取用户对话列表成功",
  631. data={
  632. "user_id": user_id,
  633. "conversations": conversations,
  634. "total_count": len(conversations)
  635. }
  636. ))
  637. except Exception as e:
  638. return jsonify(internal_error_response(
  639. response_text="获取对话列表失败,请稍后重试"
  640. )), 500
  641. @app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  642. def get_conversation_messages_redis(conversation_id: str):
  643. """获取特定对话的消息历史"""
  644. try:
  645. limit = request.args.get('limit', type=int)
  646. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  647. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  648. return jsonify(success_response(
  649. response_text="获取对话消息成功",
  650. data={
  651. "conversation_id": conversation_id,
  652. "conversation_meta": meta,
  653. "messages": messages,
  654. "message_count": len(messages)
  655. }
  656. ))
  657. except Exception as e:
  658. return jsonify(internal_error_response(
  659. response_text="获取对话消息失败"
  660. )), 500
  661. @app.route('/api/v0/conversation_stats', methods=['GET'])
  662. def conversation_stats():
  663. """获取对话系统统计信息"""
  664. try:
  665. stats = redis_conversation_manager.get_stats()
  666. return jsonify(success_response(
  667. response_text="获取统计信息成功",
  668. data=stats
  669. ))
  670. except Exception as e:
  671. return jsonify(internal_error_response(
  672. response_text="获取统计信息失败,请稍后重试"
  673. )), 500
  674. @app.route('/api/v0/conversation_cleanup', methods=['POST'])
  675. def conversation_cleanup():
  676. """手动清理过期对话"""
  677. try:
  678. redis_conversation_manager.cleanup_expired_conversations()
  679. return jsonify(success_response(
  680. response_text="对话清理完成"
  681. ))
  682. except Exception as e:
  683. return jsonify(internal_error_response(
  684. response_text="对话清理失败,请稍后重试"
  685. )), 500
  686. @app.route('/api/v0/embedding_cache_stats', methods=['GET'])
  687. def embedding_cache_stats():
  688. """获取embedding缓存统计信息"""
  689. try:
  690. from common.embedding_cache_manager import get_embedding_cache_manager
  691. cache_manager = get_embedding_cache_manager()
  692. stats = cache_manager.get_cache_stats()
  693. return jsonify(success_response(
  694. response_text="获取embedding缓存统计成功",
  695. data=stats
  696. ))
  697. except Exception as e:
  698. logger.error(f"获取embedding缓存统计失败: {str(e)}")
  699. return jsonify(internal_error_response(
  700. response_text="获取embedding缓存统计失败,请稍后重试"
  701. )), 500
  702. @app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
  703. def embedding_cache_cleanup():
  704. """清空所有embedding缓存"""
  705. try:
  706. from common.embedding_cache_manager import get_embedding_cache_manager
  707. cache_manager = get_embedding_cache_manager()
  708. if not cache_manager.is_available():
  709. return jsonify(internal_error_response(
  710. response_text="Embedding缓存功能未启用或不可用"
  711. )), 400
  712. success = cache_manager.clear_all_cache()
  713. if success:
  714. return jsonify(success_response(
  715. response_text="所有embedding缓存已清空",
  716. data={"cleared": True}
  717. ))
  718. else:
  719. return jsonify(internal_error_response(
  720. response_text="清空embedding缓存失败"
  721. )), 500
  722. except Exception as e:
  723. logger.error(f"清空embedding缓存失败: {str(e)}")
  724. return jsonify(internal_error_response(
  725. response_text="清空embedding缓存失败,请稍后重试"
  726. )), 500
  727. # ==================== 训练数据管理API ====================
  728. def validate_sql_syntax(sql: str) -> tuple[bool, str]:
  729. """SQL语法检查"""
  730. try:
  731. parsed = sqlparse.parse(sql.strip())
  732. if not parsed or not parsed[0].tokens:
  733. return False, "SQL语法错误:空语句"
  734. sql_upper = sql.strip().upper()
  735. if not any(sql_upper.startswith(keyword) for keyword in
  736. ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
  737. return False, "SQL语法错误:不是有效的SQL语句"
  738. return True, ""
  739. except Exception as e:
  740. return False, f"SQL语法错误:{str(e)}"
  741. @app.route('/api/v0/training_data/stats', methods=['GET'])
  742. def training_data_stats():
  743. """获取训练数据统计信息API"""
  744. try:
  745. training_data = vn.get_training_data()
  746. if training_data is None or training_data.empty:
  747. return jsonify(success_response(
  748. response_text="统计信息获取成功",
  749. data={
  750. "total_count": 0,
  751. "last_updated": datetime.now().isoformat()
  752. }
  753. ))
  754. total_count = len(training_data)
  755. return jsonify(success_response(
  756. response_text="统计信息获取成功",
  757. data={
  758. "total_count": total_count,
  759. "last_updated": datetime.now().isoformat()
  760. }
  761. ))
  762. except Exception as e:
  763. logger.error(f"training_data_stats执行失败: {str(e)}")
  764. return jsonify(internal_error_response(
  765. response_text="获取统计信息失败,请稍后重试"
  766. )), 500
  767. @app.route('/api/v0/training_data/query', methods=['POST'])
  768. def training_data_query():
  769. """分页查询训练数据API"""
  770. try:
  771. req = request.get_json(force=True)
  772. page = req.get('page', 1)
  773. page_size = req.get('page_size', 20)
  774. if page < 1 or page_size < 1 or page_size > 100:
  775. return jsonify(bad_request_response(
  776. response_text="参数错误"
  777. )), 400
  778. training_data = vn.get_training_data()
  779. if training_data is None or training_data.empty:
  780. return jsonify(success_response(
  781. response_text="查询成功,暂无训练数据",
  782. data={
  783. "records": [],
  784. "pagination": {
  785. "page": page,
  786. "page_size": page_size,
  787. "total": 0,
  788. "total_pages": 0,
  789. "has_next": False,
  790. "has_prev": False
  791. }
  792. }
  793. ))
  794. records = training_data.to_dict(orient="records")
  795. total = len(records)
  796. start_idx = (page - 1) * page_size
  797. end_idx = start_idx + page_size
  798. page_data = records[start_idx:end_idx]
  799. total_pages = (total + page_size - 1) // page_size
  800. return jsonify(success_response(
  801. response_text=f"查询成功,共找到 {total} 条记录",
  802. data={
  803. "records": page_data,
  804. "pagination": {
  805. "page": page,
  806. "page_size": page_size,
  807. "total": total,
  808. "total_pages": total_pages,
  809. "has_next": end_idx < total,
  810. "has_prev": page > 1
  811. }
  812. }
  813. ))
  814. except Exception as e:
  815. logger.error(f"training_data_query执行失败: {str(e)}")
  816. return jsonify(internal_error_response(
  817. response_text="查询训练数据失败,请稍后重试"
  818. )), 500
  819. @app.route('/api/v0/training_data/create', methods=['POST'])
  820. def training_data_create():
  821. """创建训练数据API"""
  822. try:
  823. req = request.get_json(force=True)
  824. data = req.get('data')
  825. if not data:
  826. return jsonify(bad_request_response(
  827. response_text="缺少必需参数:data"
  828. )), 400
  829. if isinstance(data, dict):
  830. data_list = [data]
  831. elif isinstance(data, list):
  832. data_list = data
  833. else:
  834. return jsonify(bad_request_response(
  835. response_text="data字段格式错误,应为对象或数组"
  836. )), 400
  837. if len(data_list) > 50:
  838. return jsonify(bad_request_response(
  839. response_text="批量操作最大支持50条记录"
  840. )), 400
  841. results = []
  842. successful_count = 0
  843. for index, item in enumerate(data_list):
  844. try:
  845. training_type = item.get('training_data_type')
  846. if training_type == 'sql':
  847. sql = item.get('sql')
  848. if not sql:
  849. raise ValueError("SQL字段是必需的")
  850. is_valid, error_msg = validate_sql_syntax(sql)
  851. if not is_valid:
  852. raise ValueError(error_msg)
  853. question = item.get('question')
  854. if question:
  855. training_id = vn.train(question=question, sql=sql)
  856. else:
  857. training_id = vn.train(sql=sql)
  858. elif training_type == 'documentation':
  859. content = item.get('content')
  860. if not content:
  861. raise ValueError("content字段是必需的")
  862. training_id = vn.train(documentation=content)
  863. elif training_type == 'ddl':
  864. ddl = item.get('ddl')
  865. if not ddl:
  866. raise ValueError("ddl字段是必需的")
  867. training_id = vn.train(ddl=ddl)
  868. else:
  869. raise ValueError(f"不支持的训练数据类型: {training_type}")
  870. results.append({
  871. "index": index,
  872. "success": True,
  873. "training_id": training_id,
  874. "type": training_type,
  875. "message": f"{training_type}训练数据创建成功"
  876. })
  877. successful_count += 1
  878. except Exception as e:
  879. results.append({
  880. "index": index,
  881. "success": False,
  882. "type": item.get('training_data_type', 'unknown'),
  883. "error": str(e),
  884. "message": "创建失败"
  885. })
  886. failed_count = len(data_list) - successful_count
  887. if failed_count == 0:
  888. return jsonify(success_response(
  889. response_text="训练数据创建完成",
  890. data={
  891. "total_requested": len(data_list),
  892. "successfully_created": successful_count,
  893. "failed_count": failed_count,
  894. "results": results
  895. }
  896. ))
  897. else:
  898. return jsonify(error_response(
  899. response_text=f"训练数据创建部分成功,成功{successful_count}条,失败{failed_count}条",
  900. data={
  901. "total_requested": len(data_list),
  902. "successfully_created": successful_count,
  903. "failed_count": failed_count,
  904. "results": results
  905. }
  906. )), 207
  907. except Exception as e:
  908. logger.error(f"training_data_create执行失败: {str(e)}")
  909. return jsonify(internal_error_response(
  910. response_text="创建训练数据失败,请稍后重试"
  911. )), 500
  912. @app.route('/api/v0/training_data/delete', methods=['POST'])
  913. def training_data_delete():
  914. """删除训练数据API"""
  915. try:
  916. req = request.get_json(force=True)
  917. ids = req.get('ids', [])
  918. confirm = req.get('confirm', False)
  919. if not ids or not isinstance(ids, list):
  920. return jsonify(bad_request_response(
  921. response_text="缺少有效的ID列表"
  922. )), 400
  923. if not confirm:
  924. return jsonify(bad_request_response(
  925. response_text="删除操作需要确认,请设置confirm为true"
  926. )), 400
  927. if len(ids) > 50:
  928. return jsonify(bad_request_response(
  929. response_text="批量删除最大支持50条记录"
  930. )), 400
  931. deleted_ids = []
  932. failed_ids = []
  933. for training_id in ids:
  934. try:
  935. success = vn.remove_training_data(training_id)
  936. if success:
  937. deleted_ids.append(training_id)
  938. else:
  939. failed_ids.append(training_id)
  940. except Exception as e:
  941. failed_ids.append(training_id)
  942. failed_count = len(failed_ids)
  943. if failed_count == 0:
  944. return jsonify(success_response(
  945. response_text="训练数据删除完成",
  946. data={
  947. "total_requested": len(ids),
  948. "successfully_deleted": len(deleted_ids),
  949. "failed_count": failed_count,
  950. "deleted_ids": deleted_ids,
  951. "failed_ids": failed_ids
  952. }
  953. ))
  954. else:
  955. return jsonify(error_response(
  956. response_text=f"训练数据删除部分成功,成功{len(deleted_ids)}条,失败{failed_count}条",
  957. data={
  958. "total_requested": len(ids),
  959. "successfully_deleted": len(deleted_ids),
  960. "failed_count": failed_count,
  961. "deleted_ids": deleted_ids,
  962. "failed_ids": failed_ids
  963. }
  964. )), 207
  965. except Exception as e:
  966. logger.error(f"training_data_delete执行失败: {str(e)}")
  967. return jsonify(internal_error_response(
  968. response_text="删除训练数据失败,请稍后重试"
  969. )), 500
  970. # ==================== 启动逻辑 ====================
  971. def signal_handler(signum, frame):
  972. """信号处理器,优雅退出"""
  973. logger.info(f"接收到信号 {signum},准备退出...")
  974. cleanup_resources()
  975. sys.exit(0)
  976. if __name__ == '__main__':
  977. # 注册信号处理器
  978. signal.signal(signal.SIGINT, signal_handler)
  979. signal.signal(signal.SIGTERM, signal_handler)
  980. logger.info("🚀 启动统一API服务...")
  981. logger.info("📍 服务地址: http://localhost:8084")
  982. logger.info("🔗 健康检查: http://localhost:8084/health")
  983. logger.info("📘 React Agent API: http://localhost:8084/api/v0/ask_react_agent")
  984. logger.info("📘 LangGraph Agent API: http://localhost:8084/api/v0/ask_agent")
  985. # 启动标准Flask应用(支持异步路由)
  986. app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)