unified_api.py 243 KB


  1. """
  2. 统一 API 服务
  3. 集成 citu_app.py 指定API 和 react_agent/api.py 的所有功能
  4. 提供数据库问答、Redis对话管理、QA反馈、训练数据管理、React Agent等功能
  5. 使用普通 Flask 应用 + ASGI 包装实现异步支持
  6. """
  7. import asyncio
  8. import logging
  9. import atexit
  10. import os
  11. import sys
  12. import time
  13. from datetime import datetime, timedelta, timezone
  14. import pytz
  15. from typing import Optional, Dict, Any, TYPE_CHECKING, Union
  16. import signal
  17. from threading import Thread
  18. from pathlib import Path
  19. if TYPE_CHECKING:
  20. from react_agent.agent import CustomReactAgent
  21. # 初始化日志系统 - 必须在最前面
  22. from core.logging import initialize_logging, get_app_logger
  23. initialize_logging()
  24. # 标准 Flask 导入
  25. from flask import Flask, request, jsonify, session, send_file, Response, stream_with_context
  26. import redis.asyncio as redis
  27. from werkzeug.utils import secure_filename
  28. # 导入标准化响应格式
  29. from common.result import success_response, internal_error_response, bad_request_response
  30. # 基础依赖
  31. import pandas as pd
  32. import json
  33. import sqlparse
  34. import tempfile
  35. import os
  36. import psycopg2
  37. import re
  38. # 项目模块导入
  39. from core.vanna_llm_factory import create_vanna_instance
  40. from common.redis_conversation_manager import RedisConversationManager
  41. from common.qa_feedback_manager import QAFeedbackManager
  42. # Data Pipeline 相关导入 - 从 citu_app.py 迁移
  43. from data_pipeline.api.simple_workflow import SimpleWorkflowManager, SimpleWorkflowExecutor
  44. from data_pipeline.api.simple_file_manager import SimpleFileManager
  45. from data_pipeline.api.table_inspector_api import TableInspectorAPI
  46. from common.result import (
  47. success_response, bad_request_response, not_found_response, internal_error_response,
  48. error_response, service_unavailable_response,
  49. agent_success_response, agent_error_response,
  50. validation_failed_response
  51. )
  52. from app_config import (
  53. USER_MAX_CONVERSATIONS, CONVERSATION_CONTEXT_COUNT,
  54. DEFAULT_ANONYMOUS_USER, ENABLE_QUESTION_ANSWER_CACHE
  55. )
  56. # 创建标准 Flask 应用
  57. app = Flask(__name__)
  58. # 创建日志记录器
  59. logger = get_app_logger("UnifiedApp")
  60. # React Agent 导入
  61. try:
  62. from react_agent.agent import CustomReactAgent
  63. from react_agent.enhanced_redis_api import get_conversation_detail_from_redis
  64. from react_agent import config as react_agent_config
  65. except ImportError:
  66. try:
  67. from test.custom_react_agent.agent import CustomReactAgent
  68. from test.custom_react_agent.enhanced_redis_api import get_conversation_detail_from_redis
  69. from test.custom_react_agent import config as react_agent_config
  70. except ImportError:
  71. logger.warning("无法导入 CustomReactAgent,React Agent功能将不可用")
  72. CustomReactAgent = None
  73. get_conversation_detail_from_redis = None
  74. react_agent_config = None
  75. # 初始化核心组件
  76. vn = create_vanna_instance()
  77. redis_conversation_manager = RedisConversationManager()
  78. # ==================== React Agent 全局实例管理 ====================
  79. _react_agent_instance: Optional[Any] = None # 同步工具,用于 ask_react_agent
  80. _react_agent_stream_instance: Optional[Any] = None # 异步工具,用于 ask_react_agent_stream
  81. _redis_client: Optional[redis.Redis] = None
  82. def _format_timestamp_to_china_time(timestamp_str):
  83. """将ISO时间戳转换为中国时区的指定格式"""
  84. if not timestamp_str:
  85. return None
  86. try:
  87. # 解析ISO时间戳
  88. dt = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
  89. # 转换为中国时区
  90. china_tz = pytz.timezone('Asia/Shanghai')
  91. china_dt = dt.astimezone(china_tz)
  92. # 格式化为指定格式
  93. return china_dt.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] # 保留3位毫秒
  94. except Exception as e:
  95. logger.warning(f"⚠️ 时间格式化失败: {e}")
  96. return timestamp_str
  97. def _parse_conversation_created_time(conversation_id: str) -> Optional[str]:
  98. """从conversation_id解析创建时间并转换为中国时区格式"""
  99. try:
  100. # conversation_id格式: "wang10:20250717211620915"
  101. if ':' not in conversation_id:
  102. return None
  103. parts = conversation_id.split(':')
  104. if len(parts) < 2:
  105. return None
  106. timestamp_str = parts[1] # "20250717211620915"
  107. # 解析时间戳: YYYYMMDDHHMMSSMMM (17位)
  108. if len(timestamp_str) != 17:
  109. logger.warning(f"⚠️ conversation_id时间戳长度不正确: {timestamp_str}")
  110. return None
  111. year = timestamp_str[:4]
  112. month = timestamp_str[4:6]
  113. day = timestamp_str[6:8]
  114. hour = timestamp_str[8:10]
  115. minute = timestamp_str[10:12]
  116. second = timestamp_str[12:14]
  117. millisecond = timestamp_str[14:17]
  118. # 构造datetime对象
  119. dt = datetime(
  120. int(year), int(month), int(day),
  121. int(hour), int(minute), int(second),
  122. int(millisecond) * 1000 # 毫秒转微秒
  123. )
  124. # 转换为中国时区
  125. china_tz = pytz.timezone('Asia/Shanghai')
  126. # 假设原始时间戳是中国时区
  127. china_dt = china_tz.localize(dt)
  128. # 格式化为要求的格式
  129. formatted_time = china_dt.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] # 保留3位毫秒
  130. return formatted_time
  131. except Exception as e:
  132. logger.warning(f"⚠️ 解析conversation_id时间戳失败: {e}")
  133. return None
  134. def _get_conversation_updated_time(redis_client, thread_id: str) -> Optional[str]:
  135. """获取对话的最后更新时间(从Redis checkpoint数据中的ts字段)"""
  136. try:
  137. # 扫描该thread的所有checkpoint keys
  138. pattern = f"checkpoint:{thread_id}:*"
  139. keys = []
  140. cursor = 0
  141. while True:
  142. cursor, batch = redis_client.scan(cursor=cursor, match=pattern, count=1000)
  143. keys.extend(batch)
  144. if cursor == 0:
  145. break
  146. if not keys:
  147. return None
  148. # 获取最新的checkpoint(按key排序,最大的是最新的)
  149. latest_key = max(keys)
  150. # 检查key类型并获取数据
  151. key_type = redis_client.type(latest_key)
  152. data = None
  153. if key_type == 'string':
  154. data = redis_client.get(latest_key)
  155. elif key_type == 'ReJSON-RL':
  156. # RedisJSON类型
  157. try:
  158. data = redis_client.execute_command('JSON.GET', latest_key)
  159. except Exception as json_error:
  160. logger.error(f"❌ JSON.GET失败: {json_error}")
  161. return None
  162. else:
  163. return None
  164. if not data:
  165. return None
  166. # 解析JSON数据
  167. try:
  168. checkpoint_data = json.loads(data)
  169. except json.JSONDecodeError:
  170. return None
  171. # 检查checkpoint中的ts字段
  172. if ('checkpoint' in checkpoint_data and
  173. isinstance(checkpoint_data['checkpoint'], dict) and
  174. 'ts' in checkpoint_data['checkpoint']):
  175. ts_value = checkpoint_data['checkpoint']['ts']
  176. # 解析ts字段(应该是ISO格式的时间戳)
  177. if isinstance(ts_value, str):
  178. try:
  179. dt = datetime.fromisoformat(ts_value.replace('Z', '+00:00'))
  180. china_tz = pytz.timezone('Asia/Shanghai')
  181. china_dt = dt.astimezone(china_tz)
  182. formatted_time = china_dt.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
  183. return formatted_time
  184. except Exception:
  185. pass
  186. return None
  187. except Exception as e:
  188. logger.warning(f"⚠️ 获取对话更新时间失败: {e}")
  189. return None
  190. def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
  191. """验证请求数据,并支持从thread_id中推断user_id"""
  192. errors = []
  193. # 验证 question(必填)
  194. question = data.get('question', '')
  195. if not question or not question.strip():
  196. errors.append('问题不能为空')
  197. elif len(question) > 2000:
  198. errors.append('问题长度不能超过2000字符')
  199. # 优先获取 thread_id
  200. thread_id = data.get('thread_id') or data.get('conversation_id')
  201. # 获取 user_id,但暂不设置默认值
  202. user_id = data.get('user_id')
  203. # 如果没有传递 user_id,则尝试从 thread_id 中推断
  204. if not user_id:
  205. if thread_id and ':' in thread_id:
  206. inferred_user_id = thread_id.split(':', 1)[0]
  207. if inferred_user_id:
  208. user_id = inferred_user_id
  209. logger.info(f"👤 未提供user_id,从 thread_id '{thread_id}' 中推断出: '{user_id}'")
  210. else:
  211. user_id = 'guest'
  212. else:
  213. user_id = 'guest'
  214. # 验证 user_id 长度
  215. if user_id and len(user_id) > 50:
  216. errors.append('用户ID长度不能超过50字符')
  217. # 用户ID与会话ID一致性校验
  218. if thread_id:
  219. if ':' not in thread_id:
  220. errors.append('会话ID格式无效,期望格式为 user_id:timestamp')
  221. else:
  222. thread_user_id = thread_id.split(':', 1)[0]
  223. if thread_user_id != user_id:
  224. errors.append(f'会话归属验证失败:会话ID [{thread_id}] 不属于当前用户 [{user_id}]')
  225. if errors:
  226. raise ValueError('; '.join(errors))
  227. return {
  228. 'question': question.strip(),
  229. 'user_id': user_id,
  230. 'thread_id': thread_id # 可选,不传则自动生成新会话
  231. }
  232. async def get_react_agent() -> Any:
  233. """获取 React Agent 实例(懒加载)"""
  234. global _react_agent_instance, _redis_client
  235. if _react_agent_instance is None:
  236. if CustomReactAgent is None:
  237. logger.error("❌ CustomReactAgent 未能导入,无法初始化")
  238. raise ImportError("CustomReactAgent 未能导入")
  239. logger.info("🚀 正在异步初始化 Custom React Agent...")
  240. try:
  241. # 使用React Agent配置中的Redis URL
  242. redis_url = react_agent_config.REDIS_URL if react_agent_config else 'redis://localhost:6379'
  243. # 设置环境变量
  244. os.environ['REDIS_URL'] = redis_url
  245. # 初始化共享的Redis客户端
  246. _redis_client = redis.from_url(redis_url, decode_responses=True)
  247. await _redis_client.ping()
  248. logger.info("✅ Redis客户端连接成功")
  249. _react_agent_instance = await CustomReactAgent.create()
  250. logger.info("✅ React Agent 异步初始化完成")
  251. except Exception as e:
  252. logger.error(f"❌ React Agent 异步初始化失败: {e}")
  253. raise
  254. return _react_agent_instance
  255. async def ensure_agent_ready() -> bool:
  256. """异步确保Agent实例可用"""
  257. global _react_agent_instance
  258. if _react_agent_instance is None:
  259. await get_react_agent()
  260. # 测试Agent是否还可用
  261. try:
  262. test_result = await _react_agent_instance.get_user_recent_conversations("__test__", 1)
  263. return True
  264. except Exception as e:
  265. logger.warning(f"⚠️ Agent实例不可用: {e}")
  266. _react_agent_instance = None
  267. await get_react_agent()
  268. return True
  269. async def create_stream_agent_instance():
  270. """为每个流式请求创建新的Agent实例(使用异步工具)"""
  271. if CustomReactAgent is None:
  272. logger.error("❌ CustomReactAgent 未能导入,无法初始化流式Agent")
  273. raise ImportError("CustomReactAgent 未能导入")
  274. logger.info("🚀 正在为流式请求创建新的 React Agent 实例...")
  275. try:
  276. # 创建流式专用 Agent 实例
  277. stream_agent = await CustomReactAgent.create()
  278. # 配置使用异步 SQL 工具
  279. from react_agent.async_sql_tools import async_sql_tools
  280. stream_agent.tools = async_sql_tools
  281. stream_agent.llm_with_tools = stream_agent.llm.bind_tools(async_sql_tools)
  282. logger.info("✅ 流式 React Agent 实例创建完成(配置异步工具)")
  283. return stream_agent
  284. except Exception as e:
  285. logger.error(f"❌ 流式 React Agent 实例创建失败: {e}")
  286. raise
  287. def get_user_conversations_simple_sync(user_id: str, limit: int = 10):
  288. """直接从Redis获取用户对话,测试版本"""
  289. import redis
  290. import json
  291. try:
  292. # 创建Redis连接
  293. if react_agent_config:
  294. redis_client = redis.Redis(
  295. host=react_agent_config.REDIS_HOST,
  296. port=react_agent_config.REDIS_PORT,
  297. db=react_agent_config.REDIS_DB,
  298. password=react_agent_config.REDIS_PASSWORD,
  299. decode_responses=True
  300. )
  301. else:
  302. redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
  303. redis_client.ping()
  304. # 扫描用户的checkpoint keys
  305. pattern = f"checkpoint:{user_id}:*"
  306. logger.info(f"🔍 扫描模式: {pattern}")
  307. keys = []
  308. cursor = 0
  309. while True:
  310. cursor, batch = redis_client.scan(cursor=cursor, match=pattern, count=1000)
  311. keys.extend(batch)
  312. if cursor == 0:
  313. break
  314. logger.info(f"📋 找到 {len(keys)} 个keys")
  315. # 解析thread信息
  316. thread_data = {}
  317. for key in keys:
  318. try:
  319. parts = key.split(':')
  320. if len(parts) >= 4:
  321. thread_id = f"{parts[1]}:{parts[2]}" # user_id:timestamp
  322. timestamp = parts[2]
  323. if thread_id not in thread_data:
  324. thread_data[thread_id] = {
  325. "thread_id": thread_id,
  326. "timestamp": timestamp,
  327. "keys": []
  328. }
  329. thread_data[thread_id]["keys"].append(key)
  330. except Exception as e:
  331. logger.warning(f"解析key失败 {key}: {e}")
  332. continue
  333. logger.info(f"📊 找到 {len(thread_data)} 个thread")
  334. # 按时间戳排序
  335. sorted_threads = sorted(
  336. thread_data.values(),
  337. key=lambda x: x["timestamp"],
  338. reverse=True
  339. )[:limit]
  340. # 获取每个thread的详细信息
  341. conversations = []
  342. for thread_info in sorted_threads:
  343. try:
  344. thread_id = thread_info["thread_id"]
  345. # 获取最新的checkpoint数据
  346. latest_key = max(thread_info["keys"])
  347. # 先检查key的数据类型
  348. key_type = redis_client.type(latest_key)
  349. logger.info(f"🔍 Key {latest_key} 的类型: {key_type}")
  350. data = None
  351. if key_type == 'string':
  352. data = redis_client.get(latest_key)
  353. elif key_type == 'hash':
  354. # 如果是hash类型,获取所有字段
  355. hash_data = redis_client.hgetall(latest_key)
  356. logger.info(f"🔍 Hash字段: {list(hash_data.keys())}")
  357. # 尝试获取可能的数据字段
  358. for field in ['data', 'state', 'value', 'checkpoint']:
  359. if field in hash_data:
  360. data = hash_data[field]
  361. break
  362. if not data and hash_data:
  363. # 如果没找到预期字段,取第一个值试试
  364. data = list(hash_data.values())[0]
  365. elif key_type == 'list':
  366. # 如果是list类型,获取最后一个元素
  367. data = redis_client.lindex(latest_key, -1)
  368. elif key_type == 'ReJSON-RL':
  369. # 这是RedisJSON类型,使用JSON.GET命令
  370. logger.info(f"🔍 使用JSON.GET获取RedisJSON数据")
  371. try:
  372. # 使用JSON.GET命令获取整个JSON对象
  373. json_data = redis_client.execute_command('JSON.GET', latest_key)
  374. if json_data:
  375. data = json_data # JSON.GET返回的就是JSON字符串
  376. logger.info(f"🔍 JSON数据长度: {len(data)} 字符")
  377. else:
  378. logger.warning(f"⚠️ JSON.GET 返回空数据")
  379. continue
  380. except Exception as json_error:
  381. logger.error(f"❌ JSON.GET 失败: {json_error}")
  382. continue
  383. else:
  384. logger.warning(f"⚠️ 未知的key类型: {key_type}")
  385. continue
  386. if data:
  387. try:
  388. checkpoint_data = json.loads(data)
  389. # 调试:查看JSON数据结构
  390. logger.info(f"🔍 JSON顶级keys: {list(checkpoint_data.keys())}")
  391. # 根据您提供的JSON结构,消息在 checkpoint.channel_values.messages
  392. messages = []
  393. # 首先检查是否有checkpoint字段
  394. if 'checkpoint' in checkpoint_data:
  395. checkpoint = checkpoint_data['checkpoint']
  396. if isinstance(checkpoint, dict) and 'channel_values' in checkpoint:
  397. channel_values = checkpoint['channel_values']
  398. if isinstance(channel_values, dict) and 'messages' in channel_values:
  399. messages = channel_values['messages']
  400. logger.info(f"🔍 找到messages: {len(messages)} 条消息")
  401. # 如果没有checkpoint字段,尝试直接在channel_values
  402. if not messages and 'channel_values' in checkpoint_data:
  403. channel_values = checkpoint_data['channel_values']
  404. if isinstance(channel_values, dict) and 'messages' in channel_values:
  405. messages = channel_values['messages']
  406. logger.info(f"🔍 找到messages(直接路径): {len(messages)} 条消息")
  407. # 生成对话预览
  408. preview = "空对话"
  409. if messages:
  410. for msg in messages:
  411. # 处理LangChain消息格式:{"lc": 1, "type": "constructor", "id": ["langchain", "schema", "messages", "HumanMessage"], "kwargs": {"content": "...", "type": "human"}}
  412. if isinstance(msg, dict):
  413. # 检查是否是LangChain格式的HumanMessage
  414. if (msg.get('lc') == 1 and
  415. msg.get('type') == 'constructor' and
  416. 'id' in msg and
  417. isinstance(msg['id'], list) and
  418. len(msg['id']) >= 4 and
  419. msg['id'][3] == 'HumanMessage' and
  420. 'kwargs' in msg):
  421. kwargs = msg['kwargs']
  422. if kwargs.get('type') == 'human' and 'content' in kwargs:
  423. content = str(kwargs['content'])
  424. preview = content[:50] + "..." if len(content) > 50 else content
  425. break
  426. # 兼容其他格式
  427. elif msg.get('type') == 'human' and 'content' in msg:
  428. content = str(msg['content'])
  429. preview = content[:50] + "..." if len(content) > 50 else content
  430. break
  431. # 解析时间戳
  432. created_at = _parse_conversation_created_time(thread_id)
  433. updated_at = _get_conversation_updated_time(redis_client, thread_id)
  434. # 如果无法获取updated_at,使用created_at作为备选
  435. if not updated_at:
  436. updated_at = created_at
  437. conversations.append({
  438. "conversation_id": thread_id, # thread_id -> conversation_id
  439. "user_id": user_id,
  440. "message_count": len(messages),
  441. "conversation_title": preview, # conversation_preview -> conversation_title
  442. "created_at": created_at,
  443. "updated_at": updated_at
  444. })
  445. except json.JSONDecodeError:
  446. logger.error(f"❌ JSON解析失败,数据类型: {type(data)}, 长度: {len(str(data))}")
  447. logger.error(f"❌ 数据开头: {str(data)[:200]}...")
  448. continue
  449. except Exception as e:
  450. logger.error(f"处理thread {thread_info['thread_id']} 失败: {e}")
  451. continue
  452. redis_client.close()
  453. logger.info(f"✅ 返回 {len(conversations)} 个对话")
  454. return conversations
  455. except Exception as e:
  456. logger.error(f"❌ Redis查询失败: {e}")
  457. return []
  458. def cleanup_resources():
  459. """清理资源"""
  460. global _react_agent_instance, _redis_client
  461. async def async_cleanup():
  462. if _react_agent_instance:
  463. await _react_agent_instance.close()
  464. logger.info("✅ React Agent 资源已清理")
  465. if _redis_client:
  466. await _redis_client.aclose()
  467. logger.info("✅ Redis客户端已关闭")
  468. try:
  469. asyncio.run(async_cleanup())
  470. except Exception as e:
  471. logger.error(f"清理资源失败: {e}")
  472. atexit.register(cleanup_resources)
  473. # ==================== 基础路由 ====================
  474. @app.route("/")
  475. def index():
  476. """根路径健康检查"""
  477. return jsonify({"message": "统一API服务正在运行", "version": "1.0.0"})
  478. @app.route('/health', methods=['GET'])
  479. def health_check():
  480. """健康检查端点"""
  481. try:
  482. health_status = {
  483. "status": "healthy",
  484. "react_agent_initialized": _react_agent_instance is not None,
  485. "timestamp": datetime.now().isoformat(),
  486. "services": {
  487. "redis": redis_conversation_manager.is_available(),
  488. "vanna": vn is not None
  489. }
  490. }
  491. return jsonify(health_status), 200
  492. except Exception as e:
  493. logger.error(f"健康检查失败: {e}")
  494. return jsonify({"status": "unhealthy", "error": str(e)}), 500
  495. # ==================== React Agent API ====================
  496. @app.route("/api/v0/ask_react_agent", methods=["POST"])
  497. async def ask_react_agent():
  498. """异步React Agent智能问答接口(从 custom_react_agent 迁移,原路由:/api/chat)"""
  499. global _react_agent_instance
  500. # 确保Agent已初始化
  501. if not await ensure_agent_ready():
  502. return jsonify({
  503. "code": 503,
  504. "message": "服务未就绪",
  505. "success": False,
  506. "error": "React Agent 初始化失败"
  507. }), 503
  508. try:
  509. # 获取请求数据
  510. try:
  511. data = request.get_json(force=True)
  512. except Exception as json_error:
  513. logger.warning(f"⚠️ JSON解析失败: {json_error}")
  514. return jsonify({
  515. "code": 400,
  516. "message": "请求格式错误",
  517. "success": False,
  518. "error": "无效的JSON格式,请检查请求体中是否存在语法错误(如多余的逗号、引号不匹配等)",
  519. "details": str(json_error)
  520. }), 400
  521. if not data:
  522. return jsonify({
  523. "code": 400,
  524. "message": "请求参数错误",
  525. "success": False,
  526. "error": "请求体不能为空"
  527. }), 400
  528. # 验证请求数据
  529. validated_data = validate_request_data(data)
  530. logger.info(f"📨 收到React Agent请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
  531. # 异步调用处理
  532. agent_result = await _react_agent_instance.chat(
  533. message=validated_data['question'],
  534. user_id=validated_data['user_id'],
  535. thread_id=validated_data['thread_id']
  536. )
  537. if not agent_result.get("success", False):
  538. # Agent处理失败
  539. error_msg = agent_result.get("error", "React Agent处理失败")
  540. logger.error(f"❌ React Agent处理失败: {error_msg}")
  541. # 检查是否建议重试
  542. retry_suggested = agent_result.get("retry_suggested", False)
  543. error_code = 503 if retry_suggested else 500
  544. message = "服务暂时不可用,请稍后重试" if retry_suggested else "处理失败"
  545. return jsonify({
  546. "code": error_code,
  547. "message": message,
  548. "success": False,
  549. "error": error_msg,
  550. "retry_suggested": retry_suggested,
  551. "data": {
  552. "conversation_id": agent_result.get("thread_id"),
  553. "user_id": validated_data['user_id'],
  554. "timestamp": datetime.now().isoformat()
  555. }
  556. }), error_code
  557. # Agent处理成功
  558. api_data = agent_result.get("api_data", {})
  559. # 构建响应数据(按照 react_agent/api.py 的正确格式)
  560. response_data = {
  561. "response": api_data.get("response", ""),
  562. "conversation_id": agent_result.get("thread_id"),
  563. "user_id": validated_data['user_id'],
  564. "react_agent_meta": api_data.get("react_agent_meta", {
  565. "thread_id": agent_result.get("thread_id"),
  566. "agent_version": "custom_react_v1_async"
  567. }),
  568. "timestamp": datetime.now().isoformat()
  569. }
  570. # 可选字段:SQL(仅当执行SQL时存在)
  571. if "sql" in api_data:
  572. response_data["sql"] = api_data["sql"]
  573. # 可选字段:records(仅当有查询结果时存在)
  574. if "records" in api_data:
  575. response_data["records"] = api_data["records"]
  576. return jsonify({
  577. "code": 200,
  578. "message": "处理成功",
  579. "success": True,
  580. "data": response_data
  581. }), 200
  582. except ValueError as ve:
  583. # 参数验证错误
  584. logger.warning(f"⚠️ 参数验证失败: {ve}")
  585. return jsonify({
  586. "code": 400,
  587. "message": "参数验证失败",
  588. "success": False,
  589. "error": str(ve)
  590. }), 400
  591. except Exception as e:
  592. logger.error(f"❌ React Agent API 异常: {e}")
  593. return jsonify({
  594. "code": 500,
  595. "message": "内部服务错误",
  596. "success": False,
  597. "error": "服务暂时不可用,请稍后重试"
  598. }), 500
  599. @app.route('/api/v0/ask_react_agent_stream', methods=['GET'])
  600. def ask_react_agent_stream():
  601. """React Agent 流式API - 使用异步工具的专用 Agent 实例
  602. 功能与ask_react_agent完全相同,除了采用流式输出
  603. """
  604. def generate():
  605. try:
  606. # 1. 参数获取和验证(从URL参数,因为EventSource只支持GET)
  607. question = request.args.get('question')
  608. user_id_input = request.args.get('user_id')
  609. thread_id_input = request.args.get('thread_id')
  610. # 参数验证(复用现有validate_request_data逻辑)
  611. if not question:
  612. yield format_sse_error("缺少必需参数:question")
  613. return
  614. # 2. 数据预处理(与ask_react_agent相同)
  615. try:
  616. validated_data = validate_request_data({
  617. 'question': question,
  618. 'user_id': user_id_input,
  619. 'thread_id': thread_id_input
  620. })
  621. except ValueError as ve:
  622. yield format_sse_error(f"参数验证失败: {str(ve)}")
  623. return
  624. logger.info(f"📨 收到React Agent流式请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
  625. # 3. 为当前请求创建新的事件循环和Agent实例
  626. import asyncio
  627. # 创建新的事件循环
  628. loop = asyncio.new_event_loop()
  629. asyncio.set_event_loop(loop)
  630. stream_agent = None
  631. try:
  632. # 为当前请求创建新的Agent实例
  633. stream_agent = loop.run_until_complete(create_stream_agent_instance())
  634. if not stream_agent:
  635. yield format_sse_error("流式 React Agent 初始化失败")
  636. return
  637. except Exception as e:
  638. logger.error(f"流式 Agent 初始化异常: {str(e)}")
  639. yield format_sse_error(f"流式 Agent 初始化失败: {str(e)}")
  640. return
  641. # 4. 在同一个事件循环中执行流式处理
  642. try:
  643. # 创建异步生成器
  644. async def stream_worker():
  645. try:
  646. # 使用当前请求的 Agent 实例(已配置异步工具)
  647. async for chunk in stream_agent.chat_stream(
  648. message=validated_data['question'],
  649. user_id=validated_data['user_id'],
  650. thread_id=validated_data['thread_id']
  651. ):
  652. yield chunk
  653. if chunk.get("type") == "completed":
  654. break
  655. except Exception as e:
  656. logger.error(f"流式处理异常: {str(e)}", exc_info=True)
  657. yield {
  658. "type": "error",
  659. "error": f"流式处理异常: {str(e)}"
  660. }
  661. # 在当前事件循环中运行异步生成器
  662. async_gen = stream_worker()
  663. # 同步迭代异步生成器
  664. while True:
  665. try:
  666. chunk = loop.run_until_complete(async_gen.__anext__())
  667. if chunk["type"] == "progress":
  668. yield format_sse_react_progress(chunk)
  669. elif chunk["type"] == "completed":
  670. yield format_sse_react_completed(chunk)
  671. break
  672. elif chunk["type"] == "error":
  673. yield format_sse_error(chunk.get("error", "未知错误"))
  674. break
  675. except StopAsyncIteration:
  676. break
  677. except Exception as e:
  678. logger.error(f"处理流式数据异常: {str(e)}")
  679. yield format_sse_error(f"处理异常: {str(e)}")
  680. break
  681. except Exception as e:
  682. logger.error(f"React Agent流式处理异常: {str(e)}")
  683. yield format_sse_error(f"流式处理异常: {str(e)}")
  684. finally:
  685. # 清理:流式处理完成后关闭事件循环
  686. try:
  687. loop.close()
  688. except Exception as e:
  689. logger.warning(f"关闭事件循环时出错: {e}")
  690. except Exception as e:
  691. logger.error(f"React Agent流式API异常: {str(e)}")
  692. yield format_sse_error(f"服务异常: {str(e)}")
  693. return Response(stream_with_context(generate()), mimetype='text/event-stream')
  694. @app.route('/api/v0/react/status/<thread_id>', methods=['GET'])
  695. async def get_react_agent_status(thread_id: str):
  696. """获取React Agent执行状态,使用LangGraph API"""
  697. try:
  698. global _react_agent_instance
  699. if not _react_agent_instance:
  700. from common.result import failed
  701. return jsonify(failed(message="Agent实例未初始化", code=500)), 500
  702. # 工具状态映射
  703. TOOL_STATUS_MAPPING = {
  704. "generate_sql": {"name": "生成SQL中", "icon": "🔍"},
  705. "valid_sql": {"name": "验证SQL中", "icon": "✅"},
  706. "run_sql": {"name": "执行查询中", "icon": "⚡"},
  707. }
  708. # 使用LangGraph API获取checkpoint
  709. read_config = {"configurable": {"thread_id": thread_id}}
  710. checkpoint_tuple = await _react_agent_instance.checkpointer.aget_tuple(read_config)
  711. if not checkpoint_tuple or not checkpoint_tuple.checkpoint:
  712. from common.result import failed
  713. return jsonify(failed(message="未找到执行线程", code=404)), 404
  714. # 获取checkpoint数据
  715. checkpoint = checkpoint_tuple.checkpoint
  716. channel_values = checkpoint.get("channel_values", {})
  717. messages = channel_values.get("messages", [])
  718. if not messages:
  719. from common.result import success
  720. return jsonify(success(data={
  721. "status": "running",
  722. "name": "初始化中",
  723. "icon": "🚀",
  724. "timestamp": datetime.now().isoformat()
  725. }, message="获取状态成功"))
  726. # 分析最后一条消息确定状态
  727. last_message = messages[-1]
  728. last_msg_type = last_message.get("type", "") if hasattr(last_message, 'get') else getattr(last_message, 'type', "")
  729. # 如果last_message是对象,需要转换为字典格式
  730. if hasattr(last_message, '__dict__'):
  731. last_message_dict = {
  732. 'type': getattr(last_message, 'type', ''),
  733. 'content': getattr(last_message, 'content', ''),
  734. 'tool_calls': getattr(last_message, 'tool_calls', []) if hasattr(last_message, 'tool_calls') else [],
  735. }
  736. # 如果有additional_kwargs,也包含进来
  737. if hasattr(last_message, 'additional_kwargs'):
  738. last_message_dict.update(last_message.additional_kwargs)
  739. else:
  740. last_message_dict = last_message
  741. # 判断执行状态
  742. if (last_msg_type == "ai" and
  743. not last_message_dict.get("tool_calls", []) and
  744. last_message_dict.get("content", "").strip()):
  745. from common.result import success
  746. return jsonify(success(data={
  747. "status": "completed",
  748. "name": "完成",
  749. "icon": "✅",
  750. "timestamp": datetime.now().isoformat()
  751. }, message="获取状态成功"))
  752. elif (last_msg_type == "ai" and
  753. last_message_dict.get("tool_calls", [])):
  754. tool_calls = last_message_dict.get("tool_calls", [])
  755. tool_name = tool_calls[0].get("name", "") if tool_calls else ""
  756. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
  757. "name": f"调用{tool_name}中" if tool_name else "调用工具中",
  758. "icon": "🔧"
  759. })
  760. from common.result import success
  761. return jsonify(success(data={
  762. "status": "running",
  763. "name": tool_info["name"],
  764. "icon": tool_info["icon"],
  765. "timestamp": datetime.now().isoformat()
  766. }, message="获取状态成功"))
  767. elif last_msg_type == "tool":
  768. tool_name = last_message_dict.get("name", "")
  769. tool_status = last_message_dict.get("status", "")
  770. if tool_status == "success":
  771. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {"name": "处理中", "icon": "🔄"})
  772. from common.result import success
  773. return jsonify(success(data={
  774. "status": "running",
  775. "name": f"{tool_info['name'].replace('中', '')}完成,AI处理中",
  776. "icon": "🤖",
  777. "timestamp": datetime.now().isoformat()
  778. }, message="获取状态成功"))
  779. else:
  780. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
  781. "name": f"执行{tool_name}中",
  782. "icon": "⚙️"
  783. })
  784. from common.result import success
  785. return jsonify(success(data={
  786. "status": "running",
  787. "name": tool_info["name"],
  788. "icon": tool_info["icon"],
  789. "timestamp": datetime.now().isoformat()
  790. }, message="获取状态成功"))
  791. else:
  792. from common.result import success
  793. return jsonify(success(data={
  794. "status": "running",
  795. "name": "执行中",
  796. "icon": "⚙️",
  797. "timestamp": datetime.now().isoformat()
  798. }, message="获取状态成功"))
  799. except Exception as e:
  800. from common.result import failed
  801. logger.error(f"获取React Agent状态失败: {e}")
  802. return jsonify(failed(message=f"获取状态失败: {str(e)}", code=500)), 500
  803. @app.route('/api/v0/react/direct/status/<thread_id>', methods=['GET'])
  804. async def get_react_agent_status_direct(thread_id: str):
  805. """直接访问Redis获取React Agent执行状态,绕过Agent实例资源竞争"""
  806. try:
  807. # 工具状态映射
  808. TOOL_STATUS_MAPPING = {
  809. "generate_sql": {"name": "生成SQL中", "icon": "🔍"},
  810. "valid_sql": {"name": "验证SQL中", "icon": "✅"},
  811. "run_sql": {"name": "执行查询中", "icon": "⚡"},
  812. }
  813. # 创建独立的Redis连接,不使用Agent的连接
  814. redis_client = redis.from_url("redis://localhost:6379", decode_responses=True)
  815. try:
  816. # 1. 查找该thread_id的所有checkpoint键
  817. pattern = f"checkpoint:{thread_id}:*"
  818. keys = await redis_client.keys(pattern)
  819. if not keys:
  820. from common.result import failed
  821. return jsonify(failed(message="未找到执行线程", code=404)), 404
  822. # 2. 获取最新的checkpoint键
  823. latest_key = sorted(keys)[-1]
  824. # 3. 检查Redis key的数据类型
  825. key_type = await redis_client.type(latest_key)
  826. logger.info(f"🔍 Redis key类型: {key_type}, key: {latest_key}")
  827. # 4. 根据数据类型获取checkpoint数据
  828. if key_type == "string":
  829. # 字符串类型,直接使用GET
  830. raw_checkpoint_data = await redis_client.get(latest_key)
  831. if raw_checkpoint_data:
  832. checkpoint = json.loads(raw_checkpoint_data)
  833. else:
  834. from common.result import failed
  835. return jsonify(failed(message="无法读取checkpoint数据", code=500)), 500
  836. elif key_type == "ReJSON-RL":
  837. # RedisJSON类型,使用JSON.GET命令
  838. try:
  839. # 使用execute_command执行JSON.GET
  840. checkpoint = await redis_client.execute_command("JSON.GET", latest_key)
  841. if checkpoint:
  842. # JSON.GET返回的是JSON字符串,需要解析
  843. if isinstance(checkpoint, str):
  844. checkpoint = json.loads(checkpoint)
  845. logger.info(f"✅ 成功从RedisJSON获取checkpoint数据")
  846. else:
  847. from common.result import failed
  848. return jsonify(failed(message="无法读取RedisJSON数据", code=500)), 500
  849. except Exception as json_error:
  850. logger.error(f"❌ RedisJSON操作失败: {json_error}")
  851. from common.result import failed
  852. return jsonify(failed(message=f"RedisJSON操作失败: {str(json_error)}", code=500)), 500
  853. elif key_type == "hash":
  854. # Hash类型,使用HGETALL
  855. hash_data = await redis_client.hgetall(latest_key)
  856. logger.info(f"🔍 Hash数据字段: {list(hash_data.keys())}")
  857. # 尝试不同的字段名获取checkpoint
  858. checkpoint_fields = ['checkpoint', 'data', 'value']
  859. checkpoint_data = None
  860. for field in checkpoint_fields:
  861. if field in hash_data:
  862. checkpoint_data = hash_data[field]
  863. break
  864. if not checkpoint_data:
  865. # 如果没有找到标准字段,返回整个hash结构
  866. checkpoint = {"hash_data": hash_data}
  867. else:
  868. try:
  869. checkpoint = json.loads(checkpoint_data)
  870. except json.JSONDecodeError:
  871. # 如果不是JSON,可能是其他格式
  872. checkpoint = {"raw_data": checkpoint_data}
  873. elif key_type == "list":
  874. # List类型,获取所有元素
  875. list_data = await redis_client.lrange(latest_key, 0, -1)
  876. logger.info(f"🔍 List数据长度: {len(list_data)}")
  877. checkpoint = {"list_data": list_data}
  878. else:
  879. from common.result import failed
  880. return jsonify(failed(message=f"不支持的Redis数据类型: {key_type}", code=500)), 500
  881. # 5. 提取messages
  882. messages = []
  883. # 根据不同的checkpoint结构提取messages
  884. if "checkpoint" in checkpoint and "channel_values" in checkpoint["checkpoint"]:
  885. # 标准checkpoint结构(与您的数据匹配)
  886. messages = checkpoint["checkpoint"]["channel_values"].get("messages", [])
  887. logger.info(f"✅ 从标准checkpoint结构提取到 {len(messages)} 条messages")
  888. elif "channel_values" in checkpoint:
  889. # 直接的channel_values结构
  890. messages = checkpoint["channel_values"].get("messages", [])
  891. logger.info(f"✅ 从直接channel_values结构提取到 {len(messages)} 条messages")
  892. elif "hash_data" in checkpoint:
  893. # Hash数据结构,尝试从不同字段提取
  894. hash_data = checkpoint["hash_data"]
  895. logger.info(f"🔍 Hash字段详情: {list(hash_data.keys())}")
  896. # 尝试解析可能包含messages的字段
  897. for key, value in hash_data.items():
  898. try:
  899. parsed_data = json.loads(value)
  900. if isinstance(parsed_data, dict):
  901. if "channel_values" in parsed_data and "messages" in parsed_data["channel_values"]:
  902. messages = parsed_data["channel_values"]["messages"]
  903. logger.info(f"✅ 从Hash字段 {key} 提取到 {len(messages)} 条messages")
  904. break
  905. elif "messages" in parsed_data:
  906. messages = parsed_data["messages"]
  907. logger.info(f"✅ 从Hash字段 {key} 直接提取到 {len(messages)} 条messages")
  908. break
  909. except (json.JSONDecodeError, TypeError):
  910. continue
  911. elif "list_data" in checkpoint:
  912. # List数据结构
  913. logger.info(f"🔍 List数据: {len(checkpoint['list_data'])} 个元素")
  914. # 如果无法提取messages或为空,返回初始化状态
  915. if not messages:
  916. logger.warning(f"⚠️ 无法从checkpoint中提取messages,checkpoint结构: {list(checkpoint.keys())}")
  917. status_data = {
  918. "status": "running",
  919. "name": "初始化中",
  920. "icon": "🚀",
  921. "timestamp": datetime.now().isoformat(),
  922. "debug_info": {
  923. "key_type": key_type,
  924. "checkpoint_keys": list(checkpoint.keys()),
  925. "has_checkpoint": "checkpoint" in checkpoint,
  926. "has_channel_values": "channel_values" in checkpoint.get("checkpoint", {})
  927. }
  928. }
  929. from common.result import success
  930. return jsonify(success(data=status_data, message="获取状态成功")), 200
  931. # 6. 分析最后一条消息
  932. last_message = messages[-1]
  933. last_msg_type = last_message.get("kwargs", {}).get("type", "")
  934. # 7. 判断执行状态
  935. if (last_msg_type == "ai" and
  936. not last_message.get("kwargs", {}).get("tool_calls", []) and
  937. last_message.get("kwargs", {}).get("content", "").strip()):
  938. # 完成状态:AIMessage有完整回答且无tool_calls
  939. status_data = {
  940. "status": "completed",
  941. "name": "完成",
  942. "icon": "✅",
  943. "timestamp": datetime.now().isoformat()
  944. }
  945. elif (last_msg_type == "ai" and
  946. last_message.get("kwargs", {}).get("tool_calls", [])):
  947. # AI正在调用工具
  948. tool_calls = last_message.get("kwargs", {}).get("tool_calls", [])
  949. tool_name = tool_calls[0].get("name", "") if tool_calls else ""
  950. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
  951. "name": f"调用{tool_name}中" if tool_name else "调用工具中",
  952. "icon": "🔧"
  953. })
  954. status_data = {
  955. "status": "running",
  956. "name": tool_info["name"],
  957. "icon": tool_info["icon"],
  958. "timestamp": datetime.now().isoformat()
  959. }
  960. elif last_msg_type == "tool":
  961. # 工具执行完成,等待AI处理
  962. tool_name = last_message.get("kwargs", {}).get("name", "")
  963. tool_status = last_message.get("kwargs", {}).get("status", "")
  964. if tool_status == "success":
  965. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {"name": "处理中", "icon": "🔄"})
  966. status_data = {
  967. "status": "running",
  968. "name": f"{tool_info['name'].replace('中', '')}完成,AI处理中",
  969. "icon": "🤖",
  970. "timestamp": datetime.now().isoformat()
  971. }
  972. else:
  973. tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
  974. "name": f"执行{tool_name}中",
  975. "icon": "⚙️"
  976. })
  977. status_data = {
  978. "status": "running",
  979. "name": tool_info["name"],
  980. "icon": tool_info["icon"],
  981. "timestamp": datetime.now().isoformat()
  982. }
  983. elif last_msg_type == "human":
  984. # 用户刚提问,AI开始思考
  985. status_data = {
  986. "status": "running",
  987. "name": "AI思考中",
  988. "icon": "🤖",
  989. "timestamp": datetime.now().isoformat()
  990. }
  991. else:
  992. # 默认执行中状态
  993. status_data = {
  994. "status": "running",
  995. "name": "执行中",
  996. "icon": "⚙️",
  997. "timestamp": datetime.now().isoformat()
  998. }
  999. from common.result import success
  1000. return jsonify(success(data=status_data, message="获取状态成功")), 200
  1001. finally:
  1002. await redis_client.aclose()
  1003. except Exception as e:
  1004. logger.error(f"获取React Agent状态失败: {e}")
  1005. from common.result import failed
  1006. return jsonify(failed(message=f"获取状态失败: {str(e)}", code=500)), 500
  1007. # ==================== LangGraph Agent API ====================
  1008. # 全局Agent实例(单例模式)
  1009. citu_langraph_agent = None
  1010. def get_citu_langraph_agent():
  1011. """获取LangGraph Agent实例(懒加载)"""
  1012. global citu_langraph_agent
  1013. if citu_langraph_agent is None:
  1014. try:
  1015. from agent.citu_agent import CituLangGraphAgent
  1016. logger.info("开始创建LangGraph Agent实例...")
  1017. citu_langraph_agent = CituLangGraphAgent()
  1018. logger.info("LangGraph Agent实例创建成功")
  1019. except ImportError as e:
  1020. logger.critical(f"Agent模块导入失败: {str(e)}")
  1021. raise Exception(f"Agent模块导入失败: {str(e)}")
  1022. except Exception as e:
  1023. logger.critical(f"LangGraph Agent实例创建失败: {str(e)}")
  1024. raise Exception(f"Agent初始化失败: {str(e)}")
  1025. return citu_langraph_agent
  1026. @app.route('/api/v0/ask_agent', methods=['POST'])
  1027. def ask_agent():
  1028. """支持对话上下文的ask_agent API"""
  1029. req = request.get_json(force=True)
  1030. question = req.get("question", None)
  1031. user_id_input = req.get("user_id", None)
  1032. conversation_id_input = req.get("conversation_id", None)
  1033. continue_conversation = req.get("continue_conversation", False)
  1034. api_routing_mode = req.get("routing_mode", None)
  1035. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  1036. if not question:
  1037. return jsonify(bad_request_response(
  1038. response_text="缺少必需参数:question",
  1039. missing_params=["question"]
  1040. )), 400
  1041. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  1042. return jsonify(bad_request_response(
  1043. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  1044. invalid_params=["routing_mode"]
  1045. )), 400
  1046. try:
  1047. # 获取登录用户ID
  1048. login_user_id = session.get('user_id') if 'user_id' in session else None
  1049. # 用户ID和对话ID一致性校验
  1050. from common.session_aware_cache import ConversationAwareMemoryCache
  1051. # 如果传递了conversation_id,从中解析user_id
  1052. extracted_user_id = None
  1053. if conversation_id_input:
  1054. extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
  1055. # 如果同时传递了user_id和conversation_id,进行一致性校验
  1056. if user_id_input:
  1057. is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
  1058. conversation_id_input, user_id_input
  1059. )
  1060. if not is_valid:
  1061. return jsonify(bad_request_response(
  1062. response_text=error_msg,
  1063. invalid_params=["user_id", "conversation_id"]
  1064. )), 400
  1065. # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
  1066. elif not user_id_input and extracted_user_id:
  1067. user_id_input = extracted_user_id
  1068. logger.info(f"从conversation_id解析出user_id: {user_id_input}")
  1069. # 如果没有传递user_id,使用默认值guest
  1070. if not user_id_input:
  1071. user_id_input = "guest"
  1072. logger.info("未传递user_id,使用默认值: guest")
  1073. # 智能ID解析
  1074. user_id = redis_conversation_manager.resolve_user_id(
  1075. user_id_input, None, request.remote_addr, login_user_id
  1076. )
  1077. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  1078. user_id, conversation_id_input, continue_conversation
  1079. )
  1080. # 获取上下文和上下文类型(提前到缓存检查之前)
  1081. context = redis_conversation_manager.get_context(conversation_id)
  1082. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  1083. context_type = None
  1084. if context:
  1085. try:
  1086. # 获取最后一条助手消息的metadata
  1087. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit=10)
  1088. for message in reversed(messages): # 从最新的开始找
  1089. if message.get("role") == "assistant":
  1090. metadata = message.get("metadata", {})
  1091. context_type = metadata.get("type")
  1092. if context_type:
  1093. logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
  1094. break
  1095. except Exception as e:
  1096. logger.warning(f"获取上下文类型失败: {str(e)}")
  1097. # 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  1098. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  1099. if cached_answer:
  1100. logger.info(f"[AGENT_API] 使用缓存答案")
  1101. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  1102. cached_response_type = cached_answer.get("type", "UNKNOWN")
  1103. if cached_response_type == "DATABASE":
  1104. # DATABASE类型:按优先级选择内容
  1105. if cached_answer.get("response"):
  1106. # 优先级1:错误或解释性回复(如SQL生成失败)
  1107. assistant_response = cached_answer.get("response")
  1108. elif cached_answer.get("summary"):
  1109. # 优先级2:查询成功的摘要
  1110. assistant_response = cached_answer.get("summary")
  1111. elif cached_answer.get("query_result"):
  1112. # 优先级3:构造简单描述
  1113. query_result = cached_answer.get("query_result")
  1114. row_count = query_result.get("row_count", 0)
  1115. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  1116. else:
  1117. # 异常情况
  1118. assistant_response = "数据库查询已处理。"
  1119. else:
  1120. # CHAT类型:直接使用response
  1121. assistant_response = cached_answer.get("response", "")
  1122. # 更新对话历史
  1123. redis_conversation_manager.save_message(conversation_id, "user", question)
  1124. redis_conversation_manager.save_message(
  1125. conversation_id, "assistant",
  1126. assistant_response,
  1127. metadata={"from_cache": True}
  1128. )
  1129. # 添加对话信息到缓存结果
  1130. cached_answer["conversation_id"] = conversation_id
  1131. cached_answer["user_id"] = user_id
  1132. cached_answer["from_cache"] = True
  1133. cached_answer.update(conversation_status)
  1134. # 使用agent_success_response返回标准格式
  1135. return jsonify(agent_success_response(
  1136. response_type=cached_answer.get("type", "UNKNOWN"),
  1137. response=cached_answer.get("response", ""),
  1138. sql=cached_answer.get("sql"),
  1139. records=cached_answer.get("query_result"),
  1140. summary=cached_answer.get("summary"),
  1141. conversation_id=conversation_id,
  1142. execution_path=cached_answer.get("execution_path", []),
  1143. classification_info=cached_answer.get("classification_info", {}),
  1144. user_id=user_id,
  1145. context_used=bool(context),
  1146. from_cache=True,
  1147. conversation_status=conversation_status["status"],
  1148. requested_conversation_id=conversation_status.get("requested_id")
  1149. ))
  1150. # 保存用户消息
  1151. redis_conversation_manager.save_message(conversation_id, "user", question)
  1152. # 构建带上下文的问题
  1153. if context:
  1154. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  1155. logger.info(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  1156. else:
  1157. enhanced_question = question
  1158. logger.info(f"[AGENT_API] 新对话,无上下文")
  1159. # 确定最终使用的路由模式(优先级逻辑)
  1160. if api_routing_mode:
  1161. # API传了参数,优先使用
  1162. effective_routing_mode = api_routing_mode
  1163. logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
  1164. else:
  1165. # API没传参数,使用配置文件
  1166. try:
  1167. from app_config import QUESTION_ROUTING_MODE
  1168. effective_routing_mode = QUESTION_ROUTING_MODE
  1169. logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
  1170. except ImportError:
  1171. effective_routing_mode = "hybrid"
  1172. logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  1173. # Agent处理
  1174. try:
  1175. agent = get_citu_langraph_agent()
  1176. except Exception as e:
  1177. logger.critical(f"Agent初始化失败: {str(e)}")
  1178. return jsonify(service_unavailable_response(
  1179. response_text="AI服务暂时不可用,请稍后重试",
  1180. can_retry=True
  1181. )), 503
  1182. # 异步调用Agent处理问题
  1183. import asyncio
  1184. agent_result = asyncio.run(agent.process_question(
  1185. question=enhanced_question, # 使用增强后的问题
  1186. conversation_id=conversation_id,
  1187. context_type=context_type, # 传递上下文类型
  1188. routing_mode=effective_routing_mode # 新增:传递路由模式
  1189. ))
  1190. # 处理Agent结果
  1191. if agent_result.get("success", False):
  1192. response_type = agent_result.get("type", "UNKNOWN")
  1193. response_text = agent_result.get("response", "")
  1194. sql = agent_result.get("sql")
  1195. query_result = agent_result.get("query_result")
  1196. summary = agent_result.get("summary")
  1197. execution_path = agent_result.get("execution_path", [])
  1198. classification_info = agent_result.get("classification_info", {})
  1199. # 确定助手回复内容的优先级
  1200. if response_type == "DATABASE":
  1201. if response_text:
  1202. assistant_response = response_text
  1203. elif summary:
  1204. assistant_response = summary
  1205. elif query_result:
  1206. row_count = query_result.get("row_count", 0)
  1207. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  1208. else:
  1209. assistant_response = "数据库查询已处理。"
  1210. else:
  1211. assistant_response = response_text
  1212. # 保存助手回复
  1213. redis_conversation_manager.save_message(
  1214. conversation_id, "assistant", assistant_response,
  1215. metadata={
  1216. "type": response_type,
  1217. "sql": sql,
  1218. "execution_path": execution_path
  1219. }
  1220. )
  1221. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  1222. # 直接缓存agent_result,它已经包含所有需要的字段
  1223. redis_conversation_manager.cache_answer(question, agent_result, context)
  1224. # 使用agent_success_response的正确方式
  1225. return jsonify(agent_success_response(
  1226. response_type=response_type,
  1227. response=response_text,
  1228. sql=sql,
  1229. records=query_result,
  1230. summary=summary,
  1231. conversation_id=conversation_id,
  1232. execution_path=execution_path,
  1233. classification_info=classification_info,
  1234. user_id=user_id,
  1235. context_used=bool(context),
  1236. from_cache=False,
  1237. conversation_status=conversation_status["status"],
  1238. requested_conversation_id=conversation_status.get("requested_id"),
  1239. routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
  1240. routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
  1241. ))
  1242. else:
  1243. # 错误处理
  1244. error_message = agent_result.get("error", "Agent处理失败")
  1245. error_code = agent_result.get("error_code", 500)
  1246. return jsonify(agent_error_response(
  1247. response_text=error_message,
  1248. error_type="agent_processing_failed",
  1249. code=error_code,
  1250. conversation_id=conversation_id,
  1251. user_id=user_id
  1252. )), error_code
  1253. except Exception as e:
  1254. logger.error(f"ask_agent执行失败: {str(e)}")
  1255. return jsonify(internal_error_response(
  1256. response_text="查询处理失败,请稍后重试"
  1257. )), 500
  1258. @app.route('/api/v0/ask_agent_stream', methods=['GET'])
  1259. def ask_agent_stream():
  1260. """Citu Agent 流式API - 支持实时进度显示(EventSource只支持GET请求)
  1261. 功能与ask_agent完全相同,除了采用流式输出"""
  1262. def generate():
  1263. try:
  1264. # 从URL参数获取数据(EventSource只支持GET请求)
  1265. question = request.args.get('question')
  1266. user_id_input = request.args.get('user_id')
  1267. conversation_id_input = request.args.get('conversation_id')
  1268. continue_conversation = request.args.get('continue_conversation', 'false').lower() == 'true'
  1269. api_routing_mode = request.args.get('routing_mode')
  1270. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  1271. # 参数验证
  1272. if not question:
  1273. yield format_sse_error("缺少必需参数:question")
  1274. return
  1275. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  1276. yield format_sse_error(f"无效的routing_mode参数值: {api_routing_mode}")
  1277. return
  1278. # 🆕 用户ID和对话ID一致性校验(与ask_agent相同)
  1279. try:
  1280. # 获取登录用户ID
  1281. login_user_id = session.get('user_id') if 'user_id' in session else None
  1282. # 用户ID和对话ID一致性校验
  1283. from common.session_aware_cache import ConversationAwareMemoryCache
  1284. # 如果传递了conversation_id,从中解析user_id
  1285. extracted_user_id = None
  1286. if conversation_id_input:
  1287. extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
  1288. # 如果同时传递了user_id和conversation_id,进行一致性校验
  1289. if user_id_input:
  1290. is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
  1291. conversation_id_input, user_id_input
  1292. )
  1293. if not is_valid:
  1294. yield format_sse_error(error_msg)
  1295. return
  1296. # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
  1297. elif not user_id_input and extracted_user_id:
  1298. user_id_input = extracted_user_id
  1299. logger.info(f"从conversation_id解析出user_id: {user_id_input}")
  1300. # 如果没有传递user_id,使用默认值guest
  1301. if not user_id_input:
  1302. user_id_input = "guest"
  1303. logger.info("未传递user_id,使用默认值: guest")
  1304. # 🆕 智能ID解析(与ask_agent相同)
  1305. user_id = redis_conversation_manager.resolve_user_id(
  1306. user_id_input, None, request.remote_addr, login_user_id
  1307. )
  1308. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  1309. user_id, conversation_id_input, continue_conversation
  1310. )
  1311. except Exception as e:
  1312. logger.error(f"用户ID和对话ID解析失败: {str(e)}")
  1313. yield format_sse_error(f"参数解析失败: {str(e)}")
  1314. return
  1315. logger.info(f"[STREAM_API] 收到请求 - 问题: {question[:50]}..., 用户: {user_id}, 对话: {conversation_id}")
  1316. # 🆕 获取上下文和上下文类型(与ask_agent相同)
  1317. context = redis_conversation_manager.get_context(conversation_id)
  1318. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  1319. context_type = None
  1320. if context:
  1321. try:
  1322. # 获取最后一条助手消息的metadata
  1323. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit=10)
  1324. for message in reversed(messages): # 从最新的开始找
  1325. if message.get("role") == "assistant":
  1326. metadata = message.get("metadata", {})
  1327. context_type = metadata.get("type")
  1328. if context_type:
  1329. logger.info(f"[STREAM_API] 检测到上下文类型: {context_type}")
  1330. break
  1331. except Exception as e:
  1332. logger.warning(f"获取上下文类型失败: {str(e)}")
  1333. # 🆕 检查缓存(与ask_agent相同)
  1334. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  1335. if cached_answer:
  1336. logger.info(f"[STREAM_API] 使用缓存答案")
  1337. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  1338. cached_response_type = cached_answer.get("type", "UNKNOWN")
  1339. if cached_response_type == "DATABASE":
  1340. # DATABASE类型:按优先级选择内容
  1341. if cached_answer.get("response"):
  1342. assistant_response = cached_answer.get("response")
  1343. elif cached_answer.get("summary"):
  1344. assistant_response = cached_answer.get("summary")
  1345. elif cached_answer.get("query_result"):
  1346. query_result = cached_answer.get("query_result")
  1347. row_count = query_result.get("row_count", 0)
  1348. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  1349. else:
  1350. assistant_response = "查询处理完成"
  1351. else:
  1352. assistant_response = cached_answer.get("response", "处理完成")
  1353. # 返回缓存结果的SSE格式
  1354. yield format_sse_completed({
  1355. "type": "completed",
  1356. "result": {
  1357. "success": True,
  1358. "type": cached_response_type,
  1359. "response": assistant_response,
  1360. "sql": cached_answer.get("sql"),
  1361. "query_result": cached_answer.get("query_result"),
  1362. "summary": cached_answer.get("summary"),
  1363. "conversation_id": conversation_id,
  1364. "execution_path": cached_answer.get("execution_path", []),
  1365. "classification_info": cached_answer.get("classification_info", {}),
  1366. "user_id": user_id,
  1367. "context_used": bool(context),
  1368. "from_cache": True,
  1369. "conversation_status": conversation_status["status"],
  1370. "requested_conversation_id": conversation_status.get("requested_id")
  1371. }
  1372. })
  1373. return
  1374. # 🆕 保存用户消息(与ask_agent相同)
  1375. redis_conversation_manager.save_message(conversation_id, "user", question)
  1376. # 🆕 构建带上下文的问题(与ask_agent相同)
  1377. if context:
  1378. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  1379. logger.info(f"[STREAM_API] 使用上下文,长度: {len(context)}字符")
  1380. else:
  1381. enhanced_question = question
  1382. logger.info(f"[STREAM_API] 新对话,无上下文")
  1383. # 获取Agent实例
  1384. try:
  1385. agent = get_citu_langraph_agent()
  1386. if not agent:
  1387. yield format_sse_error("Agent实例获取失败")
  1388. return
  1389. # 检查是否有process_question_stream方法
  1390. if not hasattr(agent, 'process_question_stream'):
  1391. yield format_sse_error("Agent不支持流式处理")
  1392. return
  1393. except Exception as e:
  1394. logger.error(f"Agent初始化失败: {str(e)}")
  1395. yield format_sse_error("AI服务暂时不可用,请稍后重试")
  1396. return
  1397. # 🆕 确定最终使用的路由模式(与ask_agent相同)
  1398. if api_routing_mode:
  1399. # API传了参数,优先使用
  1400. effective_routing_mode = api_routing_mode
  1401. logger.info(f"[STREAM_API] 使用API指定的路由模式: {effective_routing_mode}")
  1402. else:
  1403. # API没传参数,使用配置文件
  1404. try:
  1405. from app_config import QUESTION_ROUTING_MODE
  1406. effective_routing_mode = QUESTION_ROUTING_MODE
  1407. logger.info(f"[STREAM_API] 使用配置文件路由模式: {effective_routing_mode}")
  1408. except ImportError:
  1409. effective_routing_mode = "hybrid"
  1410. logger.info(f"[STREAM_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  1411. # 流式处理 - 实时转发
  1412. try:
  1413. import asyncio
  1414. # 获取当前事件循环,如果没有则创建新的
  1415. try:
  1416. loop = asyncio.get_event_loop()
  1417. except RuntimeError:
  1418. loop = asyncio.new_event_loop()
  1419. asyncio.set_event_loop(loop)
  1420. # 用于收集最终结果,以便保存到Redis
  1421. final_result = None
  1422. # 异步生成器,实时yield数据
  1423. async def stream_generator():
  1424. nonlocal final_result
  1425. try:
  1426. async for chunk in agent.process_question_stream(
  1427. question=enhanced_question, # 🆕 使用增强后的问题
  1428. user_id=user_id,
  1429. conversation_id=conversation_id,
  1430. context_type=context_type, # 🆕 传递上下文类型
  1431. routing_mode=effective_routing_mode
  1432. ):
  1433. # 如果是完成的chunk,保存最终结果
  1434. if chunk.get("type") == "completed":
  1435. final_result = chunk.get("result")
  1436. yield chunk
  1437. except Exception as e:
  1438. logger.error(f"流式处理异常: {str(e)}")
  1439. yield {"type": "error", "error": str(e)}
  1440. # 同步包装器,实时转发数据
  1441. def sync_stream_wrapper():
  1442. # 创建异步任务
  1443. async_gen = stream_generator()
  1444. while True:
  1445. try:
  1446. # 获取下一个chunk
  1447. chunk = loop.run_until_complete(async_gen.__anext__())
  1448. if chunk["type"] == "progress":
  1449. yield format_sse_progress(chunk)
  1450. elif chunk["type"] == "completed":
  1451. yield format_sse_completed(chunk)
  1452. break # 完成后退出循环
  1453. elif chunk["type"] == "error":
  1454. yield format_sse_error(chunk.get("error", "未知错误"))
  1455. break # 错误后退出循环
  1456. except StopAsyncIteration:
  1457. # 异步生成器结束
  1458. break
  1459. except Exception as e:
  1460. logger.error(f"流式转发异常: {str(e)}")
  1461. yield format_sse_error(f"流式处理异常: {str(e)}")
  1462. break
  1463. # 返回同步生成器
  1464. yield from sync_stream_wrapper()
  1465. # 🆕 保存助手消息和缓存结果(与ask_agent相同)
  1466. if final_result and final_result.get("success", False):
  1467. try:
  1468. response_type = final_result.get("type", "UNKNOWN")
  1469. response_text = final_result.get("response", "")
  1470. sql = final_result.get("sql")
  1471. query_result = final_result.get("query_result")
  1472. summary = final_result.get("summary")
  1473. execution_path = final_result.get("execution_path", [])
  1474. classification_info = final_result.get("classification_info", {})
  1475. # 确定助手回复内容的优先级
  1476. if response_type == "DATABASE":
  1477. if response_text:
  1478. assistant_response = response_text
  1479. elif summary:
  1480. assistant_response = summary
  1481. elif query_result:
  1482. row_count = query_result.get("row_count", 0)
  1483. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  1484. else:
  1485. assistant_response = "查询处理完成"
  1486. else:
  1487. assistant_response = response_text or "处理完成"
  1488. # 保存助手消息
  1489. metadata = {
  1490. "type": response_type,
  1491. "sql": sql,
  1492. "execution_path": execution_path,
  1493. "classification_info": classification_info
  1494. }
  1495. redis_conversation_manager.save_message(
  1496. conversation_id, "assistant", assistant_response, metadata
  1497. )
  1498. # 缓存结果(仅缓存成功的结果)- 与ask_agent相同的调用方式
  1499. redis_conversation_manager.cache_answer(question, final_result, context)
  1500. logger.info(f"[STREAM_API] 结果已缓存")
  1501. except Exception as e:
  1502. logger.error(f"保存结果和缓存失败: {str(e)}")
  1503. except Exception as e:
  1504. logger.error(f"流式处理异常: {str(e)}")
  1505. import traceback
  1506. traceback.print_exc()
  1507. yield format_sse_error(f"处理异常: {str(e)}")
  1508. except Exception as e:
  1509. logger.error(f"流式API异常: {str(e)}")
  1510. yield format_sse_error(f"服务异常: {str(e)}")
  1511. return Response(stream_with_context(generate()), mimetype='text/event-stream')
  1512. def format_sse_progress(chunk: dict) -> str:
  1513. """格式化进度事件为SSE格式"""
  1514. progress = chunk.get("progress", {})
  1515. node = chunk.get("node")
  1516. # 🆕 特殊处理:格式化响应节点的显示内容
  1517. if node == "format_response":
  1518. display_name = "格式化响应结果"
  1519. message = "正在执行:格式化响应结果"
  1520. else:
  1521. display_name = progress.get("display_name")
  1522. message = f"正在执行: {progress.get('display_name', '处理中')}"
  1523. data = {
  1524. "code": 200,
  1525. "success": True,
  1526. "message": message,
  1527. "data": {
  1528. "type": "progress",
  1529. "node": node,
  1530. "display_name": display_name,
  1531. # 🆕 删除icon字段
  1532. "details": progress.get("details"),
  1533. "sub_status": progress.get("sub_status"),
  1534. "conversation_id": chunk.get("conversation_id"),
  1535. "timestamp": datetime.now().isoformat()
  1536. }
  1537. }
  1538. import json
  1539. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1540. def format_sse_completed(chunk: dict) -> str:
  1541. """格式化完成事件为SSE格式"""
  1542. result = chunk.get("result", {})
  1543. data = {
  1544. "code": 200,
  1545. "success": True,
  1546. "message": "处理完成",
  1547. "data": {
  1548. "type": "completed",
  1549. "response": result.get("response", ""),
  1550. "response_type": result.get("type", "UNKNOWN"),
  1551. "sql": result.get("sql"),
  1552. "query_result": result.get("query_result"),
  1553. "summary": result.get("summary"),
  1554. "conversation_id": chunk.get("conversation_id"),
  1555. "execution_path": result.get("execution_path", []),
  1556. "classification_info": result.get("classification_info", {}),
  1557. "timestamp": datetime.now().isoformat()
  1558. }
  1559. }
  1560. import json
  1561. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1562. def format_sse_react_progress(chunk: dict) -> str:
  1563. """格式化React Agent进度事件为SSE格式"""
  1564. node = chunk.get("node")
  1565. thread_id = chunk.get("thread_id")
  1566. node_data = chunk.get("data", {})
  1567. # 基础节点显示名称映射
  1568. base_node_display_map = {
  1569. "__start__": "开始处理",
  1570. "trim_messages": "检查问题上下文",
  1571. "agent": "AI分析中",
  1572. "prepare_tool_input": "准备工具参数",
  1573. "tools": "执行操作",
  1574. "update_state_after_tool": "检查结果",
  1575. "format_final_response": "生成回答",
  1576. "__end__": "完成"
  1577. }
  1578. # 工具名称映射
  1579. tool_display_map = {
  1580. "generate_sql": "生成SQL语句",
  1581. "valid_sql": "验证SQL语法",
  1582. "run_sql": "执行SQL查询"
  1583. }
  1584. display_name = base_node_display_map.get(node, "处理中")
  1585. tool_name = None
  1586. # 特殊处理:提取具体工具信息
  1587. if node_data and "messages" in node_data and node_data["messages"]:
  1588. messages = node_data["messages"]
  1589. last_message = messages[-1]
  1590. # 方法1:从 AIMessage 的 tool_calls 中提取(agent节点输出)
  1591. if (hasattr(last_message, 'tool_calls') and
  1592. last_message.tool_calls and
  1593. len(last_message.tool_calls) > 0):
  1594. tool_call = last_message.tool_calls[0]
  1595. tool_name = tool_call.get('name')
  1596. # 方法2:从 ToolMessage 的 name 属性中提取(tools节点输出)
  1597. elif (hasattr(last_message, 'name') and
  1598. last_message.name):
  1599. tool_name = last_message.name
  1600. # 根据节点和工具信息生成更精确的显示名称
  1601. if tool_name and tool_name in tool_display_map:
  1602. if node == "agent":
  1603. display_name = f"准备{tool_display_map[tool_name]}"
  1604. elif node == "tools":
  1605. display_name = tool_display_map[tool_name] # 去掉"正在"前缀
  1606. elif node == "update_state_after_tool":
  1607. display_name = "检查结果" # 统一为"检查结果"
  1608. elif node == "prepare_tool_input":
  1609. display_name = "准备工具参数" # 统一为通用描述
  1610. # 构建响应数据
  1611. data = {
  1612. "code": 200,
  1613. "success": True,
  1614. "message": f"正在执行: {display_name}",
  1615. "data": {
  1616. "type": "progress",
  1617. "node": node,
  1618. "display_name": display_name,
  1619. "thread_id": thread_id,
  1620. "timestamp": datetime.now().isoformat()
  1621. }
  1622. }
  1623. # 可选:在调试模式下添加工具信息
  1624. if tool_name:
  1625. data["data"]["tool_name"] = tool_name
  1626. import json
  1627. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1628. def format_sse_react_completed(chunk: dict) -> str:
  1629. """格式化React Agent完成事件为SSE格式"""
  1630. result = chunk.get("result", {})
  1631. api_data = result.get("api_data", {})
  1632. thread_id = result.get("thread_id")
  1633. # 构建与ask_react_agent相同的响应格式
  1634. response_data = {
  1635. "response": api_data.get("response", ""),
  1636. "conversation_id": thread_id,
  1637. "user_id": api_data.get("react_agent_meta", {}).get("user_id", ""),
  1638. "react_agent_meta": api_data.get("react_agent_meta", {
  1639. "thread_id": thread_id,
  1640. "agent_version": "custom_react_v1_async"
  1641. }),
  1642. "timestamp": datetime.now().isoformat()
  1643. }
  1644. # 可选字段
  1645. if "sql" in api_data:
  1646. response_data["sql"] = api_data["sql"]
  1647. if "records" in api_data:
  1648. response_data["records"] = api_data["records"]
  1649. data = {
  1650. "code": 200,
  1651. "success": True,
  1652. "message": "处理完成",
  1653. "data": {
  1654. "type": "completed",
  1655. **response_data
  1656. }
  1657. }
  1658. import json
  1659. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1660. def format_sse_error(error_message: str) -> str:
  1661. """格式化错误事件为SSE格式"""
  1662. data = {
  1663. "code": 500,
  1664. "success": False,
  1665. "message": "处理失败",
  1666. "data": {
  1667. "type": "error",
  1668. "error": error_message,
  1669. "timestamp": datetime.now().isoformat()
  1670. }
  1671. }
  1672. import json
  1673. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1674. def format_sse_data(data: dict) -> str:
  1675. """格式化普通数据事件为SSE格式"""
  1676. import json
  1677. return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
  1678. # ==================== QA反馈系统API ====================
  1679. qa_feedback_manager = None
  1680. def get_qa_feedback_manager():
  1681. """获取QA反馈管理器实例(懒加载)"""
  1682. global qa_feedback_manager
  1683. if qa_feedback_manager is None:
  1684. try:
  1685. qa_feedback_manager = QAFeedbackManager(vanna_instance=vn)
  1686. logger.info("QA反馈管理器实例创建成功")
  1687. except Exception as e:
  1688. logger.critical(f"QA反馈管理器创建失败: {str(e)}")
  1689. raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
  1690. return qa_feedback_manager
  1691. @app.route('/api/v0/qa_feedback/query', methods=['POST'])
  1692. def qa_feedback_query():
  1693. """
  1694. 查询反馈记录API
  1695. 支持分页、筛选和排序功能
  1696. """
  1697. try:
  1698. req = request.get_json(force=True)
  1699. # 解析参数,设置默认值
  1700. page = req.get('page', 1)
  1701. page_size = req.get('page_size', 20)
  1702. is_thumb_up = req.get('is_thumb_up')
  1703. create_time_start = req.get('create_time_start')
  1704. create_time_end = req.get('create_time_end')
  1705. is_in_training_data = req.get('is_in_training_data')
  1706. sort_by = req.get('sort_by', 'create_time')
  1707. sort_order = req.get('sort_order', 'desc')
  1708. # 参数验证
  1709. if page < 1:
  1710. return jsonify(bad_request_response(
  1711. response_text="页码必须大于0",
  1712. invalid_params=["page"]
  1713. )), 400
  1714. if page_size < 1 or page_size > 100:
  1715. return jsonify(bad_request_response(
  1716. response_text="每页大小必须在1-100之间",
  1717. invalid_params=["page_size"]
  1718. )), 400
  1719. # 获取反馈管理器并查询
  1720. manager = get_qa_feedback_manager()
  1721. records, total = manager.query_feedback(
  1722. page=page,
  1723. page_size=page_size,
  1724. is_thumb_up=is_thumb_up,
  1725. create_time_start=create_time_start,
  1726. create_time_end=create_time_end,
  1727. is_in_training_data=is_in_training_data,
  1728. sort_by=sort_by,
  1729. sort_order=sort_order
  1730. )
  1731. total_pages = (total + page_size - 1) // page_size
  1732. return jsonify(success_response(
  1733. response_text=f"查询成功,共找到 {total} 条记录",
  1734. data={
  1735. "records": records,
  1736. "pagination": {
  1737. "page": page,
  1738. "page_size": page_size,
  1739. "total": total,
  1740. "total_pages": total_pages,
  1741. "has_next": page < total_pages,
  1742. "has_prev": page > 1
  1743. }
  1744. }
  1745. ))
  1746. except Exception as e:
  1747. logger.error(f"qa_feedback_query执行失败: {str(e)}")
  1748. return jsonify(internal_error_response(
  1749. response_text="查询反馈记录失败,请稍后重试"
  1750. )), 500
  1751. @app.route('/api/v0/qa_feedback/delete/<int:feedback_id>', methods=['DELETE'])
  1752. def qa_feedback_delete(feedback_id):
  1753. """删除反馈记录API"""
  1754. try:
  1755. manager = get_qa_feedback_manager()
  1756. success = manager.delete_feedback(feedback_id)
  1757. if success:
  1758. return jsonify(success_response(
  1759. response_text=f"反馈记录删除成功",
  1760. data={"deleted_id": feedback_id}
  1761. ))
  1762. else:
  1763. return jsonify(not_found_response(
  1764. response_text=f"反馈记录不存在 (ID: {feedback_id})"
  1765. )), 404
  1766. except Exception as e:
  1767. logger.error(f"qa_feedback_delete执行失败: {str(e)}")
  1768. return jsonify(internal_error_response(
  1769. response_text="删除反馈记录失败,请稍后重试"
  1770. )), 500
  1771. @app.route('/api/v0/qa_feedback/update/<int:feedback_id>', methods=['PUT'])
  1772. def qa_feedback_update(feedback_id):
  1773. """更新反馈记录API"""
  1774. try:
  1775. req = request.get_json(force=True)
  1776. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  1777. update_data = {}
  1778. for field in allowed_fields:
  1779. if field in req:
  1780. update_data[field] = req[field]
  1781. if not update_data:
  1782. return jsonify(bad_request_response(
  1783. response_text="没有提供有效的更新字段",
  1784. missing_params=allowed_fields
  1785. )), 400
  1786. manager = get_qa_feedback_manager()
  1787. success = manager.update_feedback(feedback_id, **update_data)
  1788. if success:
  1789. return jsonify(success_response(
  1790. response_text="反馈记录更新成功",
  1791. data={
  1792. "updated_id": feedback_id,
  1793. "updated_fields": list(update_data.keys())
  1794. }
  1795. ))
  1796. else:
  1797. return jsonify(not_found_response(
  1798. response_text=f"反馈记录不存在或无变化 (ID: {feedback_id})"
  1799. )), 404
  1800. except Exception as e:
  1801. logger.error(f"qa_feedback_update执行失败: {str(e)}")
  1802. return jsonify(internal_error_response(
  1803. response_text="更新反馈记录失败,请稍后重试"
  1804. )), 500
  1805. @app.route('/api/v0/qa_feedback/add_to_training', methods=['POST'])
  1806. def qa_feedback_add_to_training():
  1807. """
  1808. 将反馈记录添加到训练数据集API
  1809. 支持混合批量处理:正向反馈加入SQL训练集,负向反馈加入error_sql训练集
  1810. """
  1811. try:
  1812. req = request.get_json(force=True)
  1813. feedback_ids = req.get('feedback_ids', [])
  1814. if not feedback_ids or not isinstance(feedback_ids, list):
  1815. return jsonify(bad_request_response(
  1816. response_text="缺少有效的反馈ID列表",
  1817. missing_params=["feedback_ids"]
  1818. )), 400
  1819. manager = get_qa_feedback_manager()
  1820. # 获取反馈记录
  1821. records = manager.get_feedback_by_ids(feedback_ids)
  1822. if not records:
  1823. return jsonify(not_found_response(
  1824. response_text="未找到任何有效的反馈记录"
  1825. )), 404
  1826. # 分别处理正向和负向反馈
  1827. positive_count = 0 # 正向训练计数
  1828. negative_count = 0 # 负向训练计数
  1829. already_trained_count = 0 # 已训练计数
  1830. error_count = 0 # 错误计数
  1831. successfully_trained_ids = [] # 成功训练的ID列表
  1832. for record in records:
  1833. try:
  1834. # 检查是否已经在训练数据中
  1835. if record['is_in_training_data']:
  1836. already_trained_count += 1
  1837. continue
  1838. if record['is_thumb_up']:
  1839. # 正向反馈 - 加入标准SQL训练集
  1840. training_id = vn.train(
  1841. question=record['question'],
  1842. sql=record['sql']
  1843. )
  1844. positive_count += 1
  1845. logger.info(f"正向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1846. else:
  1847. # 负向反馈 - 加入错误SQL训练集
  1848. training_id = vn.train_error_sql(
  1849. question=record['question'],
  1850. sql=record['sql']
  1851. )
  1852. negative_count += 1
  1853. logger.info(f"负向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1854. successfully_trained_ids.append(record['id'])
  1855. except Exception as e:
  1856. logger.error(f"训练失败 - 反馈ID: {record['id']}, 错误: {e}")
  1857. error_count += 1
  1858. # 更新训练状态
  1859. if successfully_trained_ids:
  1860. updated_count = manager.mark_training_status(successfully_trained_ids, True)
  1861. logger.info(f"批量更新训练状态完成,影响 {updated_count} 条记录")
  1862. # 构建响应
  1863. total_processed = positive_count + negative_count + already_trained_count + error_count
  1864. return jsonify(success_response(
  1865. response_text=f"训练数据添加完成,成功处理 {positive_count + negative_count} 条记录",
  1866. data={
  1867. "summary": {
  1868. "total_requested": len(feedback_ids),
  1869. "total_processed": total_processed,
  1870. "positive_trained": positive_count,
  1871. "negative_trained": negative_count,
  1872. "already_trained": already_trained_count,
  1873. "errors": error_count
  1874. },
  1875. "successfully_trained_ids": successfully_trained_ids,
  1876. "training_details": {
  1877. "sql_training_count": positive_count,
  1878. "error_sql_training_count": negative_count
  1879. }
  1880. }
  1881. ))
  1882. except Exception as e:
  1883. logger.error(f"qa_feedback_add_to_training执行失败: {str(e)}")
  1884. return jsonify(internal_error_response(
  1885. response_text="添加训练数据失败,请稍后重试"
  1886. )), 500
  1887. @app.route('/api/v0/qa_feedback/add', methods=['POST'])
  1888. def qa_feedback_add():
  1889. """
  1890. 添加反馈记录API
  1891. 用于前端直接创建反馈记录
  1892. """
  1893. try:
  1894. req = request.get_json(force=True)
  1895. question = req.get('question')
  1896. sql = req.get('sql')
  1897. is_thumb_up = req.get('is_thumb_up')
  1898. user_id = req.get('user_id', 'guest')
  1899. # 参数验证
  1900. if not question:
  1901. return jsonify(bad_request_response(
  1902. response_text="缺少必需参数:question",
  1903. missing_params=["question"]
  1904. )), 400
  1905. if not sql:
  1906. return jsonify(bad_request_response(
  1907. response_text="缺少必需参数:sql",
  1908. missing_params=["sql"]
  1909. )), 400
  1910. if is_thumb_up is None:
  1911. return jsonify(bad_request_response(
  1912. response_text="缺少必需参数:is_thumb_up",
  1913. missing_params=["is_thumb_up"]
  1914. )), 400
  1915. manager = get_qa_feedback_manager()
  1916. feedback_id = manager.add_feedback(
  1917. question=question,
  1918. sql=sql,
  1919. is_thumb_up=bool(is_thumb_up),
  1920. user_id=user_id
  1921. )
  1922. return jsonify(success_response(
  1923. response_text="反馈记录创建成功",
  1924. data={
  1925. "feedback_id": feedback_id
  1926. }
  1927. ))
  1928. except Exception as e:
  1929. logger.error(f"qa_feedback_add执行失败: {str(e)}")
  1930. return jsonify(internal_error_response(
  1931. response_text="创建反馈记录失败,请稍后重试"
  1932. )), 500
  1933. @app.route('/api/v0/qa_feedback/stats', methods=['GET'])
  1934. def qa_feedback_stats():
  1935. """
  1936. 反馈统计API
  1937. 返回反馈数据的统计信息
  1938. """
  1939. try:
  1940. manager = get_qa_feedback_manager()
  1941. # 查询各种统计数据
  1942. all_records, total_count = manager.query_feedback(page=1, page_size=1)
  1943. positive_records, positive_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=True)
  1944. negative_records, negative_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=False)
  1945. trained_records, trained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=True)
  1946. untrained_records, untrained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=False)
  1947. return jsonify(success_response(
  1948. response_text="统计信息获取成功",
  1949. data={
  1950. "total_feedback": total_count,
  1951. "positive_feedback": positive_count,
  1952. "negative_feedback": negative_count,
  1953. "trained_feedback": trained_count,
  1954. "untrained_feedback": untrained_count,
  1955. "positive_rate": round(positive_count / max(total_count, 1) * 100, 2),
  1956. "training_rate": round(trained_count / max(total_count, 1) * 100, 2)
  1957. }
  1958. ))
  1959. except Exception as e:
  1960. logger.error(f"qa_feedback_stats执行失败: {str(e)}")
  1961. return jsonify(internal_error_response(
  1962. response_text="获取统计信息失败,请稍后重试"
  1963. )), 500
  1964. # ==================== Redis对话管理API ====================
  1965. @app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1966. def get_user_conversations_redis(user_id: str):
  1967. """获取用户的对话列表"""
  1968. try:
  1969. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1970. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1971. # 为每个对话动态获取标题(第一条用户消息)
  1972. for conversation in conversations:
  1973. conversation_id = conversation['conversation_id']
  1974. try:
  1975. # 获取所有消息,然后取第一条用户消息作为标题
  1976. messages = redis_conversation_manager.get_conversation_messages(conversation_id)
  1977. if messages and len(messages) > 0:
  1978. # 找到第一条用户消息(按时间顺序)
  1979. first_user_message = None
  1980. for message in messages:
  1981. if message.get('role') == 'user':
  1982. first_user_message = message
  1983. break
  1984. if first_user_message:
  1985. title = first_user_message.get('content', '对话').strip()
  1986. # 限制标题长度,保持整洁
  1987. if len(title) > 50:
  1988. conversation['conversation_title'] = title[:47] + "..."
  1989. else:
  1990. conversation['conversation_title'] = title
  1991. else:
  1992. conversation['conversation_title'] = "对话"
  1993. else:
  1994. conversation['conversation_title'] = "空对话"
  1995. except Exception as e:
  1996. logger.warning(f"获取对话标题失败 {conversation_id}: {str(e)}")
  1997. conversation['conversation_title'] = "对话"
  1998. return jsonify(success_response(
  1999. response_text="获取用户对话列表成功",
  2000. data={
  2001. "user_id": user_id,
  2002. "conversations": conversations,
  2003. "total_count": len(conversations)
  2004. }
  2005. ))
  2006. except Exception as e:
  2007. return jsonify(internal_error_response(
  2008. response_text="获取对话列表失败,请稍后重试"
  2009. )), 500
  2010. @app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  2011. def get_conversation_messages_redis(conversation_id: str):
  2012. """获取特定对话的消息历史"""
  2013. try:
  2014. limit = request.args.get('limit', type=int)
  2015. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  2016. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  2017. return jsonify(success_response(
  2018. response_text="获取对话消息成功",
  2019. data={
  2020. "conversation_id": conversation_id,
  2021. "conversation_meta": meta,
  2022. "messages": messages,
  2023. "message_count": len(messages)
  2024. }
  2025. ))
  2026. except Exception as e:
  2027. return jsonify(internal_error_response(
  2028. response_text="获取对话消息失败"
  2029. )), 500
  2030. @app.route('/api/v0/conversation_stats', methods=['GET'])
  2031. def conversation_stats():
  2032. """获取对话系统统计信息"""
  2033. try:
  2034. stats = redis_conversation_manager.get_stats()
  2035. return jsonify(success_response(
  2036. response_text="获取统计信息成功",
  2037. data=stats
  2038. ))
  2039. except Exception as e:
  2040. return jsonify(internal_error_response(
  2041. response_text="获取统计信息失败,请稍后重试"
  2042. )), 500
  2043. @app.route('/api/v0/conversation_cleanup', methods=['POST'])
  2044. def conversation_cleanup():
  2045. """手动清理过期对话"""
  2046. try:
  2047. redis_conversation_manager.cleanup_expired_conversations()
  2048. return jsonify(success_response(
  2049. response_text="对话清理完成"
  2050. ))
  2051. except Exception as e:
  2052. return jsonify(internal_error_response(
  2053. response_text="对话清理失败,请稍后重试"
  2054. )), 500
  2055. @app.route('/api/v0/embedding_cache_stats', methods=['GET'])
  2056. def embedding_cache_stats():
  2057. """获取embedding缓存统计信息"""
  2058. try:
  2059. from common.embedding_cache_manager import get_embedding_cache_manager
  2060. cache_manager = get_embedding_cache_manager()
  2061. stats = cache_manager.get_cache_stats()
  2062. return jsonify(success_response(
  2063. response_text="获取embedding缓存统计成功",
  2064. data=stats
  2065. ))
  2066. except Exception as e:
  2067. logger.error(f"获取embedding缓存统计失败: {str(e)}")
  2068. return jsonify(internal_error_response(
  2069. response_text="获取embedding缓存统计失败,请稍后重试"
  2070. )), 500
  2071. @app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
  2072. def embedding_cache_cleanup():
  2073. """清空所有embedding缓存"""
  2074. try:
  2075. from common.embedding_cache_manager import get_embedding_cache_manager
  2076. cache_manager = get_embedding_cache_manager()
  2077. if not cache_manager.is_available():
  2078. return jsonify(internal_error_response(
  2079. response_text="Embedding缓存功能未启用或不可用"
  2080. )), 400
  2081. success = cache_manager.clear_all_cache()
  2082. if success:
  2083. return jsonify(success_response(
  2084. response_text="所有embedding缓存已清空",
  2085. data={"cleared": True}
  2086. ))
  2087. else:
  2088. return jsonify(internal_error_response(
  2089. response_text="清空embedding缓存失败"
  2090. )), 500
  2091. except Exception as e:
  2092. logger.error(f"清空embedding缓存失败: {str(e)}")
  2093. return jsonify(internal_error_response(
  2094. response_text="清空embedding缓存失败,请稍后重试"
  2095. )), 500
  2096. # ==================== 训练数据管理API ====================
  2097. def validate_sql_syntax(sql: str) -> tuple[bool, str]:
  2098. """SQL语法检查"""
  2099. try:
  2100. parsed = sqlparse.parse(sql.strip())
  2101. if not parsed or not parsed[0].tokens:
  2102. return False, "SQL语法错误:空语句"
  2103. sql_upper = sql.strip().upper()
  2104. if not any(sql_upper.startswith(keyword) for keyword in
  2105. ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
  2106. return False, "SQL语法错误:不是有效的SQL语句"
  2107. return True, ""
  2108. except Exception as e:
  2109. return False, f"SQL语法错误:{str(e)}"
  2110. def paginate_data(data_list: list, page: int, page_size: int):
  2111. """分页算法"""
  2112. total = len(data_list)
  2113. start_idx = (page - 1) * page_size
  2114. end_idx = start_idx + page_size
  2115. page_data = data_list[start_idx:end_idx]
  2116. total_pages = (total + page_size - 1) // page_size
  2117. return {
  2118. "data": page_data,
  2119. "pagination": {
  2120. "page": page,
  2121. "page_size": page_size,
  2122. "total": total,
  2123. "total_pages": total_pages,
  2124. "has_next": end_idx < total,
  2125. "has_prev": page > 1
  2126. }
  2127. }
  2128. def filter_by_type(data_list: list, training_data_type: str):
  2129. """按类型筛选算法"""
  2130. if not training_data_type:
  2131. return data_list
  2132. return [
  2133. record for record in data_list
  2134. if record.get('training_data_type') == training_data_type
  2135. ]
  2136. def search_in_data(data_list: list, search_keyword: str):
  2137. """在数据中搜索关键词"""
  2138. if not search_keyword:
  2139. return data_list
  2140. keyword_lower = search_keyword.lower()
  2141. return [
  2142. record for record in data_list
  2143. if (record.get('question') and keyword_lower in record['question'].lower()) or
  2144. (record.get('content') and keyword_lower in record['content'].lower())
  2145. ]
  2146. def get_total_training_count():
  2147. """获取当前训练数据总数"""
  2148. try:
  2149. training_data = vn.get_training_data()
  2150. if training_data is not None and not training_data.empty:
  2151. return len(training_data)
  2152. return 0
  2153. except Exception as e:
  2154. logger.warning(f"获取训练数据总数失败: {e}")
  2155. return 0
  2156. def process_single_training_item(item: dict, index: int) -> dict:
  2157. """处理单个训练数据项"""
  2158. training_type = item.get('training_data_type')
  2159. if training_type == 'sql':
  2160. sql = item.get('sql')
  2161. if not sql:
  2162. raise ValueError("SQL字段是必需的")
  2163. # SQL语法检查
  2164. is_valid, error_msg = validate_sql_syntax(sql)
  2165. if not is_valid:
  2166. raise ValueError(error_msg)
  2167. question = item.get('question')
  2168. if question:
  2169. training_id = vn.train(question=question, sql=sql)
  2170. else:
  2171. training_id = vn.train(sql=sql)
  2172. elif training_type == 'error_sql':
  2173. # error_sql不需要语法检查
  2174. question = item.get('question')
  2175. sql = item.get('sql')
  2176. if not question or not sql:
  2177. raise ValueError("question和sql字段都是必需的")
  2178. training_id = vn.train_error_sql(question=question, sql=sql)
  2179. elif training_type == 'documentation':
  2180. content = item.get('content')
  2181. if not content:
  2182. raise ValueError("content字段是必需的")
  2183. training_id = vn.train(documentation=content)
  2184. elif training_type == 'ddl':
  2185. ddl = item.get('ddl')
  2186. if not ddl:
  2187. raise ValueError("ddl字段是必需的")
  2188. training_id = vn.train(ddl=ddl)
  2189. else:
  2190. raise ValueError(f"不支持的训练数据类型: {training_type}")
  2191. return {
  2192. "index": index,
  2193. "success": True,
  2194. "training_id": training_id,
  2195. "type": training_type,
  2196. "message": f"{training_type}训练数据创建成功"
  2197. }
  2198. @app.route('/api/v0/training_data/stats', methods=['GET'])
  2199. def training_data_stats():
  2200. """获取训练数据统计信息API"""
  2201. try:
  2202. training_data = vn.get_training_data()
  2203. if training_data is None or training_data.empty:
  2204. return jsonify(success_response(
  2205. response_text="统计信息获取成功",
  2206. data={
  2207. "total_count": 0,
  2208. "type_breakdown": {
  2209. "sql": 0,
  2210. "documentation": 0,
  2211. "ddl": 0,
  2212. "error_sql": 0
  2213. },
  2214. "type_percentages": {
  2215. "sql": 0.0,
  2216. "documentation": 0.0,
  2217. "ddl": 0.0,
  2218. "error_sql": 0.0
  2219. },
  2220. "last_updated": datetime.now().isoformat()
  2221. }
  2222. ))
  2223. total_count = len(training_data)
  2224. # 统计各类型数量
  2225. type_breakdown = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2226. if 'training_data_type' in training_data.columns:
  2227. type_counts = training_data['training_data_type'].value_counts()
  2228. for data_type, count in type_counts.items():
  2229. if data_type in type_breakdown:
  2230. type_breakdown[data_type] = int(count)
  2231. # 计算百分比
  2232. type_percentages = {}
  2233. for data_type, count in type_breakdown.items():
  2234. type_percentages[data_type] = round(count / max(total_count, 1) * 100, 2)
  2235. return jsonify(success_response(
  2236. response_text="统计信息获取成功",
  2237. data={
  2238. "total_count": total_count,
  2239. "type_breakdown": type_breakdown,
  2240. "type_percentages": type_percentages,
  2241. "last_updated": datetime.now().isoformat()
  2242. }
  2243. ))
  2244. except Exception as e:
  2245. logger.error(f"training_data_stats执行失败: {str(e)}")
  2246. return jsonify(internal_error_response(
  2247. response_text="获取统计信息失败,请稍后重试"
  2248. )), 500
  2249. @app.route('/api/v0/training_data/query', methods=['POST'])
  2250. def training_data_query():
  2251. """分页查询训练数据API - 支持类型筛选、搜索和排序功能"""
  2252. try:
  2253. req = request.get_json(force=True)
  2254. # 解析参数,设置默认值
  2255. page = req.get('page', 1)
  2256. page_size = req.get('page_size', 20)
  2257. training_data_type = req.get('training_data_type')
  2258. sort_by = req.get('sort_by', 'id')
  2259. sort_order = req.get('sort_order', 'desc')
  2260. search_keyword = req.get('search_keyword')
  2261. # 参数验证
  2262. if page < 1:
  2263. return jsonify(bad_request_response(
  2264. response_text="页码必须大于0",
  2265. missing_params=["page"]
  2266. )), 400
  2267. if page_size < 1 or page_size > 100:
  2268. return jsonify(bad_request_response(
  2269. response_text="每页大小必须在1-100之间",
  2270. missing_params=["page_size"]
  2271. )), 400
  2272. if search_keyword and len(search_keyword) > 100:
  2273. return jsonify(bad_request_response(
  2274. response_text="搜索关键词最大长度为100字符",
  2275. missing_params=["search_keyword"]
  2276. )), 400
  2277. # 获取训练数据
  2278. training_data = vn.get_training_data()
  2279. if training_data is None or training_data.empty:
  2280. return jsonify(success_response(
  2281. response_text="查询成功,暂无训练数据",
  2282. data={
  2283. "records": [],
  2284. "pagination": {
  2285. "page": page,
  2286. "page_size": page_size,
  2287. "total": 0,
  2288. "total_pages": 0,
  2289. "has_next": False,
  2290. "has_prev": False
  2291. },
  2292. "filters_applied": {
  2293. "training_data_type": training_data_type,
  2294. "search_keyword": search_keyword
  2295. }
  2296. }
  2297. ))
  2298. # 转换为列表格式
  2299. records = training_data.to_dict(orient="records")
  2300. # 应用筛选条件
  2301. if training_data_type:
  2302. records = filter_by_type(records, training_data_type)
  2303. if search_keyword:
  2304. records = search_in_data(records, search_keyword)
  2305. # 排序
  2306. if sort_by in ['id', 'training_data_type']:
  2307. reverse = (sort_order.lower() == 'desc')
  2308. records.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
  2309. # 分页
  2310. paginated_result = paginate_data(records, page, page_size)
  2311. return jsonify(success_response(
  2312. response_text=f"查询成功,共找到 {paginated_result['pagination']['total']} 条记录",
  2313. data={
  2314. "records": paginated_result["data"],
  2315. "pagination": paginated_result["pagination"],
  2316. "filters_applied": {
  2317. "training_data_type": training_data_type,
  2318. "search_keyword": search_keyword
  2319. }
  2320. }
  2321. ))
  2322. except Exception as e:
  2323. logger.error(f"training_data_query执行失败: {str(e)}")
  2324. return jsonify(internal_error_response(
  2325. response_text="查询训练数据失败,请稍后重试"
  2326. )), 500
  2327. @app.route('/api/v0/training_data/create', methods=['POST'])
  2328. def training_data_create():
  2329. """创建训练数据API - 支持单条和批量创建,支持四种数据类型"""
  2330. try:
  2331. req = request.get_json(force=True)
  2332. data = req.get('data')
  2333. if not data:
  2334. return jsonify(bad_request_response(
  2335. response_text="缺少必需参数:data",
  2336. missing_params=["data"]
  2337. )), 400
  2338. # 统一处理为列表格式
  2339. if isinstance(data, dict):
  2340. data_list = [data]
  2341. elif isinstance(data, list):
  2342. data_list = data
  2343. else:
  2344. return jsonify(bad_request_response(
  2345. response_text="data字段格式错误,应为对象或数组"
  2346. )), 400
  2347. # 批量操作限制
  2348. if len(data_list) > 50:
  2349. return jsonify(bad_request_response(
  2350. response_text="批量操作最大支持50条记录"
  2351. )), 400
  2352. results = []
  2353. successful_count = 0
  2354. type_summary = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2355. for index, item in enumerate(data_list):
  2356. try:
  2357. result = process_single_training_item(item, index)
  2358. results.append(result)
  2359. if result['success']:
  2360. successful_count += 1
  2361. type_summary[result['type']] += 1
  2362. except Exception as e:
  2363. results.append({
  2364. "index": index,
  2365. "success": False,
  2366. "type": item.get('training_data_type', 'unknown'),
  2367. "error": str(e),
  2368. "message": "创建失败"
  2369. })
  2370. # 获取创建后的总记录数
  2371. current_total = get_total_training_count()
  2372. # 根据实际执行结果决定响应状态
  2373. failed_count = len(data_list) - successful_count
  2374. if failed_count == 0:
  2375. # 全部成功
  2376. return jsonify(success_response(
  2377. response_text="训练数据创建完成",
  2378. data={
  2379. "total_requested": len(data_list),
  2380. "successfully_created": successful_count,
  2381. "failed_count": failed_count,
  2382. "results": results,
  2383. "summary": type_summary,
  2384. "current_total_count": current_total
  2385. }
  2386. ))
  2387. elif successful_count == 0:
  2388. # 全部失败
  2389. return jsonify(error_response(
  2390. response_text="训练数据创建失败",
  2391. data={
  2392. "total_requested": len(data_list),
  2393. "successfully_created": successful_count,
  2394. "failed_count": failed_count,
  2395. "results": results,
  2396. "summary": type_summary,
  2397. "current_total_count": current_total
  2398. }
  2399. )), 400
  2400. else:
  2401. # 部分成功,部分失败
  2402. return jsonify(error_response(
  2403. response_text=f"训练数据创建部分成功,成功{successful_count}条,失败{failed_count}条",
  2404. data={
  2405. "total_requested": len(data_list),
  2406. "successfully_created": successful_count,
  2407. "failed_count": failed_count,
  2408. "results": results,
  2409. "summary": type_summary,
  2410. "current_total_count": current_total
  2411. }
  2412. )), 207
  2413. except Exception as e:
  2414. logger.error(f"training_data_create执行失败: {str(e)}")
  2415. return jsonify(internal_error_response(
  2416. response_text="创建训练数据失败,请稍后重试"
  2417. )), 500
  2418. @app.route('/api/v0/training_data/delete', methods=['POST'])
  2419. def training_data_delete():
  2420. """删除训练数据API - 支持批量删除"""
  2421. try:
  2422. req = request.get_json(force=True)
  2423. ids = req.get('ids', [])
  2424. confirm = req.get('confirm', False)
  2425. if not ids or not isinstance(ids, list):
  2426. return jsonify(bad_request_response(
  2427. response_text="缺少有效的ID列表",
  2428. missing_params=["ids"]
  2429. )), 400
  2430. if not confirm:
  2431. return jsonify(bad_request_response(
  2432. response_text="删除操作需要确认,请设置confirm为true"
  2433. )), 400
  2434. # 批量操作限制
  2435. if len(ids) > 50:
  2436. return jsonify(bad_request_response(
  2437. response_text="批量删除最大支持50条记录"
  2438. )), 400
  2439. deleted_ids = []
  2440. failed_ids = []
  2441. failed_details = []
  2442. for training_id in ids:
  2443. try:
  2444. success = vn.remove_training_data(training_id)
  2445. if success:
  2446. deleted_ids.append(training_id)
  2447. else:
  2448. failed_ids.append(training_id)
  2449. failed_details.append({
  2450. "id": training_id,
  2451. "error": "记录不存在或删除失败"
  2452. })
  2453. except Exception as e:
  2454. failed_ids.append(training_id)
  2455. failed_details.append({
  2456. "id": training_id,
  2457. "error": str(e)
  2458. })
  2459. # 获取删除后的总记录数
  2460. current_total = get_total_training_count()
  2461. # 根据实际执行结果决定响应状态
  2462. failed_count = len(failed_ids)
  2463. if failed_count == 0:
  2464. # 全部成功
  2465. return jsonify(success_response(
  2466. response_text="训练数据删除完成",
  2467. data={
  2468. "total_requested": len(ids),
  2469. "successfully_deleted": len(deleted_ids),
  2470. "failed_count": failed_count,
  2471. "deleted_ids": deleted_ids,
  2472. "failed_ids": failed_ids,
  2473. "failed_details": failed_details,
  2474. "current_total_count": current_total
  2475. }
  2476. ))
  2477. elif len(deleted_ids) == 0:
  2478. # 全部失败
  2479. return jsonify(error_response(
  2480. response_text="训练数据删除失败",
  2481. data={
  2482. "total_requested": len(ids),
  2483. "successfully_deleted": len(deleted_ids),
  2484. "failed_count": failed_count,
  2485. "deleted_ids": deleted_ids,
  2486. "failed_ids": failed_ids,
  2487. "failed_details": failed_details,
  2488. "current_total_count": current_total
  2489. }
  2490. )), 400
  2491. else:
  2492. # 部分成功,部分失败
  2493. return jsonify(error_response(
  2494. response_text=f"训练数据删除部分成功,成功{len(deleted_ids)}条,失败{failed_count}条",
  2495. data={
  2496. "total_requested": len(ids),
  2497. "successfully_deleted": len(deleted_ids),
  2498. "failed_count": failed_count,
  2499. "deleted_ids": deleted_ids,
  2500. "failed_ids": failed_ids,
  2501. "failed_details": failed_details,
  2502. "current_total_count": current_total
  2503. }
  2504. )), 207
  2505. except Exception as e:
  2506. logger.error(f"training_data_delete执行失败: {str(e)}")
  2507. return jsonify(internal_error_response(
  2508. response_text="删除训练数据失败,请稍后重试"
  2509. )), 500
  2510. @app.route('/api/v0/training_data/update', methods=['POST'])
  2511. def training_data_update():
  2512. """更新训练数据API - 支持单条更新,采用先删除后插入策略"""
  2513. try:
  2514. req = request.get_json(force=True)
  2515. # 1. 参数验证
  2516. original_id = req.get('id')
  2517. if not original_id:
  2518. return jsonify(bad_request_response(
  2519. response_text="缺少必需参数:id",
  2520. missing_params=["id"]
  2521. )), 400
  2522. training_type = req.get('training_data_type')
  2523. if not training_type:
  2524. return jsonify(bad_request_response(
  2525. response_text="缺少必需参数:training_data_type",
  2526. missing_params=["training_data_type"]
  2527. )), 400
  2528. # 2. 先删除原始记录
  2529. try:
  2530. success = vn.remove_training_data(original_id)
  2531. if not success:
  2532. return jsonify(bad_request_response(
  2533. response_text=f"原始记录 {original_id} 不存在或删除失败"
  2534. )), 400
  2535. except Exception as e:
  2536. return jsonify(internal_error_response(
  2537. response_text=f"删除原始记录失败: {str(e)}"
  2538. )), 500
  2539. # 3. 根据类型验证和准备新数据
  2540. try:
  2541. if training_type == 'sql':
  2542. sql = req.get('sql')
  2543. if not sql:
  2544. return jsonify(bad_request_response(
  2545. response_text="SQL字段是必需的",
  2546. missing_params=["sql"]
  2547. )), 400
  2548. # SQL语法检查
  2549. is_valid, error_msg = validate_sql_syntax(sql)
  2550. if not is_valid:
  2551. return jsonify(bad_request_response(
  2552. response_text=f"SQL语法错误: {error_msg}"
  2553. )), 400
  2554. question = req.get('question')
  2555. if question:
  2556. training_id = vn.train(question=question, sql=sql)
  2557. else:
  2558. training_id = vn.train(sql=sql)
  2559. elif training_type == 'error_sql':
  2560. question = req.get('question')
  2561. sql = req.get('sql')
  2562. if not question or not sql:
  2563. return jsonify(bad_request_response(
  2564. response_text="question和sql字段都是必需的",
  2565. missing_params=["question", "sql"]
  2566. )), 400
  2567. training_id = vn.train_error_sql(question=question, sql=sql)
  2568. elif training_type == 'documentation':
  2569. content = req.get('content')
  2570. if not content:
  2571. return jsonify(bad_request_response(
  2572. response_text="content字段是必需的",
  2573. missing_params=["content"]
  2574. )), 400
  2575. training_id = vn.train(documentation=content)
  2576. elif training_type == 'ddl':
  2577. ddl = req.get('ddl')
  2578. if not ddl:
  2579. return jsonify(bad_request_response(
  2580. response_text="ddl字段是必需的",
  2581. missing_params=["ddl"]
  2582. )), 400
  2583. training_id = vn.train(ddl=ddl)
  2584. else:
  2585. return jsonify(bad_request_response(
  2586. response_text=f"不支持的训练数据类型: {training_type}"
  2587. )), 400
  2588. except Exception as e:
  2589. return jsonify(internal_error_response(
  2590. response_text=f"创建新训练数据失败: {str(e)}"
  2591. )), 500
  2592. # 4. 获取更新后的总记录数
  2593. current_total = get_total_training_count()
  2594. return jsonify(success_response(
  2595. response_text="训练数据更新成功",
  2596. data={
  2597. "original_id": original_id,
  2598. "new_training_id": training_id,
  2599. "type": training_type,
  2600. "current_total_count": current_total
  2601. }
  2602. ))
  2603. except Exception as e:
  2604. logger.error(f"training_data_update执行失败: {str(e)}")
  2605. return jsonify(internal_error_response(
  2606. response_text="更新训练数据失败,请稍后重试"
  2607. )), 500
  2608. # 导入现有的专业训练函数
  2609. from data_pipeline.trainer.run_training import (
  2610. train_ddl_statements,
  2611. train_documentation_blocks,
  2612. train_json_question_sql_pairs,
  2613. train_formatted_question_sql_pairs,
  2614. train_sql_examples
  2615. )
  2616. def get_allowed_extensions(file_type: str) -> list:
  2617. """根据文件类型返回允许的扩展名"""
  2618. type_specific_extensions = {
  2619. 'ddl': ['ddl', 'sql', 'txt', ''], # 支持无扩展名
  2620. 'markdown': ['md', 'markdown'], # 不支持无扩展名
  2621. 'sql_pair_json': ['json', 'txt', ''], # 支持无扩展名
  2622. 'sql_pair': ['sql', 'txt', ''], # 支持无扩展名
  2623. 'sql': ['sql', 'txt', ''] # 支持无扩展名
  2624. }
  2625. return type_specific_extensions.get(file_type, [])
  2626. def validate_file_content(content: str, file_type: str) -> dict:
  2627. """验证文件内容格式"""
  2628. try:
  2629. if file_type == 'ddl':
  2630. # 检查是否包含CREATE语句
  2631. if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
  2632. return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
  2633. elif file_type == 'markdown':
  2634. # 检查是否包含##标题
  2635. if '##' not in content:
  2636. return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
  2637. elif file_type == 'sql_pair_json':
  2638. # 检查是否为有效JSON
  2639. try:
  2640. data = json.loads(content)
  2641. if not isinstance(data, list) or not data:
  2642. return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
  2643. # 检查是否包含question和sql字段
  2644. for item in data:
  2645. if not isinstance(item, dict):
  2646. return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
  2647. has_question = any(key.lower() == 'question' for key in item.keys())
  2648. has_sql = any(key.lower() == 'sql' for key in item.keys())
  2649. if not has_question or not has_sql:
  2650. return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
  2651. except json.JSONDecodeError:
  2652. return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
  2653. elif file_type == 'sql_pair':
  2654. # 检查是否包含Question:和SQL:
  2655. if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
  2656. return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
  2657. if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
  2658. return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
  2659. elif file_type == 'sql':
  2660. # 检查是否包含;分隔符
  2661. if ';' not in content:
  2662. return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
  2663. return {'valid': True}
  2664. except Exception as e:
  2665. return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}
  2666. @app.route('/api/v0/training_data/upload', methods=['POST'])
  2667. def upload_training_data():
  2668. """上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
  2669. try:
  2670. # 1. 参数验证
  2671. if 'file' not in request.files:
  2672. return jsonify(bad_request_response("未提供文件"))
  2673. file = request.files['file']
  2674. if file.filename == '':
  2675. return jsonify(bad_request_response("未选择文件"))
  2676. # 获取file_type参数
  2677. file_type = request.form.get('file_type')
  2678. if not file_type:
  2679. return jsonify(bad_request_response("缺少必需参数:file_type"))
  2680. # 验证file_type参数
  2681. valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
  2682. if file_type not in valid_file_types:
  2683. return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
  2684. # 2. 文件大小验证 (500KB)
  2685. file.seek(0, 2)
  2686. file_size = file.tell()
  2687. file.seek(0)
  2688. if file_size > 500 * 1024: # 500KB
  2689. return jsonify(bad_request_response("文件大小不能超过500KB"))
  2690. # 3. 验证文件扩展名(基于file_type)
  2691. filename = secure_filename(file.filename)
  2692. allowed_extensions = get_allowed_extensions(file_type)
  2693. file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
  2694. if file_ext not in allowed_extensions:
  2695. # 构建友好的错误信息
  2696. non_empty_extensions = [ext for ext in allowed_extensions if ext]
  2697. if '' in allowed_extensions:
  2698. ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
  2699. else:
  2700. ext_message = ', '.join(non_empty_extensions)
  2701. return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
  2702. # 4. 读取文件内容并验证格式
  2703. file.seek(0)
  2704. content = file.read().decode('utf-8')
  2705. # 格式验证
  2706. validation_result = validate_file_content(content, file_type)
  2707. if not validation_result['valid']:
  2708. return jsonify(bad_request_response(validation_result['error']))
  2709. # 5. 创建临时文件(复用现有函数需要文件路径)
  2710. temp_file_path = None
  2711. try:
  2712. with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
  2713. tmp_file.write(content)
  2714. temp_file_path = tmp_file.name
  2715. # 6. 根据文件类型调用现有的训练函数
  2716. if file_type == 'ddl':
  2717. train_ddl_statements(temp_file_path)
  2718. elif file_type == 'markdown':
  2719. train_documentation_blocks(temp_file_path)
  2720. elif file_type == 'sql_pair_json':
  2721. train_json_question_sql_pairs(temp_file_path)
  2722. elif file_type == 'sql_pair':
  2723. train_formatted_question_sql_pairs(temp_file_path)
  2724. elif file_type == 'sql':
  2725. train_sql_examples(temp_file_path)
  2726. return jsonify(success_response(
  2727. response_text=f"文件上传并训练成功:{filename}",
  2728. data={
  2729. "filename": filename,
  2730. "file_type": file_type,
  2731. "file_size": file_size,
  2732. "status": "completed"
  2733. }
  2734. ))
  2735. except Exception as e:
  2736. logger.error(f"训练失败: {str(e)}")
  2737. return jsonify(internal_error_response(f"训练失败: {str(e)}"))
  2738. finally:
  2739. # 清理临时文件
  2740. if temp_file_path and os.path.exists(temp_file_path):
  2741. try:
  2742. os.unlink(temp_file_path)
  2743. except Exception as e:
  2744. logger.warning(f"清理临时文件失败: {str(e)}")
  2745. except Exception as e:
  2746. logger.error(f"文件上传失败: {str(e)}")
  2747. return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
  2748. def get_db_connection():
  2749. """获取数据库连接"""
  2750. try:
  2751. from app_config import PGVECTOR_CONFIG
  2752. return psycopg2.connect(**PGVECTOR_CONFIG)
  2753. except Exception as e:
  2754. logger.error(f"数据库连接失败: {str(e)}")
  2755. raise
  2756. def get_db_connection_for_transaction():
  2757. """获取用于事务操作的数据库连接(非自动提交模式)"""
  2758. try:
  2759. from app_config import PGVECTOR_CONFIG
  2760. conn = psycopg2.connect(**PGVECTOR_CONFIG)
  2761. conn.autocommit = False # 设置为非自动提交模式,允许手动控制事务
  2762. return conn
  2763. except Exception as e:
  2764. logger.error(f"数据库连接失败: {str(e)}")
  2765. raise
  2766. @app.route('/api/v0/training_data/combine', methods=['POST'])
  2767. def combine_training_data():
  2768. """合并训练数据API - 支持合并重复记录"""
  2769. try:
  2770. # 1. 参数验证
  2771. data = request.get_json()
  2772. if not data:
  2773. return jsonify(bad_request_response("请求体不能为空"))
  2774. collection_names = data.get('collection_names', [])
  2775. if not collection_names or not isinstance(collection_names, list):
  2776. return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
  2777. # 验证集合名称
  2778. valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
  2779. invalid_collections = [name for name in collection_names if name not in valid_collections]
  2780. if invalid_collections:
  2781. return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
  2782. dry_run = data.get('dry_run', True)
  2783. keep_strategy = data.get('keep_strategy', 'first')
  2784. if keep_strategy not in ['first', 'last', 'by_metadata_time']:
  2785. return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
  2786. # 2. 获取数据库连接(用于事务操作)
  2787. conn = get_db_connection_for_transaction()
  2788. cursor = conn.cursor()
  2789. # 3. 查找重复记录
  2790. duplicate_groups = []
  2791. total_before = 0
  2792. total_duplicates = 0
  2793. collections_stats = {}
  2794. for collection_name in collection_names:
  2795. # 获取集合ID
  2796. cursor.execute(
  2797. "SELECT uuid FROM langchain_pg_collection WHERE name = %s",
  2798. (collection_name,)
  2799. )
  2800. collection_result = cursor.fetchone()
  2801. if not collection_result:
  2802. continue
  2803. collection_id = collection_result[0]
  2804. # 统计该集合的记录数
  2805. cursor.execute(
  2806. "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
  2807. (collection_id,)
  2808. )
  2809. collection_before = cursor.fetchone()[0]
  2810. total_before += collection_before
  2811. # 查找重复记录
  2812. if keep_strategy in ['first', 'last']:
  2813. order_by = "id"
  2814. else:
  2815. order_by = "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id"
  2816. cursor.execute(f"""
  2817. SELECT document, COUNT(*) as duplicate_count,
  2818. array_agg(id ORDER BY {order_by}) as record_ids
  2819. FROM langchain_pg_embedding
  2820. WHERE collection_id = %s
  2821. GROUP BY document
  2822. HAVING COUNT(*) > 1
  2823. """, (collection_id,))
  2824. collection_duplicates = 0
  2825. for row in cursor.fetchall():
  2826. document_content, duplicate_count, record_ids = row
  2827. collection_duplicates += duplicate_count - 1 # 减去要保留的一条
  2828. # 根据保留策略选择要保留的记录
  2829. if keep_strategy == 'first':
  2830. keep_id = record_ids[0]
  2831. remove_ids = record_ids[1:]
  2832. elif keep_strategy == 'last':
  2833. keep_id = record_ids[-1]
  2834. remove_ids = record_ids[:-1]
  2835. else: # by_metadata_time
  2836. keep_id = record_ids[0] # 已经按时间排序
  2837. remove_ids = record_ids[1:]
  2838. duplicate_groups.append({
  2839. "collection_name": collection_name,
  2840. "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
  2841. "duplicate_count": duplicate_count,
  2842. "kept_record_id" if not dry_run else "records_to_keep": keep_id,
  2843. "removed_record_ids" if not dry_run else "records_to_remove": remove_ids
  2844. })
  2845. total_duplicates += collection_duplicates
  2846. collections_stats[collection_name] = {
  2847. "before": collection_before,
  2848. "after": collection_before - collection_duplicates,
  2849. "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
  2850. }
  2851. # 4. 执行合并操作(如果不是dry_run)
  2852. if not dry_run:
  2853. try:
  2854. # 连接已经设置为非自动提交模式,直接开始事务
  2855. for group in duplicate_groups:
  2856. remove_ids = group["removed_record_ids"]
  2857. if remove_ids:
  2858. cursor.execute(
  2859. "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
  2860. (remove_ids,)
  2861. )
  2862. conn.commit()
  2863. except Exception as e:
  2864. conn.rollback()
  2865. return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
  2866. # 5. 构建响应
  2867. total_after = total_before - total_duplicates
  2868. summary = {
  2869. "total_records_before": total_before,
  2870. "total_records_after": total_after,
  2871. "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
  2872. "collections_stats": collections_stats
  2873. }
  2874. if dry_run:
  2875. response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
  2876. data_key = "duplicate_groups"
  2877. else:
  2878. response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
  2879. data_key = "merged_groups"
  2880. return jsonify(success_response(
  2881. response_text=response_text,
  2882. data={
  2883. "dry_run": dry_run,
  2884. "collections_processed": collection_names,
  2885. "summary": summary,
  2886. data_key: duplicate_groups
  2887. }
  2888. ))
  2889. except Exception as e:
  2890. return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
  2891. finally:
  2892. if 'cursor' in locals():
  2893. cursor.close()
  2894. if 'conn' in locals():
  2895. conn.close()
  2896. # ==================== React Agent 扩展API ====================
  2897. @app.route('/api/v0/react/users/<user_id>/conversations', methods=['GET'])
  2898. async def get_user_conversations_react(user_id: str):
  2899. """异步获取用户的聊天记录列表(从 custom_react_agent 迁移)"""
  2900. global _react_agent_instance
  2901. try:
  2902. # 获取查询参数
  2903. limit = request.args.get('limit', 10, type=int)
  2904. # 限制limit的范围
  2905. limit = max(1, min(limit, 50)) # 限制在1-50之间
  2906. logger.info(f"📋 异步获取用户 {user_id} 的对话列表,限制 {limit} 条")
  2907. # 确保Agent可用
  2908. if not await ensure_agent_ready():
  2909. return jsonify(service_unavailable_response(
  2910. response_text="Agent 未就绪"
  2911. )), 503
  2912. # 直接调用异步方法
  2913. conversations = await _react_agent_instance.get_user_recent_conversations(user_id, limit)
  2914. return jsonify(success_response(
  2915. response_text="获取用户对话列表成功",
  2916. data={
  2917. "user_id": user_id,
  2918. "conversations": conversations,
  2919. "total_count": len(conversations),
  2920. "limit": limit
  2921. }
  2922. )), 200
  2923. except Exception as e:
  2924. logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
  2925. return jsonify(internal_error_response(
  2926. response_text=f"获取用户对话列表失败: {str(e)}"
  2927. )), 500
  2928. @app.route('/api/v0/react/conversations/<thread_id>', methods=['GET'])
  2929. async def get_user_conversation_detail_react(thread_id: str):
  2930. """异步获取特定对话的详细历史(从 custom_react_agent 迁移)"""
  2931. global _react_agent_instance
  2932. try:
  2933. # 从thread_id中提取user_id
  2934. user_id = thread_id.split(':')[0] if ':' in thread_id else 'unknown'
  2935. logger.info(f"📖 异步获取用户 {user_id} 的对话 {thread_id} 详情")
  2936. # 确保Agent可用
  2937. if not await ensure_agent_ready():
  2938. return jsonify(service_unavailable_response(
  2939. response_text="Agent 未就绪"
  2940. )), 503
  2941. # 获取查询参数
  2942. include_tools = request.args.get('include_tools', 'false').lower() == 'true'
  2943. # 直接调用异步方法
  2944. conversation_data = await _react_agent_instance.get_conversation_history(thread_id, include_tools=include_tools)
  2945. messages = conversation_data.get("messages", [])
  2946. logger.info(f"✅ 异步成功获取对话历史,消息数量: {len(messages)}")
  2947. if not messages:
  2948. return jsonify(not_found_response(
  2949. response_text=f"未找到对话 {thread_id}"
  2950. )), 404
  2951. # 格式化消息
  2952. formatted_messages = []
  2953. for msg in messages:
  2954. formatted_msg = {
  2955. "message_id": msg["id"], # id -> message_id
  2956. "role": msg["type"], # type -> role
  2957. "content": msg["content"],
  2958. "timestamp": _format_timestamp_to_china_time(msg["timestamp"]) # 转换为中国时区
  2959. }
  2960. formatted_messages.append(formatted_msg)
  2961. return jsonify(success_response(
  2962. response_text="获取对话详情成功",
  2963. data={
  2964. "user_id": user_id,
  2965. "thread_id": thread_id,
  2966. "conversation_id": thread_id, # 新增conversation_id字段
  2967. "message_count": len(formatted_messages),
  2968. "messages": formatted_messages,
  2969. "created_at": conversation_data.get("thread_created_at"), # 已经包含毫秒
  2970. "total_checkpoints": conversation_data.get("total_checkpoints", 0)
  2971. }
  2972. )), 200
  2973. except Exception as e:
  2974. import traceback
  2975. logger.error(f"❌ 异步获取对话 {thread_id} 详情失败: {e}")
  2976. logger.error(f"❌ 详细错误信息: {traceback.format_exc()}")
  2977. return jsonify(internal_error_response(
  2978. response_text=f"获取对话详情失败: {str(e)}"
  2979. )), 500
  2980. @app.route('/api/test/redis', methods=['GET'])
  2981. def test_redis_connection():
  2982. """测试Redis连接和基本查询(从 custom_react_agent 迁移)"""
  2983. try:
  2984. import redis
  2985. # 创建Redis连接
  2986. if react_agent_config:
  2987. r = redis.Redis(
  2988. host=react_agent_config.REDIS_HOST,
  2989. port=react_agent_config.REDIS_PORT,
  2990. db=react_agent_config.REDIS_DB,
  2991. password=react_agent_config.REDIS_PASSWORD,
  2992. decode_responses=True
  2993. )
  2994. else:
  2995. r = redis.Redis(host='localhost', port=6379, decode_responses=True)
  2996. r.ping()
  2997. # 扫描checkpoint keys
  2998. pattern = "checkpoint:*"
  2999. keys = []
  3000. cursor = 0
  3001. count = 0
  3002. while True:
  3003. cursor, batch = r.scan(cursor=cursor, match=pattern, count=100)
  3004. keys.extend(batch)
  3005. count += len(batch)
  3006. if cursor == 0 or count > 500: # 限制扫描数量
  3007. break
  3008. # 统计用户
  3009. users = {}
  3010. for key in keys:
  3011. try:
  3012. parts = key.split(':')
  3013. if len(parts) >= 2:
  3014. user_id = parts[1]
  3015. users[user_id] = users.get(user_id, 0) + 1
  3016. except:
  3017. continue
  3018. r.close()
  3019. return jsonify({
  3020. "success": True,
  3021. "data": {
  3022. "redis_connected": True,
  3023. "total_checkpoint_keys": len(keys),
  3024. "users_found": list(users.keys()),
  3025. "user_key_counts": users,
  3026. "sample_keys": keys[:5] if keys else []
  3027. },
  3028. "timestamp": datetime.now().isoformat()
  3029. }), 200
  3030. except Exception as e:
  3031. logger.error(f"❌ Redis测试失败: {e}")
  3032. return jsonify({
  3033. "success": False,
  3034. "error": str(e),
  3035. "timestamp": datetime.now().isoformat()
  3036. }), 500
  3037. @app.route('/api/v0/react/direct/users/<user_id>/conversations', methods=['GET'])
  3038. def test_get_user_conversations_simple(user_id: str):
  3039. """直接从Redis获取用户对话列表"""
  3040. try:
  3041. limit = request.args.get('limit', 10, type=int)
  3042. limit = max(1, min(limit, 50))
  3043. logger.info(f"📋 获取用户 {user_id} 的对话列表(直接Redis方式)")
  3044. # 使用简单Redis查询
  3045. conversations = get_user_conversations_simple_sync(user_id, limit)
  3046. return jsonify(success_response(
  3047. response_text="获取用户对话列表成功",
  3048. data={
  3049. "user_id": user_id,
  3050. "conversations": conversations,
  3051. "total_count": len(conversations),
  3052. "limit": limit
  3053. }
  3054. )), 200
  3055. except Exception as e:
  3056. logger.error(f"❌ 获取用户对话列表失败: {e}")
  3057. return jsonify(internal_error_response(
  3058. response_text=f"获取用户对话列表失败: {str(e)}"
  3059. )), 500
  3060. @app.route('/api/v0/react/direct/conversations/<thread_id>', methods=['GET'])
  3061. def get_conversation_detail_api(thread_id: str):
  3062. """
  3063. 获取特定对话的详细信息 - 支持include_tools开关参数(从 custom_react_agent 迁移)
  3064. Query Parameters:
  3065. - include_tools: bool, 是否包含工具调用信息,默认false
  3066. true: 返回完整对话(human/ai/tool/system)
  3067. false: 只返回human/ai消息,清理工具调用信息
  3068. - user_id: str, 可选的用户ID验证
  3069. Examples:
  3070. GET /api/conversations/wang:20250709195048728?include_tools=true # 完整模式
  3071. GET /api/conversations/wang:20250709195048728?include_tools=false # 简化模式(默认)
  3072. GET /api/conversations/wang:20250709195048728 # 简化模式(默认)
  3073. """
  3074. try:
  3075. # 获取查询参数
  3076. include_tools = request.args.get('include_tools', 'false').lower() == 'true'
  3077. user_id = request.args.get('user_id')
  3078. # 验证thread_id格式
  3079. if ':' not in thread_id:
  3080. return jsonify({
  3081. "success": False,
  3082. "error": "Invalid thread_id format. Expected format: user_id:timestamp",
  3083. "timestamp": datetime.now().isoformat()
  3084. }), 400
  3085. # 如果提供了user_id,验证thread_id是否属于该用户
  3086. thread_user_id = thread_id.split(':')[0]
  3087. if user_id and thread_user_id != user_id:
  3088. return jsonify(bad_request_response(
  3089. response_text=f"Thread ID {thread_id} does not belong to user {user_id}"
  3090. )), 400
  3091. logger.info(f"📖 获取对话详情 - Thread: {thread_id}, Include Tools: {include_tools}")
  3092. # 检查enhanced_redis_api是否可用
  3093. if get_conversation_detail_from_redis is None:
  3094. return jsonify(service_unavailable_response(
  3095. response_text="enhanced_redis_api 模块不可用"
  3096. )), 503
  3097. # 从Redis获取对话详情(使用我们的新函数)
  3098. result = get_conversation_detail_from_redis(thread_id, include_tools)
  3099. if not result['success']:
  3100. logger.warning(f"⚠️ 获取对话详情失败: {result['error']}")
  3101. return jsonify(internal_error_response(
  3102. response_text=result['error']
  3103. )), 404
  3104. # 添加API元数据
  3105. result['data']['api_metadata'] = {
  3106. "timestamp": datetime.now().isoformat(),
  3107. "api_version": "v1",
  3108. "endpoint": "get_conversation_detail",
  3109. "query_params": {
  3110. "include_tools": include_tools,
  3111. "user_id": user_id
  3112. }
  3113. }
  3114. mode_desc = "完整模式" if include_tools else "简化模式"
  3115. logger.info(f"✅ 成功获取对话详情 - Messages: {result['data']['message_count']}, Mode: {mode_desc}")
  3116. return jsonify(success_response(
  3117. response_text=f"获取对话详情成功 ({mode_desc})",
  3118. data=result['data']
  3119. )), 200
  3120. except Exception as e:
  3121. import traceback
  3122. logger.error(f"❌ 获取对话详情异常: {e}")
  3123. logger.error(f"❌ 详细错误信息: {traceback.format_exc()}")
  3124. return jsonify(internal_error_response(
  3125. response_text=f"获取对话详情失败: {str(e)}"
  3126. )), 500
  3127. @app.route('/api/v0/react/direct/conversations/<thread_id>/compare', methods=['GET'])
  3128. def compare_conversation_modes_api(thread_id: str):
  3129. """
  3130. 比较完整模式和简化模式的对话内容
  3131. 用于调试和理解两种模式的差异(从 custom_react_agent 迁移)
  3132. Examples:
  3133. GET /api/conversations/wang:20250709195048728/compare
  3134. """
  3135. try:
  3136. logger.info(f"🔍 比较对话模式 - Thread: {thread_id}")
  3137. # 检查enhanced_redis_api是否可用
  3138. if get_conversation_detail_from_redis is None:
  3139. return jsonify({
  3140. "success": False,
  3141. "error": "enhanced_redis_api 模块不可用",
  3142. "timestamp": datetime.now().isoformat()
  3143. }), 503
  3144. # 获取完整模式
  3145. full_result = get_conversation_detail_from_redis(thread_id, include_tools=True)
  3146. # 获取简化模式
  3147. simple_result = get_conversation_detail_from_redis(thread_id, include_tools=False)
  3148. if not (full_result['success'] and simple_result['success']):
  3149. return jsonify({
  3150. "success": False,
  3151. "error": "无法获取对话数据进行比较",
  3152. "timestamp": datetime.now().isoformat()
  3153. }), 404
  3154. # 构建比较结果
  3155. comparison = {
  3156. "thread_id": thread_id,
  3157. "full_mode": {
  3158. "message_count": full_result['data']['message_count'],
  3159. "stats": full_result['data']['stats'],
  3160. "sample_messages": full_result['data']['messages'][:3] # 只显示前3条作为示例
  3161. },
  3162. "simple_mode": {
  3163. "message_count": simple_result['data']['message_count'],
  3164. "stats": simple_result['data']['stats'],
  3165. "sample_messages": simple_result['data']['messages'][:3] # 只显示前3条作为示例
  3166. },
  3167. "comparison_summary": {
  3168. "message_count_difference": full_result['data']['message_count'] - simple_result['data']['message_count'],
  3169. "tools_filtered_out": full_result['data']['stats'].get('tool_messages', 0),
  3170. "ai_messages_with_tools": full_result['data']['stats'].get('messages_with_tools', 0),
  3171. "filtering_effectiveness": "有效" if (full_result['data']['message_count'] - simple_result['data']['message_count']) > 0 else "无差异"
  3172. },
  3173. "metadata": {
  3174. "timestamp": datetime.now().isoformat(),
  3175. "note": "sample_messages 只显示前3条消息作为示例,完整消息请使用相应的详情API"
  3176. }
  3177. }
  3178. logger.info(f"✅ 模式比较完成 - 完整: {comparison['full_mode']['message_count']}, 简化: {comparison['simple_mode']['message_count']}")
  3179. return jsonify({
  3180. "success": True,
  3181. "data": comparison,
  3182. "timestamp": datetime.now().isoformat()
  3183. }), 200
  3184. except Exception as e:
  3185. logger.error(f"❌ 对话模式比较失败: {e}")
  3186. return jsonify({
  3187. "success": False,
  3188. "error": str(e),
  3189. "timestamp": datetime.now().isoformat()
  3190. }), 500
  3191. @app.route('/api/v0/react/direct/conversations/<thread_id>/summary', methods=['GET'])
  3192. def get_conversation_summary_api(thread_id: str):
  3193. """
  3194. 获取对话摘要信息(只包含基本统计,不返回具体消息)(从 custom_react_agent 迁移)
  3195. Query Parameters:
  3196. - include_tools: bool, 影响统计信息的计算方式
  3197. Examples:
  3198. GET /api/conversations/wang:20250709195048728/summary?include_tools=true
  3199. """
  3200. try:
  3201. include_tools = request.args.get('include_tools', 'false').lower() == 'true'
  3202. # 验证thread_id格式
  3203. if ':' not in thread_id:
  3204. return jsonify(bad_request_response(
  3205. response_text="Invalid thread_id format. Expected format: user_id:timestamp"
  3206. )), 400
  3207. logger.info(f"📊 获取对话摘要 - Thread: {thread_id}, Include Tools: {include_tools}")
  3208. # 检查enhanced_redis_api是否可用
  3209. if get_conversation_detail_from_redis is None:
  3210. return jsonify(service_unavailable_response(
  3211. response_text="enhanced_redis_api 模块不可用"
  3212. )), 503
  3213. # 获取完整对话信息
  3214. result = get_conversation_detail_from_redis(thread_id, include_tools)
  3215. if not result['success']:
  3216. return jsonify(internal_error_response(
  3217. response_text=result['error']
  3218. )), 404
  3219. # 只返回摘要信息,不包含具体消息
  3220. data = result['data']
  3221. summary = {
  3222. "thread_id": data['thread_id'],
  3223. "user_id": data['user_id'],
  3224. "include_tools": data['include_tools'],
  3225. "message_count": data['message_count'],
  3226. "stats": data['stats'],
  3227. "metadata": data['metadata'],
  3228. "first_message_preview": None,
  3229. "last_message_preview": None,
  3230. "conversation_preview": None
  3231. }
  3232. # 添加消息预览
  3233. messages = data.get('messages', [])
  3234. if messages:
  3235. # 第一条human消息预览
  3236. for msg in messages:
  3237. if msg['role'] == 'human':
  3238. content = str(msg['content'])
  3239. summary['first_message_preview'] = content[:100] + "..." if len(content) > 100 else content
  3240. break
  3241. # 最后一条ai消息预览
  3242. for msg in reversed(messages):
  3243. if msg['role'] == 'ai' and msg.get('content', '').strip():
  3244. content = str(msg['content'])
  3245. summary['last_message_preview'] = content[:100] + "..." if len(content) > 100 else content
  3246. break
  3247. # 生成对话预览(第一条human消息)
  3248. summary['conversation_preview'] = summary['first_message_preview']
  3249. # 添加API元数据
  3250. summary['api_metadata'] = {
  3251. "timestamp": datetime.now().isoformat(),
  3252. "api_version": "v1",
  3253. "endpoint": "get_conversation_summary"
  3254. }
  3255. logger.info(f"✅ 成功获取对话摘要")
  3256. return jsonify(success_response(
  3257. response_text="获取对话摘要成功",
  3258. data=summary
  3259. )), 200
  3260. except Exception as e:
  3261. logger.error(f"❌ 获取对话摘要失败: {e}")
  3262. return jsonify(internal_error_response(
  3263. response_text=f"获取对话摘要失败: {str(e)}"
  3264. )), 500
  3265. # ================== Checkpoint 管理 API ==================
  3266. @app.route('/api/v0/checkpoint/direct/cleanup', methods=['POST'])
  3267. async def cleanup_checkpoints():
  3268. """
  3269. 清理checkpoint,保留最近N个
  3270. 请求参数:
  3271. - keep_count: 可选,保留数量,默认使用配置值
  3272. - user_id: 可选,指定用户ID
  3273. - thread_id: 可选,指定线程ID
  3274. 参数逻辑:
  3275. - 无任何参数:清理所有thread_id的checkpoint
  3276. - 只有user_id:清理指定用户的所有thread
  3277. - 只有thread_id:清理指定的thread
  3278. - user_id和thread_id同时存在:以thread_id为准
  3279. """
  3280. try:
  3281. # 获取请求参数
  3282. data = request.get_json() or {}
  3283. keep_count = data.get('keep_count', react_agent_config.CHECKPOINT_KEEP_COUNT)
  3284. user_id = data.get('user_id')
  3285. thread_id = data.get('thread_id')
  3286. logger.info(f"🧹 开始checkpoint清理 - keep_count: {keep_count}, user_id: {user_id}, thread_id: {thread_id}")
  3287. # 参数验证
  3288. if keep_count <= 0:
  3289. return jsonify(bad_request_response(
  3290. response_text="keep_count必须大于0"
  3291. )), 400
  3292. # 验证thread_id格式
  3293. if thread_id and ':' not in thread_id:
  3294. return jsonify(bad_request_response(
  3295. response_text="thread_id格式错误,期望格式: user_id:timestamp"
  3296. )), 400
  3297. # 创建Redis连接(异步版本)
  3298. redis_client = redis.Redis(
  3299. host=react_agent_config.REDIS_HOST,
  3300. port=react_agent_config.REDIS_PORT,
  3301. db=react_agent_config.REDIS_DB,
  3302. password=react_agent_config.REDIS_PASSWORD,
  3303. decode_responses=True
  3304. )
  3305. await redis_client.ping()
  3306. # 确定扫描模式和操作类型
  3307. if thread_id:
  3308. # 清理指定thread
  3309. pattern = f"checkpoint:{thread_id}:*"
  3310. operation_type = "cleanup_thread"
  3311. target = thread_id
  3312. elif user_id:
  3313. # 清理指定用户的所有thread
  3314. pattern = f"checkpoint:{user_id}:*"
  3315. operation_type = "cleanup_user"
  3316. target = user_id
  3317. else:
  3318. # 清理所有thread
  3319. pattern = "checkpoint:*"
  3320. operation_type = "cleanup_all"
  3321. target = "all"
  3322. logger.info(f" 扫描模式: {pattern}")
  3323. # 扫描匹配的keys
  3324. keys = []
  3325. cursor = 0
  3326. while True:
  3327. cursor, batch = await redis_client.scan(cursor=cursor, match=pattern, count=1000)
  3328. keys.extend(batch)
  3329. if cursor == 0:
  3330. break
  3331. logger.info(f" 找到 {len(keys)} 个checkpoint keys")
  3332. if not keys:
  3333. await redis_client.close()
  3334. return jsonify(success_response(
  3335. response_text="未找到需要清理的checkpoint",
  3336. data={
  3337. "operation_type": operation_type,
  3338. "target": target,
  3339. "keep_count": keep_count,
  3340. "total_processed": 0,
  3341. "total_deleted": 0,
  3342. "details": {}
  3343. }
  3344. )), 200
  3345. # 按thread_id分组
  3346. thread_groups = {}
  3347. for key in keys:
  3348. parts = key.split(':')
  3349. if len(parts) >= 3:
  3350. key_user_id = parts[1]
  3351. timestamp = parts[2]
  3352. key_thread_id = f"{key_user_id}:{timestamp}"
  3353. if key_thread_id not in thread_groups:
  3354. thread_groups[key_thread_id] = []
  3355. thread_groups[key_thread_id].append(key)
  3356. logger.info(f" 分组结果: {len(thread_groups)} 个threads")
  3357. # 清理每个thread的checkpoint
  3358. details = {}
  3359. total_deleted = 0
  3360. total_processed = 0
  3361. for tid, tid_keys in thread_groups.items():
  3362. original_count = len(tid_keys)
  3363. if original_count <= keep_count:
  3364. # 无需清理
  3365. details[tid] = {
  3366. "original_count": original_count,
  3367. "deleted_count": 0,
  3368. "remaining_count": original_count,
  3369. "status": "no_cleanup_needed"
  3370. }
  3371. total_processed += 1
  3372. continue
  3373. # 按key排序(key包含timestamp,天然有序)
  3374. tid_keys.sort()
  3375. keys_to_delete = tid_keys[:-keep_count]
  3376. # 使用Redis Pipeline批量删除
  3377. deleted_count = 0
  3378. if keys_to_delete:
  3379. try:
  3380. pipeline = redis_client.pipeline()
  3381. for key in keys_to_delete:
  3382. pipeline.delete(key)
  3383. await pipeline.execute()
  3384. deleted_count = len(keys_to_delete)
  3385. logger.info(f" Thread {tid}: 删除了 {deleted_count} 个checkpoint")
  3386. except Exception as e:
  3387. logger.error(f" Thread {tid}: 批量删除失败: {e}")
  3388. # 尝试逐个删除
  3389. for key in keys_to_delete:
  3390. try:
  3391. await redis_client.delete(key)
  3392. deleted_count += 1
  3393. except Exception as del_error:
  3394. logger.error(f" 删除key失败: {key}, 错误: {del_error}")
  3395. details[tid] = {
  3396. "original_count": original_count,
  3397. "deleted_count": deleted_count,
  3398. "remaining_count": original_count - deleted_count,
  3399. "status": "success" if deleted_count > 0 else "failed"
  3400. }
  3401. total_deleted += deleted_count
  3402. total_processed += 1
  3403. await redis_client.aclose()
  3404. logger.info(f"✅ Checkpoint清理完成 - 处理{total_processed}个threads,删除{total_deleted}个checkpoints")
  3405. return jsonify(success_response(
  3406. response_text=f"Checkpoint清理完成,删除了{total_deleted}个checkpoint",
  3407. data={
  3408. "operation_type": operation_type,
  3409. "target": target,
  3410. "keep_count": keep_count,
  3411. "total_processed": total_processed,
  3412. "total_deleted": total_deleted,
  3413. "details": details
  3414. }
  3415. )), 200
  3416. except redis.ConnectionError as e:
  3417. logger.error(f"❌ Redis连接失败: {e}")
  3418. return jsonify(internal_error_response(
  3419. response_text="Redis连接失败,请检查Redis服务状态"
  3420. )), 500
  3421. except Exception as e:
  3422. logger.error(f"❌ Checkpoint清理失败: {e}")
  3423. return jsonify(internal_error_response(
  3424. response_text=f"Checkpoint清理失败: {str(e)}"
  3425. )), 500
  3426. @app.route('/api/v0/checkpoint/direct/stats', methods=['GET'])
  3427. async def get_checkpoint_stats():
  3428. """
  3429. 获取checkpoint统计信息
  3430. 查询参数:
  3431. - user_id: 可选,指定用户ID
  3432. 调用方式:
  3433. GET /api/v0/checkpoint/direct/stats # 获取全部统计信息
  3434. GET /api/v0/checkpoint/direct/stats?user_id=wang1 # 获取指定用户统计信息
  3435. """
  3436. try:
  3437. user_id = request.args.get('user_id')
  3438. logger.info(f"📊 获取checkpoint统计 - user_id: {user_id}")
  3439. # 创建Redis连接(异步版本)
  3440. redis_client = redis.Redis(
  3441. host=react_agent_config.REDIS_HOST,
  3442. port=react_agent_config.REDIS_PORT,
  3443. db=react_agent_config.REDIS_DB,
  3444. password=react_agent_config.REDIS_PASSWORD,
  3445. decode_responses=True
  3446. )
  3447. await redis_client.ping()
  3448. # 确定扫描模式
  3449. if user_id:
  3450. pattern = f"checkpoint:{user_id}:*"
  3451. operation_type = "user_stats"
  3452. else:
  3453. pattern = "checkpoint:*"
  3454. operation_type = "system_stats"
  3455. logger.info(f" 扫描模式: {pattern}")
  3456. # 扫描匹配的keys
  3457. keys = []
  3458. cursor = 0
  3459. while True:
  3460. cursor, batch = await redis_client.scan(cursor=cursor, match=pattern, count=1000)
  3461. keys.extend(batch)
  3462. if cursor == 0:
  3463. break
  3464. logger.info(f" 找到 {len(keys)} 个checkpoint keys")
  3465. await redis_client.aclose()
  3466. if not keys:
  3467. if user_id:
  3468. return jsonify(not_found_response(
  3469. response_text=f"用户 {user_id} 没有任何checkpoint"
  3470. )), 404
  3471. else:
  3472. return jsonify(success_response(
  3473. response_text="系统中暂无checkpoint数据",
  3474. data={
  3475. "operation_type": operation_type,
  3476. "total_users": 0,
  3477. "total_threads": 0,
  3478. "total_checkpoints": 0,
  3479. "users": []
  3480. }
  3481. )), 200
  3482. # 按用户和thread分组统计
  3483. user_stats = {}
  3484. for key in keys:
  3485. parts = key.split(':')
  3486. if len(parts) >= 3:
  3487. key_user_id = parts[1]
  3488. timestamp = parts[2]
  3489. thread_id = f"{key_user_id}:{timestamp}"
  3490. if key_user_id not in user_stats:
  3491. user_stats[key_user_id] = {}
  3492. if thread_id not in user_stats[key_user_id]:
  3493. user_stats[key_user_id][thread_id] = 0
  3494. user_stats[key_user_id][thread_id] += 1
  3495. # 构建响应数据
  3496. if user_id:
  3497. # 返回指定用户的统计信息
  3498. if user_id not in user_stats:
  3499. return jsonify(not_found_response(
  3500. response_text=f"用户 {user_id} 没有任何checkpoint"
  3501. )), 404
  3502. threads = []
  3503. total_checkpoints = 0
  3504. for thread_id, count in user_stats[user_id].items():
  3505. threads.append({
  3506. "thread_id": thread_id,
  3507. "checkpoint_count": count
  3508. })
  3509. total_checkpoints += count
  3510. # 按checkpoint数量排序
  3511. threads.sort(key=lambda x: x["checkpoint_count"], reverse=True)
  3512. result_data = {
  3513. "operation_type": operation_type,
  3514. "user_id": user_id,
  3515. "thread_count": len(threads),
  3516. "total_checkpoints": total_checkpoints,
  3517. "threads": threads
  3518. }
  3519. logger.info(f"✅ 获取用户 {user_id} 统计完成 - {len(threads)} threads, {total_checkpoints} checkpoints")
  3520. return jsonify(success_response(
  3521. response_text=f"获取用户{user_id}统计成功",
  3522. data=result_data
  3523. )), 200
  3524. else:
  3525. # 返回系统全部统计信息
  3526. users = []
  3527. total_threads = 0
  3528. total_checkpoints = 0
  3529. for uid, threads_data in user_stats.items():
  3530. user_threads = []
  3531. user_total_checkpoints = 0
  3532. for thread_id, count in threads_data.items():
  3533. user_threads.append({
  3534. "thread_id": thread_id,
  3535. "checkpoint_count": count
  3536. })
  3537. user_total_checkpoints += count
  3538. # 按checkpoint数量排序
  3539. user_threads.sort(key=lambda x: x["checkpoint_count"], reverse=True)
  3540. users.append({
  3541. "user_id": uid,
  3542. "thread_count": len(user_threads),
  3543. "total_checkpoints": user_total_checkpoints,
  3544. "threads": user_threads
  3545. })
  3546. total_threads += len(user_threads)
  3547. total_checkpoints += user_total_checkpoints
  3548. # 按用户的checkpoint数量排序
  3549. users.sort(key=lambda x: x["total_checkpoints"], reverse=True)
  3550. result_data = {
  3551. "operation_type": operation_type,
  3552. "total_users": len(users),
  3553. "total_threads": total_threads,
  3554. "total_checkpoints": total_checkpoints,
  3555. "users": users
  3556. }
  3557. logger.info(f"✅ 获取系统统计完成 - {len(users)} users, {total_threads} threads, {total_checkpoints} checkpoints")
  3558. return jsonify(success_response(
  3559. response_text="获取系统checkpoint统计成功",
  3560. data=result_data
  3561. )), 200
  3562. except redis.ConnectionError as e:
  3563. logger.error(f"❌ Redis连接失败: {e}")
  3564. return jsonify(internal_error_response(
  3565. response_text="Redis连接失败,请检查Redis服务状态"
  3566. )), 500
  3567. except Exception as e:
  3568. logger.error(f"❌ 获取checkpoint统计失败: {e}")
  3569. return jsonify(internal_error_response(
  3570. response_text=f"获取checkpoint统计失败: {str(e)}"
  3571. )), 500
  3572. # Data Pipeline 全局变量 - 从 citu_app.py 迁移
  3573. data_pipeline_manager = None
  3574. data_pipeline_file_manager = None
  3575. def get_data_pipeline_manager():
  3576. """获取Data Pipeline管理器单例(从 citu_app.py 迁移)"""
  3577. global data_pipeline_manager
  3578. if data_pipeline_manager is None:
  3579. data_pipeline_manager = SimpleWorkflowManager()
  3580. return data_pipeline_manager
  3581. def get_data_pipeline_file_manager():
  3582. """获取Data Pipeline文件管理器单例(从 citu_app.py 迁移)"""
  3583. global data_pipeline_file_manager
  3584. if data_pipeline_file_manager is None:
  3585. data_pipeline_file_manager = SimpleFileManager()
  3586. return data_pipeline_file_manager
  3587. # ==================== QA缓存管理API (从 citu_app.py 迁移) ====================
  3588. @app.route('/api/v0/qa_cache_stats', methods=['GET'])
  3589. def qa_cache_stats():
  3590. """获取问答缓存统计信息(从 citu_app.py 迁移)"""
  3591. try:
  3592. stats = redis_conversation_manager.get_qa_cache_stats()
  3593. return jsonify(success_response(
  3594. response_text="获取问答缓存统计成功",
  3595. data=stats
  3596. ))
  3597. except Exception as e:
  3598. logger.error(f"获取问答缓存统计失败: {str(e)}")
  3599. return jsonify(internal_error_response(
  3600. response_text="获取问答缓存统计失败,请稍后重试"
  3601. )), 500
  3602. @app.route('/api/v0/qa_cache_cleanup', methods=['POST'])
  3603. def qa_cache_cleanup():
  3604. """清空所有问答缓存(从 citu_app.py 迁移)"""
  3605. try:
  3606. if not redis_conversation_manager.is_available():
  3607. return jsonify(internal_error_response(
  3608. response_text="Redis连接不可用,无法执行清理操作"
  3609. )), 500
  3610. deleted_count = redis_conversation_manager.clear_all_qa_cache()
  3611. return jsonify(success_response(
  3612. response_text="问答缓存清理完成",
  3613. data={
  3614. "deleted_count": deleted_count,
  3615. "cleared": deleted_count > 0,
  3616. "cleanup_time": datetime.now().isoformat()
  3617. }
  3618. ))
  3619. except Exception as e:
  3620. logger.error(f"清空问答缓存失败: {str(e)}")
  3621. return jsonify(internal_error_response(
  3622. response_text="清空问答缓存失败,请稍后重试"
  3623. )), 500
  3624. # ==================== Database API (从 citu_app.py 迁移) ====================
  3625. @app.route('/api/v0/database/tables', methods=['POST'])
  3626. def get_database_tables():
  3627. """
  3628. 获取数据库表列表(从 citu_app.py 迁移)
  3629. 请求体:
  3630. {
  3631. "db_connection": "postgresql://postgres:postgres@192.168.67.1:5432/highway_db", // 可选,不传则使用默认配置
  3632. "schema": "public,ods", // 可选,支持多个schema用逗号分隔,默认为public
  3633. "table_name_pattern": "ods_*" // 可选,表名模式匹配,支持通配符:ods_*、*_dim、*fact*、ods_%
  3634. }
  3635. 响应:
  3636. {
  3637. "success": true,
  3638. "code": 200,
  3639. "message": "获取表列表成功",
  3640. "data": {
  3641. "tables": ["public.table1", "public.table2", "ods.table3"],
  3642. "total": 3,
  3643. "schemas": ["public", "ods"],
  3644. "table_name_pattern": "ods_*"
  3645. }
  3646. }
  3647. """
  3648. try:
  3649. req = request.get_json(force=True)
  3650. # 处理数据库连接参数(可选)
  3651. db_connection = req.get('db_connection')
  3652. if not db_connection:
  3653. # 使用app_config的默认数据库配置
  3654. import app_config
  3655. db_params = app_config.APP_DB_CONFIG
  3656. db_connection = f"postgresql://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
  3657. logger.info("使用默认数据库配置获取表列表")
  3658. else:
  3659. logger.info("使用用户指定的数据库配置获取表列表")
  3660. # 可选参数
  3661. schema = req.get('schema', '')
  3662. table_name_pattern = req.get('table_name_pattern')
  3663. # 创建表检查API实例
  3664. table_inspector = TableInspectorAPI()
  3665. # 使用asyncio运行异步方法
  3666. async def get_tables():
  3667. return await table_inspector.get_tables_list(db_connection, schema, table_name_pattern)
  3668. # 在新的事件循环中运行异步方法
  3669. try:
  3670. loop = asyncio.new_event_loop()
  3671. asyncio.set_event_loop(loop)
  3672. tables = loop.run_until_complete(get_tables())
  3673. finally:
  3674. loop.close()
  3675. # 解析schema信息
  3676. parsed_schemas = table_inspector._parse_schemas(schema)
  3677. response_data = {
  3678. "tables": tables,
  3679. "total": len(tables),
  3680. "schemas": parsed_schemas,
  3681. "db_connection_info": {
  3682. "database": db_connection.split('/')[-1].split('?')[0] if '/' in db_connection else "unknown"
  3683. }
  3684. }
  3685. # 如果使用了表名模式,添加到响应中
  3686. if table_name_pattern:
  3687. response_data["table_name_pattern"] = table_name_pattern
  3688. return jsonify(success_response(
  3689. response_text="获取表列表成功",
  3690. data=response_data
  3691. )), 200
  3692. except Exception as e:
  3693. logger.error(f"获取数据库表列表失败: {str(e)}")
  3694. return jsonify(internal_error_response(
  3695. response_text=f"获取表列表失败: {str(e)}"
  3696. )), 500
  3697. @app.route('/api/v0/database/table/ddl', methods=['POST'])
  3698. def get_table_ddl():
  3699. """
  3700. 获取表的DDL语句或MD文档(从 citu_app.py 迁移)
  3701. 请求体:
  3702. {
  3703. "db_connection": "postgresql://postgres:postgres@192.168.67.1:5432/highway_db", // 可选,不传则使用默认配置
  3704. "table": "public.test",
  3705. "business_context": "这是高速公路服务区的相关数据", // 可选
  3706. "type": "ddl" // 可选,支持ddl/md/both,默认为ddl
  3707. }
  3708. 响应:
  3709. {
  3710. "success": true,
  3711. "code": 200,
  3712. "message": "获取表DDL成功",
  3713. "data": {
  3714. "ddl": "create table public.test (...);",
  3715. "md": "## test表...", // 仅当type为md或both时返回
  3716. "table_info": {
  3717. "table_name": "test",
  3718. "schema_name": "public",
  3719. "full_name": "public.test",
  3720. "comment": "测试表",
  3721. "field_count": 10,
  3722. "row_count": 1000
  3723. },
  3724. "fields": [...]
  3725. }
  3726. }
  3727. """
  3728. try:
  3729. req = request.get_json(force=True)
  3730. # 处理参数(table仍为必需,db_connection可选)
  3731. table = req.get('table')
  3732. db_connection = req.get('db_connection')
  3733. if not table:
  3734. return jsonify(bad_request_response(
  3735. response_text="缺少必需参数:table",
  3736. missing_params=['table']
  3737. )), 400
  3738. if not db_connection:
  3739. # 使用app_config的默认数据库配置
  3740. import app_config
  3741. db_params = app_config.APP_DB_CONFIG
  3742. db_connection = f"postgresql://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
  3743. logger.info("使用默认数据库配置获取表DDL")
  3744. else:
  3745. logger.info("使用用户指定的数据库配置获取表DDL")
  3746. # 可选参数
  3747. business_context = req.get('business_context', '')
  3748. output_type = req.get('type', 'ddl')
  3749. # 验证type参数
  3750. valid_types = ['ddl', 'md', 'both']
  3751. if output_type not in valid_types:
  3752. return jsonify(bad_request_response(
  3753. response_text=f"无效的type参数: {output_type},支持的值: {valid_types}",
  3754. invalid_params=['type']
  3755. )), 400
  3756. # 创建表检查API实例
  3757. table_inspector = TableInspectorAPI()
  3758. # 使用asyncio运行异步方法
  3759. async def get_ddl():
  3760. return await table_inspector.get_table_ddl(
  3761. db_connection=db_connection,
  3762. table=table,
  3763. business_context=business_context,
  3764. output_type=output_type
  3765. )
  3766. # 在新的事件循环中运行异步方法
  3767. try:
  3768. loop = asyncio.new_event_loop()
  3769. asyncio.set_event_loop(loop)
  3770. result = loop.run_until_complete(get_ddl())
  3771. finally:
  3772. loop.close()
  3773. response_data = {
  3774. **result,
  3775. "generation_info": {
  3776. "business_context": business_context,
  3777. "output_type": output_type,
  3778. "has_llm_comments": bool(business_context),
  3779. "database": db_connection.split('/')[-1].split('?')[0] if '/' in db_connection else "unknown"
  3780. }
  3781. }
  3782. return jsonify(success_response(
  3783. response_text=f"获取表{output_type.upper()}成功",
  3784. data=response_data
  3785. )), 200
  3786. except Exception as e:
  3787. logger.error(f"获取表DDL失败: {str(e)}")
  3788. return jsonify(internal_error_response(
  3789. response_text=f"获取表{output_type.upper() if 'output_type' in locals() else 'DDL'}失败: {str(e)}"
  3790. )), 500
  3791. # ==================== Data Pipeline API (从 citu_app.py 迁移) ====================
  3792. @app.route('/api/v0/data_pipeline/tasks', methods=['POST'])
  3793. def create_data_pipeline_task():
  3794. """创建数据管道任务(从 citu_app.py 迁移)"""
  3795. try:
  3796. req = request.get_json(force=True)
  3797. # table_list_file和business_context现在都是可选参数
  3798. # 如果未提供table_list_file,将使用文件上传模式
  3799. # 创建任务(支持可选的db_connection参数)
  3800. manager = get_data_pipeline_manager()
  3801. task_id = manager.create_task(
  3802. table_list_file=req.get('table_list_file'),
  3803. business_context=req.get('business_context'),
  3804. db_name=req.get('db_name'), # 可选参数,用于指定特定数据库名称
  3805. db_connection=req.get('db_connection'), # 可选参数,用于指定数据库连接字符串
  3806. task_name=req.get('task_name'), # 可选参数,用于指定任务名称
  3807. enable_sql_validation=req.get('enable_sql_validation', True),
  3808. enable_llm_repair=req.get('enable_llm_repair', True),
  3809. modify_original_file=req.get('modify_original_file', True),
  3810. enable_training_data_load=req.get('enable_training_data_load', True)
  3811. )
  3812. # 获取任务信息
  3813. task_info = manager.get_task_status(task_id)
  3814. response_data = {
  3815. "task_id": task_id,
  3816. "task_name": task_info.get('task_name'),
  3817. "status": task_info.get('status'),
  3818. "created_at": task_info.get('created_at').isoformat() if task_info.get('created_at') else None
  3819. }
  3820. # 检查是否为文件上传模式
  3821. file_upload_mode = not req.get('table_list_file')
  3822. response_message = "任务创建成功"
  3823. if file_upload_mode:
  3824. response_data["file_upload_mode"] = True
  3825. response_data["next_step"] = f"POST /api/v0/data_pipeline/tasks/{task_id}/upload-table-list"
  3826. response_message += ",请上传表清单文件后再执行任务"
  3827. return jsonify(success_response(
  3828. response_text=response_message,
  3829. data=response_data
  3830. )), 201
  3831. except Exception as e:
  3832. logger.error(f"创建数据管道任务失败: {str(e)}")
  3833. return jsonify(internal_error_response(
  3834. response_text="创建任务失败,请稍后重试"
  3835. )), 500
  3836. @app.route('/api/v0/data_pipeline/tasks/<task_id>/execute', methods=['POST'])
  3837. def execute_data_pipeline_task(task_id):
  3838. """执行数据管道任务(从 citu_app.py 迁移)"""
  3839. try:
  3840. req = request.get_json(force=True) if request.is_json else {}
  3841. execution_mode = req.get('execution_mode', 'complete')
  3842. step_name = req.get('step_name')
  3843. # 新增:Vector表管理参数
  3844. backup_vector_tables = req.get('backup_vector_tables', False)
  3845. truncate_vector_tables = req.get('truncate_vector_tables', False)
  3846. skip_training = req.get('skip_training', False)
  3847. # 验证执行模式
  3848. if execution_mode not in ['complete', 'step']:
  3849. return jsonify(bad_request_response(
  3850. response_text="无效的执行模式,必须是 'complete' 或 'step'",
  3851. invalid_params=['execution_mode']
  3852. )), 400
  3853. # 如果是步骤执行模式,验证步骤名称
  3854. if execution_mode == 'step':
  3855. if not step_name:
  3856. return jsonify(bad_request_response(
  3857. response_text="步骤执行模式需要指定step_name",
  3858. missing_params=['step_name']
  3859. )), 400
  3860. valid_steps = ['ddl_generation', 'qa_generation', 'sql_validation', 'training_load']
  3861. if step_name not in valid_steps:
  3862. return jsonify(bad_request_response(
  3863. response_text=f"无效的步骤名称,支持的步骤: {', '.join(valid_steps)}",
  3864. invalid_params=['step_name']
  3865. )), 400
  3866. # 新增:Vector表管理参数验证和警告
  3867. if execution_mode == 'step' and step_name != 'training_load':
  3868. if backup_vector_tables or truncate_vector_tables or skip_training:
  3869. logger.warning(
  3870. f"⚠️ Vector表管理参数仅在training_load步骤有效,当前步骤: {step_name},忽略参数"
  3871. )
  3872. backup_vector_tables = False
  3873. truncate_vector_tables = False
  3874. skip_training = False
  3875. # 检查任务是否存在
  3876. manager = get_data_pipeline_manager()
  3877. task_info = manager.get_task_status(task_id)
  3878. if not task_info:
  3879. return jsonify(not_found_response(
  3880. response_text=f"任务不存在: {task_id}"
  3881. )), 404
  3882. # 使用subprocess启动独立进程执行任务
  3883. def run_task_subprocess():
  3884. try:
  3885. import subprocess
  3886. import sys
  3887. from pathlib import Path
  3888. # 构建执行命令
  3889. python_executable = sys.executable
  3890. script_path = Path(__file__).parent / "data_pipeline" / "task_executor.py"
  3891. cmd = [
  3892. python_executable,
  3893. str(script_path),
  3894. "--task-id", task_id,
  3895. "--execution-mode", execution_mode
  3896. ]
  3897. if step_name:
  3898. cmd.extend(["--step-name", step_name])
  3899. # 新增:Vector表管理参数传递
  3900. if backup_vector_tables:
  3901. cmd.append("--backup-vector-tables")
  3902. if truncate_vector_tables:
  3903. cmd.append("--truncate-vector-tables")
  3904. if skip_training:
  3905. cmd.append("--skip-training")
  3906. logger.info(f"启动任务进程: {' '.join(cmd)}")
  3907. # 启动后台进程(不等待完成)
  3908. process = subprocess.Popen(
  3909. cmd,
  3910. stdout=subprocess.PIPE,
  3911. stderr=subprocess.PIPE,
  3912. text=True,
  3913. cwd=Path(__file__).parent
  3914. )
  3915. logger.info(f"任务进程已启动: PID={process.pid}, task_id={task_id}")
  3916. except Exception as e:
  3917. logger.error(f"启动任务进程失败: {task_id}, 错误: {str(e)}")
  3918. # 在新线程中启动subprocess(避免阻塞API响应)
  3919. thread = Thread(target=run_task_subprocess, daemon=True)
  3920. thread.start()
  3921. # 新增:记录Vector表管理参数到日志
  3922. if backup_vector_tables or truncate_vector_tables:
  3923. logger.info(f"📋 API请求包含Vector表管理参数: backup={backup_vector_tables}, truncate={truncate_vector_tables}")
  3924. response_data = {
  3925. "task_id": task_id,
  3926. "execution_mode": execution_mode,
  3927. "step_name": step_name if execution_mode == 'step' else None,
  3928. "message": "任务正在后台执行,请通过状态接口查询进度"
  3929. }
  3930. return jsonify(success_response(
  3931. response_text="任务执行已启动",
  3932. data=response_data
  3933. )), 202
  3934. except Exception as e:
  3935. logger.error(f"启动数据管道任务执行失败: {str(e)}")
  3936. return jsonify(internal_error_response(
  3937. response_text="启动任务执行失败,请稍后重试"
  3938. )), 500
  3939. @app.route('/api/v0/data_pipeline/tasks/<task_id>', methods=['GET'])
  3940. def get_data_pipeline_task_status(task_id):
  3941. """
  3942. 获取数据管道任务状态(从 citu_app.py 迁移)
  3943. 响应:
  3944. {
  3945. "success": true,
  3946. "code": 200,
  3947. "message": "获取任务状态成功",
  3948. "data": {
  3949. "task_id": "task_20250627_143052",
  3950. "status": "in_progress",
  3951. "step_status": {
  3952. "ddl_generation": "completed",
  3953. "qa_generation": "running",
  3954. "sql_validation": "pending",
  3955. "training_load": "pending"
  3956. },
  3957. "created_at": "2025-06-27T14:30:52",
  3958. "started_at": "2025-06-27T14:31:00",
  3959. "parameters": {...},
  3960. "current_execution": {...},
  3961. "total_executions": 2
  3962. }
  3963. }
  3964. """
  3965. try:
  3966. manager = get_data_pipeline_manager()
  3967. task_info = manager.get_task_status(task_id)
  3968. if not task_info:
  3969. return jsonify(not_found_response(
  3970. response_text=f"任务不存在: {task_id}"
  3971. )), 404
  3972. # 获取步骤状态
  3973. steps = manager.get_task_steps(task_id)
  3974. current_step = None
  3975. for step in steps:
  3976. if step['step_status'] == 'running':
  3977. current_step = step
  3978. break
  3979. # 构建步骤状态摘要
  3980. step_status_summary = {}
  3981. for step in steps:
  3982. step_status_summary[step['step_name']] = step['step_status']
  3983. response_data = {
  3984. "task_id": task_info['task_id'],
  3985. "task_name": task_info.get('task_name'),
  3986. "status": task_info['status'],
  3987. "step_status": step_status_summary,
  3988. "created_at": task_info['created_at'].isoformat() if task_info.get('created_at') else None,
  3989. "started_at": task_info['started_at'].isoformat() if task_info.get('started_at') else None,
  3990. "completed_at": task_info['completed_at'].isoformat() if task_info.get('completed_at') else None,
  3991. "parameters": task_info.get('parameters', {}),
  3992. "result": task_info.get('result'),
  3993. "error_message": task_info.get('error_message'),
  3994. "current_step": {
  3995. "execution_id": current_step['execution_id'],
  3996. "step": current_step['step_name'],
  3997. "status": current_step['step_status'],
  3998. "started_at": current_step['started_at'].isoformat() if current_step and current_step.get('started_at') else None
  3999. } if current_step else None,
  4000. "total_steps": len(steps),
  4001. "steps": [{
  4002. "step_name": step['step_name'],
  4003. "step_status": step['step_status'],
  4004. "started_at": step['started_at'].isoformat() if step.get('started_at') else None,
  4005. "completed_at": step['completed_at'].isoformat() if step.get('completed_at') else None,
  4006. "error_message": step.get('error_message')
  4007. } for step in steps]
  4008. }
  4009. return jsonify(success_response(
  4010. response_text="获取任务状态成功",
  4011. data=response_data
  4012. ))
  4013. except Exception as e:
  4014. logger.error(f"获取数据管道任务状态失败: {str(e)}")
  4015. return jsonify(internal_error_response(
  4016. response_text="获取任务状态失败,请稍后重试"
  4017. )), 500
  4018. @app.route('/api/v0/data_pipeline/tasks/<task_id>/logs', methods=['GET'])
  4019. def get_data_pipeline_task_logs(task_id):
  4020. """
  4021. 获取数据管道任务日志(从任务目录文件读取)(从 citu_app.py 迁移)
  4022. 查询参数:
  4023. - limit: 日志行数限制,默认100
  4024. - level: 日志级别过滤,可选
  4025. 响应:
  4026. {
  4027. "success": true,
  4028. "code": 200,
  4029. "message": "获取任务日志成功",
  4030. "data": {
  4031. "task_id": "task_20250627_143052",
  4032. "logs": [
  4033. {
  4034. "timestamp": "2025-06-27 14:30:52",
  4035. "level": "INFO",
  4036. "message": "任务开始执行"
  4037. }
  4038. ],
  4039. "total": 15,
  4040. "source": "file"
  4041. }
  4042. }
  4043. """
  4044. try:
  4045. limit = request.args.get('limit', 100, type=int)
  4046. level = request.args.get('level')
  4047. # 限制最大查询数量
  4048. limit = min(limit, 1000)
  4049. manager = get_data_pipeline_manager()
  4050. # 验证任务是否存在
  4051. task_info = manager.get_task_status(task_id)
  4052. if not task_info:
  4053. return jsonify(not_found_response(
  4054. response_text=f"任务不存在: {task_id}"
  4055. )), 404
  4056. # 获取任务目录下的日志文件
  4057. import os
  4058. from pathlib import Path
  4059. # 获取项目根目录的绝对路径
  4060. project_root = Path(__file__).parent.absolute()
  4061. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  4062. log_file = task_dir / "data_pipeline.log"
  4063. logs = []
  4064. if log_file.exists():
  4065. try:
  4066. # 读取日志文件的最后N行
  4067. with open(log_file, 'r', encoding='utf-8') as f:
  4068. lines = f.readlines()
  4069. # 取最后limit行
  4070. recent_lines = lines[-limit:] if len(lines) > limit else lines
  4071. # 解析日志行
  4072. import re
  4073. log_pattern = r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) \[(\w+)\] (.+?): (.+)$'
  4074. for line in recent_lines:
  4075. line = line.strip()
  4076. if not line:
  4077. continue
  4078. match = re.match(log_pattern, line)
  4079. if match:
  4080. timestamp, log_level, logger_name, message = match.groups()
  4081. # 级别过滤
  4082. if level and log_level != level.upper():
  4083. continue
  4084. logs.append({
  4085. "timestamp": timestamp,
  4086. "level": log_level,
  4087. "logger": logger_name,
  4088. "message": message
  4089. })
  4090. else:
  4091. # 处理多行日志(如异常堆栈)
  4092. if logs:
  4093. logs[-1]["message"] += f"\n{line}"
  4094. except Exception as e:
  4095. logger.error(f"读取日志文件失败: {e}")
  4096. response_data = {
  4097. "task_id": task_id,
  4098. "logs": logs,
  4099. "total": len(logs),
  4100. "source": "file",
  4101. "log_file": str(log_file) if log_file.exists() else None
  4102. }
  4103. return jsonify(success_response(
  4104. response_text="获取任务日志成功",
  4105. data=response_data
  4106. ))
  4107. except Exception as e:
  4108. logger.error(f"获取数据管道任务日志失败: {str(e)}")
  4109. return jsonify(internal_error_response(
  4110. response_text="获取任务日志失败,请稍后重试"
  4111. )), 500
  4112. @app.route('/api/v0/data_pipeline/tasks', methods=['GET'])
  4113. def list_data_pipeline_tasks():
  4114. """获取数据管道任务列表(从 citu_app.py 迁移)"""
  4115. try:
  4116. limit = request.args.get('limit', 50, type=int)
  4117. offset = request.args.get('offset', 0, type=int)
  4118. status_filter = request.args.get('status')
  4119. # 限制查询数量
  4120. limit = min(limit, 100)
  4121. manager = get_data_pipeline_manager()
  4122. tasks = manager.get_tasks_list(
  4123. limit=limit,
  4124. offset=offset,
  4125. status_filter=status_filter
  4126. )
  4127. # 格式化任务列表
  4128. formatted_tasks = []
  4129. for task in tasks:
  4130. formatted_tasks.append({
  4131. "task_id": task.get('task_id'),
  4132. "task_name": task.get('task_name'),
  4133. "status": task.get('status'),
  4134. "step_status": task.get('step_status'),
  4135. "created_at": task['created_at'].isoformat() if task.get('created_at') else None,
  4136. "started_at": task['started_at'].isoformat() if task.get('started_at') else None,
  4137. "completed_at": task['completed_at'].isoformat() if task.get('completed_at') else None,
  4138. "created_by": task.get('by_user'),
  4139. "db_name": task.get('db_name'),
  4140. "business_context": task.get('parameters', {}).get('business_context') if task.get('parameters') else None,
  4141. # 新增字段
  4142. "directory_exists": task.get('directory_exists', True), # 默认为True,兼容旧数据
  4143. "updated_at": task['updated_at'].isoformat() if task.get('updated_at') else None
  4144. })
  4145. response_data = {
  4146. "tasks": formatted_tasks,
  4147. "total": len(formatted_tasks),
  4148. "limit": limit,
  4149. "offset": offset
  4150. }
  4151. return jsonify(success_response(
  4152. response_text="获取任务列表成功",
  4153. data=response_data
  4154. ))
  4155. except Exception as e:
  4156. logger.error(f"获取数据管道任务列表失败: {str(e)}")
  4157. return jsonify(internal_error_response(
  4158. response_text="获取任务列表失败,请稍后重试"
  4159. )), 500
  4160. @app.route('/api/v0/data_pipeline/tasks/query', methods=['POST'])
  4161. def query_data_pipeline_tasks():
  4162. """
  4163. 高级查询数据管道任务列表(从 citu_app.py 迁移)
  4164. 支持复杂筛选、排序、分页功能
  4165. 请求体:
  4166. {
  4167. "page": 1, // 页码,必须大于0,默认1
  4168. "page_size": 20, // 每页大小,1-100之间,默认20
  4169. "status": "completed", // 可选,任务状态筛选:"pending"|"running"|"completed"|"failed"|"cancelled"
  4170. "task_name": "highway", // 可选,任务名称模糊搜索,最大100字符
  4171. "created_by": "user123", // 可选,创建者精确匹配
  4172. "db_name": "highway_db", // 可选,数据库名称精确匹配
  4173. "created_time_start": "2025-01-01T00:00:00", // 可选,创建时间范围开始
  4174. "created_time_end": "2025-12-31T23:59:59", // 可选,创建时间范围结束
  4175. "started_time_start": "2025-01-01T00:00:00", // 可选,开始时间范围开始
  4176. "started_time_end": "2025-12-31T23:59:59", // 可选,开始时间范围结束
  4177. "completed_time_start": "2025-01-01T00:00:00", // 可选,完成时间范围开始
  4178. "completed_time_end": "2025-12-31T23:59:59", // 可选,完成时间范围结束
  4179. "sort_by": "created_at", // 可选,排序字段:"created_at"|"started_at"|"completed_at"|"task_name"|"status",默认"created_at"
  4180. "sort_order": "desc" // 可选,排序方向:"asc"|"desc",默认"desc"
  4181. }
  4182. """
  4183. try:
  4184. # 获取请求数据
  4185. req = request.get_json(force=True) if request.is_json else {}
  4186. # 解析参数,设置默认值
  4187. page = req.get('page', 1)
  4188. page_size = req.get('page_size', 20)
  4189. status = req.get('status')
  4190. task_name = req.get('task_name')
  4191. created_by = req.get('created_by')
  4192. db_name = req.get('db_name')
  4193. created_time_start = req.get('created_time_start')
  4194. created_time_end = req.get('created_time_end')
  4195. started_time_start = req.get('started_time_start')
  4196. started_time_end = req.get('started_time_end')
  4197. completed_time_start = req.get('completed_time_start')
  4198. completed_time_end = req.get('completed_time_end')
  4199. sort_by = req.get('sort_by', 'created_at')
  4200. sort_order = req.get('sort_order', 'desc')
  4201. # 参数验证
  4202. # 验证分页参数
  4203. if page < 1:
  4204. return jsonify(bad_request_response(
  4205. response_text="页码必须大于0",
  4206. invalid_params=['page']
  4207. )), 400
  4208. if page_size < 1 or page_size > 100:
  4209. return jsonify(bad_request_response(
  4210. response_text="每页大小必须在1-100之间",
  4211. invalid_params=['page_size']
  4212. )), 400
  4213. # 验证任务名称长度
  4214. if task_name and len(task_name) > 100:
  4215. return jsonify(bad_request_response(
  4216. response_text="任务名称搜索关键词最大长度为100字符",
  4217. invalid_params=['task_name']
  4218. )), 400
  4219. # 验证排序参数
  4220. allowed_sort_fields = ['created_at', 'started_at', 'completed_at', 'task_name', 'status']
  4221. if sort_by not in allowed_sort_fields:
  4222. return jsonify(bad_request_response(
  4223. response_text=f"不支持的排序字段: {sort_by},支持的字段: {', '.join(allowed_sort_fields)}",
  4224. invalid_params=['sort_by']
  4225. )), 400
  4226. if sort_order.lower() not in ['asc', 'desc']:
  4227. return jsonify(bad_request_response(
  4228. response_text="排序方向必须是 'asc' 或 'desc'",
  4229. invalid_params=['sort_order']
  4230. )), 400
  4231. # 验证状态筛选
  4232. if status:
  4233. allowed_statuses = ['pending', 'running', 'completed', 'failed', 'cancelled']
  4234. if status not in allowed_statuses:
  4235. return jsonify(bad_request_response(
  4236. response_text=f"不支持的状态值: {status},支持的状态: {', '.join(allowed_statuses)}",
  4237. invalid_params=['status']
  4238. )), 400
  4239. # 调用管理器执行查询
  4240. manager = get_data_pipeline_manager()
  4241. result = manager.query_tasks_advanced(
  4242. page=page,
  4243. page_size=page_size,
  4244. status=status,
  4245. task_name=task_name,
  4246. created_by=created_by,
  4247. db_name=db_name,
  4248. created_time_start=created_time_start,
  4249. created_time_end=created_time_end,
  4250. started_time_start=started_time_start,
  4251. started_time_end=started_time_end,
  4252. completed_time_start=completed_time_start,
  4253. completed_time_end=completed_time_end,
  4254. sort_by=sort_by,
  4255. sort_order=sort_order
  4256. )
  4257. # 格式化任务列表
  4258. formatted_tasks = []
  4259. for task in result['tasks']:
  4260. formatted_tasks.append({
  4261. "task_id": task.get('task_id'),
  4262. "task_name": task.get('task_name'),
  4263. "status": task.get('status'),
  4264. "step_status": task.get('step_status'),
  4265. "created_at": task['created_at'].isoformat() if task.get('created_at') else None,
  4266. "started_at": task['started_at'].isoformat() if task.get('started_at') else None,
  4267. "completed_at": task['completed_at'].isoformat() if task.get('completed_at') else None,
  4268. "created_by": task.get('by_user'),
  4269. "db_name": task.get('db_name'),
  4270. "business_context": task.get('parameters', {}).get('business_context') if task.get('parameters') else None,
  4271. "directory_exists": task.get('directory_exists', True),
  4272. "updated_at": task['updated_at'].isoformat() if task.get('updated_at') else None
  4273. })
  4274. # 构建响应数据
  4275. response_data = {
  4276. "tasks": formatted_tasks,
  4277. "pagination": result['pagination'],
  4278. "filters_applied": {
  4279. k: v for k, v in {
  4280. "status": status,
  4281. "task_name": task_name,
  4282. "created_by": created_by,
  4283. "db_name": db_name,
  4284. "created_time_start": created_time_start,
  4285. "created_time_end": created_time_end,
  4286. "started_time_start": started_time_start,
  4287. "started_time_end": started_time_end,
  4288. "completed_time_start": completed_time_start,
  4289. "completed_time_end": completed_time_end
  4290. }.items() if v
  4291. },
  4292. "sort_applied": {
  4293. "sort_by": sort_by,
  4294. "sort_order": sort_order
  4295. },
  4296. "query_time": result.get('query_time', '0.000s')
  4297. }
  4298. return jsonify(success_response(
  4299. response_text="查询任务列表成功",
  4300. data=response_data
  4301. ))
  4302. except Exception as e:
  4303. logger.error(f"查询数据管道任务列表失败: {str(e)}")
  4304. return jsonify(internal_error_response(
  4305. response_text="查询任务列表失败,请稍后重试"
  4306. )), 500
  4307. @app.route('/api/v0/data_pipeline/tasks/<task_id>/files', methods=['GET'])
  4308. def get_data_pipeline_task_files(task_id):
  4309. """获取任务文件列表(从 citu_app.py 迁移)"""
  4310. try:
  4311. file_manager = get_data_pipeline_file_manager()
  4312. # 获取任务文件
  4313. files = file_manager.get_task_files(task_id)
  4314. directory_info = file_manager.get_directory_info(task_id)
  4315. # 格式化文件信息
  4316. formatted_files = []
  4317. for file_info in files:
  4318. formatted_files.append({
  4319. "file_name": file_info['file_name'],
  4320. "file_type": file_info['file_type'],
  4321. "file_size": file_info['file_size'],
  4322. "file_size_formatted": file_info['file_size_formatted'],
  4323. "created_at": file_info['created_at'].isoformat() if file_info.get('created_at') else None,
  4324. "modified_at": file_info['modified_at'].isoformat() if file_info.get('modified_at') else None,
  4325. "is_readable": file_info['is_readable']
  4326. })
  4327. response_data = {
  4328. "task_id": task_id,
  4329. "files": formatted_files,
  4330. "directory_info": directory_info
  4331. }
  4332. return jsonify(success_response(
  4333. response_text="获取任务文件列表成功",
  4334. data=response_data
  4335. ))
  4336. except Exception as e:
  4337. logger.error(f"获取任务文件列表失败: {str(e)}")
  4338. return jsonify(internal_error_response(
  4339. response_text="获取任务文件列表失败,请稍后重试"
  4340. )), 500
  4341. @app.route('/api/v0/data_pipeline/tasks/<task_id>/files/<file_name>', methods=['GET'])
  4342. def download_data_pipeline_task_file(task_id, file_name):
  4343. """下载任务文件(从 citu_app.py 迁移)"""
  4344. try:
  4345. logger.info(f"开始下载文件: task_id={task_id}, file_name={file_name}")
  4346. # 直接构建文件路径,避免依赖数据库
  4347. from pathlib import Path
  4348. import os
  4349. # 获取项目根目录的绝对路径
  4350. project_root = Path(__file__).parent.absolute()
  4351. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  4352. file_path = task_dir / file_name
  4353. logger.info(f"文件路径: {file_path}")
  4354. # 检查文件是否存在
  4355. if not file_path.exists():
  4356. logger.warning(f"文件不存在: {file_path}")
  4357. return jsonify(not_found_response(
  4358. response_text=f"文件不存在: {file_name}"
  4359. )), 404
  4360. # 检查是否为文件(而不是目录)
  4361. if not file_path.is_file():
  4362. logger.warning(f"路径不是文件: {file_path}")
  4363. return jsonify(bad_request_response(
  4364. response_text=f"路径不是有效文件: {file_name}"
  4365. )), 400
  4366. # 安全检查:确保文件在允许的目录内
  4367. try:
  4368. file_path.resolve().relative_to(task_dir.resolve())
  4369. except ValueError:
  4370. logger.warning(f"文件路径不安全: {file_path}")
  4371. return jsonify(bad_request_response(
  4372. response_text="非法的文件路径"
  4373. )), 400
  4374. # 检查文件是否可读
  4375. if not os.access(file_path, os.R_OK):
  4376. logger.warning(f"文件不可读: {file_path}")
  4377. return jsonify(bad_request_response(
  4378. response_text="文件不可读"
  4379. )), 400
  4380. logger.info(f"开始发送文件: {file_path}")
  4381. return send_file(
  4382. file_path,
  4383. as_attachment=True,
  4384. download_name=file_name
  4385. )
  4386. except Exception as e:
  4387. logger.error(f"下载任务文件失败: task_id={task_id}, file_name={file_name}, 错误: {str(e)}", exc_info=True)
  4388. return jsonify(internal_error_response(
  4389. response_text="下载文件失败,请稍后重试"
  4390. )), 500
  4391. @app.route('/api/v0/data_pipeline/tasks/<task_id>/upload-table-list', methods=['POST'])
  4392. def upload_table_list_file(task_id):
  4393. """
  4394. 上传表清单文件(从 citu_app.py 迁移)
  4395. 表单参数:
  4396. - file: 要上传的表清单文件(multipart/form-data)
  4397. 响应:
  4398. {
  4399. "success": true,
  4400. "code": 200,
  4401. "message": "表清单文件上传成功",
  4402. "data": {
  4403. "task_id": "task_20250701_123456",
  4404. "filename": "table_list.txt",
  4405. "file_size": 1024,
  4406. "file_size_formatted": "1.0 KB"
  4407. }
  4408. }
  4409. """
  4410. try:
  4411. # 验证任务是否存在
  4412. manager = get_data_pipeline_manager()
  4413. task_info = manager.get_task_status(task_id)
  4414. if not task_info:
  4415. return jsonify(not_found_response(
  4416. response_text=f"任务不存在: {task_id}"
  4417. )), 404
  4418. # 检查是否有文件上传
  4419. if 'file' not in request.files:
  4420. return jsonify(bad_request_response(
  4421. response_text="请选择要上传的表清单文件",
  4422. missing_params=['file']
  4423. )), 400
  4424. file = request.files['file']
  4425. # 验证文件名
  4426. if file.filename == '':
  4427. return jsonify(bad_request_response(
  4428. response_text="请选择有效的文件"
  4429. )), 400
  4430. try:
  4431. # 使用文件管理器上传文件
  4432. file_manager = get_data_pipeline_file_manager()
  4433. result = file_manager.upload_table_list_file(task_id, file)
  4434. response_data = {
  4435. "task_id": task_id,
  4436. "filename": result["filename"],
  4437. "file_size": result["file_size"],
  4438. "file_size_formatted": result["file_size_formatted"],
  4439. "upload_time": result["upload_time"].isoformat() if result.get("upload_time") else None
  4440. }
  4441. return jsonify(success_response(
  4442. response_text="表清单文件上传成功",
  4443. data=response_data
  4444. )), 200
  4445. except ValueError as e:
  4446. # 文件验证错误(如文件太大、空文件等)
  4447. return jsonify(bad_request_response(
  4448. response_text=str(e)
  4449. )), 400
  4450. except Exception as e:
  4451. logger.error(f"上传表清单文件失败: {str(e)}")
  4452. return jsonify(internal_error_response(
  4453. response_text="文件上传失败,请稍后重试"
  4454. )), 500
  4455. except Exception as e:
  4456. logger.error(f"处理表清单文件上传请求失败: {str(e)}")
  4457. return jsonify(internal_error_response(
  4458. response_text="处理上传请求失败,请稍后重试"
  4459. )), 500
  4460. @app.route('/api/v0/data_pipeline/tasks/<task_id>/table-list-info', methods=['GET'])
  4461. def get_table_list_info(task_id):
  4462. """
  4463. 获取任务的表清单文件信息(从 citu_app.py 迁移)
  4464. 响应:
  4465. {
  4466. "success": true,
  4467. "code": 200,
  4468. "message": "获取表清单文件信息成功",
  4469. "data": {
  4470. "task_id": "task_20250701_123456",
  4471. "has_file": true,
  4472. "filename": "table_list.txt",
  4473. "file_path": "./data_pipeline/training_data/task_20250701_123456/table_list.txt",
  4474. "file_size": 1024,
  4475. "file_size_formatted": "1.0 KB",
  4476. "uploaded_at": "2025-07-01T12:34:56",
  4477. "table_count": 5,
  4478. "table_names": ["table_name_1", "table_name_2", "table_name_3", "table_name_4", "table_name_5"],
  4479. "is_readable": true
  4480. }
  4481. }
  4482. """
  4483. try:
  4484. file_manager = get_data_pipeline_file_manager()
  4485. # 获取表清单文件信息
  4486. table_list_info = file_manager.get_table_list_file_info(task_id)
  4487. response_data = {
  4488. "task_id": task_id,
  4489. "has_file": table_list_info.get("exists", False),
  4490. **table_list_info
  4491. }
  4492. return jsonify(success_response(
  4493. response_text="获取表清单文件信息成功",
  4494. data=response_data
  4495. ))
  4496. except Exception as e:
  4497. logger.error(f"获取表清单文件信息失败: {str(e)}")
  4498. return jsonify(internal_error_response(
  4499. response_text="获取表清单文件信息失败,请稍后重试"
  4500. )), 500
  4501. @app.route('/api/v0/data_pipeline/tasks/<task_id>/table-list', methods=['POST'])
  4502. def create_table_list_from_names(task_id):
  4503. """
  4504. 通过POST方式提交表名列表并创建table_list.txt文件(从 citu_app.py 迁移)
  4505. 请求体:
  4506. {
  4507. "tables": ["table1", "schema.table2", "table3"]
  4508. }
  4509. 或者:
  4510. {
  4511. "tables": "table1,schema.table2,table3"
  4512. }
  4513. 响应:
  4514. {
  4515. "success": true,
  4516. "code": 200,
  4517. "message": "表清单已成功创建",
  4518. "data": {
  4519. "task_id": "task_20250701_123456",
  4520. "filename": "table_list.txt",
  4521. "table_count": 3,
  4522. "file_size": 45,
  4523. "file_size_formatted": "45 B",
  4524. "created_time": "2025-07-01T12:34:56"
  4525. }
  4526. }
  4527. """
  4528. try:
  4529. # 验证任务是否存在
  4530. manager = get_data_pipeline_manager()
  4531. task_info = manager.get_task_status(task_id)
  4532. if not task_info:
  4533. return jsonify(not_found_response(
  4534. response_text=f"任务不存在: {task_id}"
  4535. )), 404
  4536. # 获取请求数据
  4537. req = request.get_json(force=True)
  4538. tables_param = req.get('tables')
  4539. if not tables_param:
  4540. return jsonify(bad_request_response(
  4541. response_text="缺少必需参数:tables",
  4542. missing_params=['tables']
  4543. )), 400
  4544. # 处理不同格式的表名参数
  4545. try:
  4546. if isinstance(tables_param, str):
  4547. # 逗号分隔的字符串格式
  4548. table_names = [name.strip() for name in tables_param.split(',') if name.strip()]
  4549. elif isinstance(tables_param, list):
  4550. # 数组格式
  4551. table_names = [str(name).strip() for name in tables_param if str(name).strip()]
  4552. else:
  4553. return jsonify(bad_request_response(
  4554. response_text="tables参数格式错误,应为字符串(逗号分隔)或数组"
  4555. )), 400
  4556. if not table_names:
  4557. return jsonify(bad_request_response(
  4558. response_text="表名列表不能为空"
  4559. )), 400
  4560. except Exception as e:
  4561. return jsonify(bad_request_response(
  4562. response_text=f"解析tables参数失败: {str(e)}"
  4563. )), 400
  4564. try:
  4565. # 使用文件管理器创建表清单文件
  4566. file_manager = get_data_pipeline_file_manager()
  4567. result = file_manager.create_table_list_from_names(task_id, table_names)
  4568. response_data = {
  4569. "task_id": task_id,
  4570. "filename": result["filename"],
  4571. "table_count": result["table_count"],
  4572. "unique_table_count": result["unique_table_count"],
  4573. "file_size": result["file_size"],
  4574. "file_size_formatted": result["file_size_formatted"],
  4575. "created_time": result["created_time"].isoformat() if result.get("created_time") else None,
  4576. "original_count": len(table_names) if isinstance(table_names, list) else len(tables_param.split(','))
  4577. }
  4578. return jsonify(success_response(
  4579. response_text=f"表清单已成功创建,包含 {result['table_count']} 个表",
  4580. data=response_data
  4581. )), 200
  4582. except ValueError as e:
  4583. # 表名验证错误(如格式错误、数量限制等)
  4584. return jsonify(bad_request_response(
  4585. response_text=str(e)
  4586. )), 400
  4587. except Exception as e:
  4588. logger.error(f"创建表清单文件失败: {str(e)}")
  4589. return jsonify(internal_error_response(
  4590. response_text="创建表清单文件失败,请稍后重试"
  4591. )), 500
  4592. except Exception as e:
  4593. logger.error(f"处理表清单创建请求失败: {str(e)}")
  4594. return jsonify(internal_error_response(
  4595. response_text="处理请求失败,请稍后重试"
  4596. )), 500
  4597. @app.route('/api/v0/data_pipeline/tasks/<task_id>/files', methods=['POST'])
  4598. def upload_file_to_task(task_id):
  4599. """
  4600. 上传文件到指定任务目录(从 citu_app.py 迁移)
  4601. 表单参数:
  4602. - file: 要上传的文件(multipart/form-data)
  4603. - overwrite_mode: 重名处理模式 (backup, replace, skip),默认为backup
  4604. 支持的文件类型:
  4605. - .ddl: DDL文件
  4606. - .md: Markdown文档
  4607. - .txt: 文本文件
  4608. - .json: JSON文件
  4609. - .sql: SQL文件
  4610. - .csv: CSV文件
  4611. 重名处理模式:
  4612. - backup: 备份原文件(默认)
  4613. - replace: 直接覆盖
  4614. - skip: 跳过上传
  4615. """
  4616. try:
  4617. # 验证任务是否存在
  4618. manager = get_data_pipeline_manager()
  4619. task_info = manager.get_task_status(task_id)
  4620. if not task_info:
  4621. return jsonify(not_found_response(
  4622. response_text=f"任务不存在: {task_id}"
  4623. )), 404
  4624. # 检查是否有文件上传
  4625. if 'file' not in request.files:
  4626. return jsonify(bad_request_response(
  4627. response_text="请选择要上传的文件",
  4628. missing_params=['file']
  4629. )), 400
  4630. file = request.files['file']
  4631. # 验证文件名
  4632. if file.filename == '':
  4633. return jsonify(bad_request_response(
  4634. response_text="请选择有效的文件"
  4635. )), 400
  4636. # 获取重名处理模式
  4637. overwrite_mode = request.form.get('overwrite_mode', 'backup')
  4638. # 验证重名处理模式
  4639. valid_modes = ['backup', 'replace', 'skip']
  4640. if overwrite_mode not in valid_modes:
  4641. return jsonify(bad_request_response(
  4642. response_text=f"无效的overwrite_mode参数: {overwrite_mode},支持的值: {valid_modes}",
  4643. invalid_params=['overwrite_mode']
  4644. )), 400
  4645. try:
  4646. # 使用文件管理器上传文件
  4647. file_manager = get_data_pipeline_file_manager()
  4648. result = file_manager.upload_file_to_task(task_id, file, file.filename, overwrite_mode)
  4649. # 检查是否跳过上传
  4650. if result.get('skipped'):
  4651. return jsonify(success_response(
  4652. response_text=result.get('message', '文件已存在,跳过上传'),
  4653. data=result
  4654. )), 200
  4655. return jsonify(success_response(
  4656. response_text="文件上传成功",
  4657. data=result
  4658. )), 200
  4659. except ValueError as e:
  4660. # 文件验证错误(如文件太大、空文件、不支持的类型等)
  4661. return jsonify(bad_request_response(
  4662. response_text=str(e)
  4663. )), 400
  4664. except Exception as e:
  4665. logger.error(f"上传文件失败: {str(e)}")
  4666. return jsonify(internal_error_response(
  4667. response_text="文件上传失败,请稍后重试"
  4668. )), 500
  4669. except Exception as e:
  4670. logger.error(f"处理文件上传请求失败: {str(e)}")
  4671. return jsonify(internal_error_response(
  4672. response_text="处理上传请求失败,请稍后重试"
  4673. )), 500
  4674. # 任务目录删除功能(从 citu_app.py 迁移)
  4675. import shutil
  4676. from pathlib import Path
  4677. import psycopg2
  4678. from app_config import PGVECTOR_CONFIG
  4679. def delete_task_directory_simple(task_id, delete_database_records=False):
  4680. """
  4681. 简单的任务目录删除功能(从 citu_app.py 迁移)
  4682. - 删除 data_pipeline/training_data/{task_id} 目录
  4683. - 更新数据库中的 directory_exists 字段
  4684. - 可选:删除数据库记录
  4685. """
  4686. try:
  4687. # 1. 删除目录
  4688. project_root = Path(__file__).parent.absolute()
  4689. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  4690. deleted_files_count = 0
  4691. deleted_size = 0
  4692. if task_dir.exists():
  4693. # 计算删除前的统计信息
  4694. for file_path in task_dir.rglob('*'):
  4695. if file_path.is_file():
  4696. deleted_files_count += 1
  4697. deleted_size += file_path.stat().st_size
  4698. # 删除目录
  4699. shutil.rmtree(task_dir)
  4700. directory_deleted = True
  4701. operation_message = "目录删除成功"
  4702. else:
  4703. directory_deleted = False
  4704. operation_message = "目录不存在,无需删除"
  4705. # 2. 更新数据库
  4706. database_records_deleted = False
  4707. try:
  4708. conn = psycopg2.connect(**PGVECTOR_CONFIG)
  4709. cur = conn.cursor()
  4710. if delete_database_records:
  4711. # 删除任务步骤记录
  4712. cur.execute("DELETE FROM data_pipeline_task_steps WHERE task_id = %s", (task_id,))
  4713. # 删除任务主记录
  4714. cur.execute("DELETE FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
  4715. database_records_deleted = True
  4716. else:
  4717. # 只更新目录状态
  4718. cur.execute("""
  4719. UPDATE data_pipeline_tasks
  4720. SET directory_exists = FALSE, updated_at = CURRENT_TIMESTAMP
  4721. WHERE task_id = %s
  4722. """, (task_id,))
  4723. conn.commit()
  4724. cur.close()
  4725. conn.close()
  4726. except Exception as db_error:
  4727. logger.error(f"数据库操作失败: {db_error}")
  4728. # 数据库失败不影响文件删除的结果
  4729. # 3. 格式化文件大小
  4730. def format_size(size_bytes):
  4731. if size_bytes < 1024:
  4732. return f"{size_bytes} B"
  4733. elif size_bytes < 1024**2:
  4734. return f"{size_bytes/1024:.1f} KB"
  4735. elif size_bytes < 1024**3:
  4736. return f"{size_bytes/(1024**2):.1f} MB"
  4737. else:
  4738. return f"{size_bytes/(1024**3):.1f} GB"
  4739. return {
  4740. "success": True,
  4741. "task_id": task_id,
  4742. "directory_deleted": directory_deleted,
  4743. "database_records_deleted": database_records_deleted,
  4744. "deleted_files_count": deleted_files_count,
  4745. "deleted_size": format_size(deleted_size),
  4746. "deleted_at": datetime.now().isoformat(),
  4747. "operation_message": operation_message # 新增:具体的操作消息
  4748. }
  4749. except Exception as e:
  4750. logger.error(f"删除任务目录失败: {task_id}, 错误: {str(e)}")
  4751. return {
  4752. "success": False,
  4753. "task_id": task_id,
  4754. "error": str(e),
  4755. "error_code": "DELETE_FAILED",
  4756. "operation_message": f"删除操作失败: {str(e)}" # 新增:失败消息
  4757. }
  4758. @app.route('/api/v0/data_pipeline/tasks', methods=['DELETE'])
  4759. def delete_tasks():
  4760. """删除任务目录(支持单个和批量)(从 citu_app.py 迁移)"""
  4761. try:
  4762. # 智能获取参数:支持JSON body和URL查询参数两种方式
  4763. def get_request_parameter(param_name, array_param_name=None):
  4764. """从JSON body或URL查询参数中获取参数值"""
  4765. # 1. 优先从JSON body获取
  4766. if request.is_json:
  4767. try:
  4768. json_data = request.get_json()
  4769. if json_data and param_name in json_data:
  4770. return json_data[param_name]
  4771. except:
  4772. pass
  4773. # 2. 从URL查询参数获取
  4774. if param_name in request.args:
  4775. value = request.args.get(param_name)
  4776. # 处理布尔值
  4777. if value.lower() in ('true', '1', 'yes'):
  4778. return True
  4779. elif value.lower() in ('false', '0', 'no'):
  4780. return False
  4781. return value
  4782. # 3. 处理数组参数(如 task_ids[])
  4783. if array_param_name and array_param_name in request.args:
  4784. return request.args.getlist(array_param_name)
  4785. return None
  4786. # 获取参数
  4787. task_ids = get_request_parameter('task_ids', 'task_ids[]')
  4788. confirm = get_request_parameter('confirm')
  4789. if not task_ids:
  4790. return jsonify(bad_request_response(
  4791. response_text="缺少必需参数: task_ids",
  4792. missing_params=['task_ids']
  4793. )), 400
  4794. if not confirm:
  4795. return jsonify(bad_request_response(
  4796. response_text="缺少必需参数: confirm",
  4797. missing_params=['confirm']
  4798. )), 400
  4799. if confirm != True:
  4800. return jsonify(bad_request_response(
  4801. response_text="confirm参数必须为true以确认删除操作"
  4802. )), 400
  4803. if not isinstance(task_ids, list) or len(task_ids) == 0:
  4804. return jsonify(bad_request_response(
  4805. response_text="task_ids必须是非空的任务ID列表"
  4806. )), 400
  4807. # 获取可选参数
  4808. delete_database_records = get_request_parameter('delete_database_records') or False
  4809. continue_on_error = get_request_parameter('continue_on_error')
  4810. if continue_on_error is None:
  4811. continue_on_error = True
  4812. # 执行批量删除操作
  4813. deleted_tasks = []
  4814. failed_tasks = []
  4815. total_size_freed = 0
  4816. for task_id in task_ids:
  4817. result = delete_task_directory_simple(task_id, delete_database_records)
  4818. if result["success"]:
  4819. deleted_tasks.append(result)
  4820. # 累计释放的空间大小(这里简化处理,实际应该解析size字符串)
  4821. else:
  4822. failed_tasks.append({
  4823. "task_id": task_id,
  4824. "error": result["error"],
  4825. "error_code": result.get("error_code", "UNKNOWN")
  4826. })
  4827. if not continue_on_error:
  4828. break
  4829. # 构建响应
  4830. summary = {
  4831. "total_requested": len(task_ids),
  4832. "successfully_deleted": len(deleted_tasks),
  4833. "failed": len(failed_tasks)
  4834. }
  4835. batch_result = {
  4836. "deleted_tasks": deleted_tasks,
  4837. "failed_tasks": failed_tasks,
  4838. "summary": summary,
  4839. "deleted_at": datetime.now().isoformat()
  4840. }
  4841. # 构建智能响应消息
  4842. if len(task_ids) == 1:
  4843. # 单个删除:使用具体的操作消息
  4844. if summary["failed"] == 0:
  4845. # 从deleted_tasks中获取具体的操作消息
  4846. operation_msg = deleted_tasks[0].get('operation_message', '任务处理完成')
  4847. message = operation_msg
  4848. else:
  4849. # 从failed_tasks中获取错误消息
  4850. error_msg = failed_tasks[0].get('error', '删除失败')
  4851. message = f"任务删除失败: {error_msg}"
  4852. else:
  4853. # 批量删除:统计各种操作结果
  4854. directory_deleted_count = sum(1 for task in deleted_tasks if task.get('directory_deleted', False))
  4855. directory_not_exist_count = sum(1 for task in deleted_tasks if not task.get('directory_deleted', False))
  4856. if summary["failed"] == 0:
  4857. # 全部成功
  4858. if directory_deleted_count > 0 and directory_not_exist_count > 0:
  4859. message = f"批量操作完成:{directory_deleted_count}个目录已删除,{directory_not_exist_count}个目录不存在"
  4860. elif directory_deleted_count > 0:
  4861. message = f"批量删除完成:成功删除{directory_deleted_count}个目录"
  4862. elif directory_not_exist_count > 0:
  4863. message = f"批量操作完成:{directory_not_exist_count}个目录不存在,无需删除"
  4864. else:
  4865. message = "批量操作完成"
  4866. elif summary["successfully_deleted"] == 0:
  4867. message = f"批量删除失败:{summary['failed']}个任务处理失败"
  4868. else:
  4869. message = f"批量删除部分完成:成功{summary['successfully_deleted']}个,失败{summary['failed']}个"
  4870. return jsonify(success_response(
  4871. response_text=message,
  4872. data=batch_result
  4873. )), 200
  4874. except Exception as e:
  4875. logger.error(f"删除任务失败: 错误: {str(e)}")
  4876. return jsonify(internal_error_response(
  4877. response_text="删除任务失败,请稍后重试"
  4878. )), 500
  4879. @app.route('/api/v0/data_pipeline/tasks/<task_id>/logs/query', methods=['POST'])
  4880. def query_data_pipeline_task_logs(task_id):
  4881. """
  4882. 高级查询数据管道任务日志(从 citu_app.py 迁移)
  4883. 支持复杂筛选、排序、分页功能
  4884. 请求体:
  4885. {
  4886. "page": 1, // 页码,必须大于0,默认1
  4887. "page_size": 50, // 每页大小,1-500之间,默认50
  4888. "level": "ERROR", // 可选,日志级别筛选:"DEBUG"|"INFO"|"WARNING"|"ERROR"|"CRITICAL"
  4889. "start_time": "2025-01-01 00:00:00", // 可选,开始时间范围 (YYYY-MM-DD HH:MM:SS)
  4890. "end_time": "2025-01-02 23:59:59", // 可选,结束时间范围 (YYYY-MM-DD HH:MM:SS)
  4891. "keyword": "failed", // 可选,关键字搜索(消息内容模糊匹配)
  4892. "logger_name": "DDLGenerator", // 可选,日志记录器名称精确匹配
  4893. "step_name": "ddl_generation", // 可选,执行步骤名称精确匹配
  4894. "sort_by": "timestamp", // 可选,排序字段:"timestamp"|"level"|"logger"|"step"|"line_number",默认"timestamp"
  4895. "sort_order": "desc" // 可选,排序方向:"asc"|"desc",默认"desc"
  4896. }
  4897. """
  4898. try:
  4899. # 验证任务是否存在
  4900. manager = get_data_pipeline_manager()
  4901. task_info = manager.get_task_status(task_id)
  4902. if not task_info:
  4903. return jsonify(not_found_response(
  4904. response_text=f"任务不存在: {task_id}"
  4905. )), 404
  4906. # 解析请求数据
  4907. request_data = request.get_json() or {}
  4908. # 参数验证
  4909. def _is_valid_time_format(time_str):
  4910. """验证时间格式是否有效"""
  4911. if not time_str:
  4912. return True
  4913. # 支持的时间格式
  4914. time_formats = [
  4915. '%Y-%m-%d %H:%M:%S', # 2025-01-01 00:00:00
  4916. '%Y-%m-%d', # 2025-01-01
  4917. '%Y-%m-%dT%H:%M:%S', # 2025-01-01T00:00:00
  4918. '%Y-%m-%dT%H:%M:%S.%f', # 2025-01-01T00:00:00.123456
  4919. ]
  4920. for fmt in time_formats:
  4921. try:
  4922. from datetime import datetime
  4923. datetime.strptime(time_str, fmt)
  4924. return True
  4925. except ValueError:
  4926. continue
  4927. return False
  4928. # 提取和验证参数
  4929. page = request_data.get('page', 1)
  4930. page_size = request_data.get('page_size', 50)
  4931. level = request_data.get('level')
  4932. start_time = request_data.get('start_time')
  4933. end_time = request_data.get('end_time')
  4934. keyword = request_data.get('keyword')
  4935. logger_name = request_data.get('logger_name')
  4936. step_name = request_data.get('step_name')
  4937. sort_by = request_data.get('sort_by', 'timestamp')
  4938. sort_order = request_data.get('sort_order', 'desc')
  4939. # 参数验证
  4940. if not isinstance(page, int) or page < 1:
  4941. return jsonify(bad_request_response(
  4942. response_text="页码必须是大于0的整数"
  4943. )), 400
  4944. if not isinstance(page_size, int) or page_size < 1 or page_size > 500:
  4945. return jsonify(bad_request_response(
  4946. response_text="每页大小必须是1-500之间的整数"
  4947. )), 400
  4948. # 验证日志级别
  4949. if level and level.upper() not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
  4950. return jsonify(bad_request_response(
  4951. response_text="日志级别必须是DEBUG、INFO、WARNING、ERROR、CRITICAL之一"
  4952. )), 400
  4953. # 验证时间格式
  4954. if not _is_valid_time_format(start_time):
  4955. return jsonify(bad_request_response(
  4956. response_text="开始时间格式无效,支持格式:YYYY-MM-DD HH:MM:SS 或 YYYY-MM-DD"
  4957. )), 400
  4958. if not _is_valid_time_format(end_time):
  4959. return jsonify(bad_request_response(
  4960. response_text="结束时间格式无效,支持格式:YYYY-MM-DD HH:MM:SS 或 YYYY-MM-DD"
  4961. )), 400
  4962. # 验证关键字长度
  4963. if keyword and len(keyword) > 200:
  4964. return jsonify(bad_request_response(
  4965. response_text="关键字长度不能超过200个字符"
  4966. )), 400
  4967. # 验证排序字段
  4968. allowed_sort_fields = ['timestamp', 'level', 'logger', 'step', 'line_number']
  4969. if sort_by not in allowed_sort_fields:
  4970. return jsonify(bad_request_response(
  4971. response_text=f"排序字段必须是以下之一: {', '.join(allowed_sort_fields)}"
  4972. )), 400
  4973. # 验证排序方向
  4974. if sort_order.lower() not in ['asc', 'desc']:
  4975. return jsonify(bad_request_response(
  4976. response_text="排序方向必须是asc或desc"
  4977. )), 400
  4978. # 创建工作流执行器并查询日志
  4979. from data_pipeline.api.simple_workflow import SimpleWorkflowExecutor
  4980. executor = SimpleWorkflowExecutor(task_id)
  4981. try:
  4982. result = executor.query_logs_advanced(
  4983. page=page,
  4984. page_size=page_size,
  4985. level=level,
  4986. start_time=start_time,
  4987. end_time=end_time,
  4988. keyword=keyword,
  4989. logger_name=logger_name,
  4990. step_name=step_name,
  4991. sort_by=sort_by,
  4992. sort_order=sort_order
  4993. )
  4994. return jsonify(success_response(
  4995. response_text="查询任务日志成功",
  4996. data=result
  4997. ))
  4998. finally:
  4999. executor.cleanup()
  5000. except Exception as e:
  5001. logger.error(f"查询数据管道任务日志失败: {str(e)}")
  5002. return jsonify(internal_error_response(
  5003. response_text="查询任务日志失败,请稍后重试"
  5004. )), 500
  5005. # ==================== 启动逻辑 ====================
  5006. def signal_handler(signum, frame):
  5007. """信号处理器,优雅退出"""
  5008. logger.info(f"接收到信号 {signum},准备退出...")
  5009. cleanup_resources()
  5010. sys.exit(0)
  5011. @app.route('/api/v0/data_pipeline/vector/backup', methods=['POST'])
  5012. def backup_pgvector_tables():
  5013. """专用的pgvector表备份API - 纯备份功能,不执行truncate操作
  5014. 注意:truncate功能已移除,如需清空表请使用 /api/v0/data_pipeline/vector/restore 的 truncate_before_restore 参数
  5015. """
  5016. try:
  5017. # 支持空参数调用 {}
  5018. req = request.get_json(force=True) if request.is_json else {}
  5019. # 解析参数(全部可选)
  5020. task_id = req.get('task_id')
  5021. db_connection = req.get('db_connection')
  5022. # truncate_vector_tables = req.get('truncate_vector_tables', False) # 已注释:备份API不应执行truncate操作
  5023. backup_vector_tables = req.get('backup_vector_tables', True)
  5024. # 参数验证
  5025. if task_id and not re.match(r'^[a-zA-Z0-9_]+$', task_id):
  5026. return jsonify(bad_request_response(
  5027. "无效的task_id格式,只能包含字母、数字和下划线"
  5028. )), 400
  5029. # 确定备份目录
  5030. if task_id:
  5031. # 验证task_id目录是否存在
  5032. task_dir = Path(f"data_pipeline/training_data/{task_id}")
  5033. if not task_dir.exists():
  5034. return jsonify(not_found_response(
  5035. f"指定的任务目录不存在: {task_id}"
  5036. )), 404
  5037. backup_base_dir = str(task_dir)
  5038. else:
  5039. # 使用training_data根目录(支持空参数调用)
  5040. backup_base_dir = "data_pipeline/training_data"
  5041. # 直接使用现有的VectorTableManager
  5042. from data_pipeline.trainer.vector_table_manager import VectorTableManager
  5043. # 临时修改数据库连接配置(如果提供了自定义连接)
  5044. original_config = None
  5045. if db_connection:
  5046. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  5047. original_config = SCHEMA_TOOLS_CONFIG.get("default_db_connection")
  5048. SCHEMA_TOOLS_CONFIG["default_db_connection"] = db_connection
  5049. try:
  5050. # 使用现有的成熟管理器
  5051. vector_manager = VectorTableManager(
  5052. task_output_dir=backup_base_dir,
  5053. task_id=task_id or "vector_bak"
  5054. )
  5055. # 执行备份(纯备份操作,不执行truncate)
  5056. result = vector_manager.execute_vector_management(
  5057. backup=backup_vector_tables,
  5058. truncate=False # 强制设为False,备份API不执行truncate操作
  5059. )
  5060. # 使用 common/result.py 的标准格式
  5061. return jsonify(success_response(
  5062. response_text="Vector表备份完成",
  5063. data=result
  5064. )), 200
  5065. finally:
  5066. # 恢复原始配置
  5067. if original_config is not None:
  5068. SCHEMA_TOOLS_CONFIG["default_db_connection"] = original_config
  5069. except Exception as e:
  5070. logger.error(f"Vector表备份失败: {str(e)}")
  5071. return jsonify(internal_error_response(
  5072. "Vector表备份失败,请稍后重试"
  5073. )), 500
  5074. # ====================================================================
  5075. # Vector表恢复备份API
  5076. # ====================================================================
  5077. @app.route('/api/v0/data_pipeline/vector/restore/list', methods=['GET'])
  5078. def list_vector_backups():
  5079. """列出可用的vector表备份文件"""
  5080. try:
  5081. # 解析查询参数
  5082. global_only = request.args.get('global_only', 'false').lower() == 'true'
  5083. task_id = request.args.get('task_id')
  5084. # 参数验证
  5085. if task_id and not re.match(r'^[a-zA-Z0-9_]+$', task_id):
  5086. return jsonify(bad_request_response(
  5087. "无效的task_id格式,只能包含字母、数字和下划线"
  5088. )), 400
  5089. # 使用VectorRestoreManager扫描
  5090. from data_pipeline.api.vector_restore_manager import VectorRestoreManager
  5091. restore_manager = VectorRestoreManager()
  5092. result = restore_manager.scan_backup_files(global_only, task_id)
  5093. # 构建响应文本
  5094. total_locations = result['summary']['total_locations']
  5095. total_backup_sets = result['summary']['total_backup_sets']
  5096. if total_backup_sets == 0:
  5097. response_text = "未找到任何可用的备份文件"
  5098. else:
  5099. response_text = f"成功扫描到 {total_locations} 个备份位置,共 {total_backup_sets} 个备份集"
  5100. # 返回标准格式
  5101. return jsonify(success_response(
  5102. response_text=response_text,
  5103. data=result
  5104. )), 200
  5105. except Exception as e:
  5106. logger.error(f"扫描备份文件失败: {str(e)}")
  5107. return jsonify(internal_error_response(
  5108. "扫描备份文件失败,请稍后重试"
  5109. )), 500
  5110. @app.route('/api/v0/data_pipeline/vector/restore', methods=['POST'])
  5111. def restore_vector_tables():
  5112. """恢复vector表数据"""
  5113. try:
  5114. # 解析请求参数
  5115. req = request.get_json(force=True) if request.is_json else {}
  5116. # 必需参数验证
  5117. backup_path = req.get('backup_path')
  5118. timestamp = req.get('timestamp')
  5119. if not backup_path or not timestamp:
  5120. missing_params = []
  5121. if not backup_path:
  5122. missing_params.append('backup_path')
  5123. if not timestamp:
  5124. missing_params.append('timestamp')
  5125. return jsonify(bad_request_response(
  5126. f"缺少必需参数: {', '.join(missing_params)}",
  5127. missing_params
  5128. )), 400
  5129. # 可选参数
  5130. tables = req.get('tables')
  5131. db_connection = req.get('db_connection')
  5132. truncate_before_restore = req.get('truncate_before_restore', False)
  5133. # 参数验证
  5134. if tables is not None and not isinstance(tables, list):
  5135. return jsonify(bad_request_response(
  5136. "tables参数必须是数组格式"
  5137. )), 400
  5138. # 验证时间戳格式
  5139. if not re.match(r'^\d{8}_\d{6}$', timestamp):
  5140. return jsonify(bad_request_response(
  5141. "无效的timestamp格式,应为YYYYMMDD_HHMMSS"
  5142. )), 400
  5143. # 执行恢复
  5144. from data_pipeline.api.vector_restore_manager import VectorRestoreManager
  5145. restore_manager = VectorRestoreManager()
  5146. result = restore_manager.restore_from_backup(
  5147. backup_path=backup_path,
  5148. timestamp=timestamp,
  5149. tables=tables,
  5150. db_connection=db_connection,
  5151. truncate_before_restore=truncate_before_restore
  5152. )
  5153. # 构建响应文本
  5154. if result.get("errors"):
  5155. response_text = "Vector表恢复部分完成,部分表恢复失败"
  5156. else:
  5157. response_text = "Vector表恢复完成"
  5158. # 返回结果
  5159. return jsonify(success_response(
  5160. response_text=response_text,
  5161. data=result
  5162. )), 200
  5163. except FileNotFoundError as e:
  5164. return jsonify(not_found_response(str(e))), 404
  5165. except ValueError as e:
  5166. return jsonify(bad_request_response(str(e))), 400
  5167. except Exception as e:
  5168. logger.error(f"Vector表恢复失败: {str(e)}")
  5169. return jsonify(internal_error_response(
  5170. "Vector表恢复失败,请稍后重试"
  5171. )), 500
  5172. if __name__ == '__main__':
  5173. # 注册信号处理器
  5174. signal.signal(signal.SIGINT, signal_handler)
  5175. signal.signal(signal.SIGTERM, signal_handler)
  5176. logger.info("🚀 启动统一API服务...")
  5177. logger.info("📍 服务地址: http://localhost:8084")
  5178. logger.info("🔗 健康检查: http://localhost:8084/health")
  5179. logger.info("📘 React Agent API: http://localhost:8084/api/v0/ask_react_agent")
  5180. logger.info("📘 LangGraph Agent API: http://localhost:8084/api/v0/ask_agent")
  5181. logger.info("💾 Vector备份API: http://localhost:8084/api/v0/data_pipeline/vector/backup")
  5182. logger.info("📥 Vector恢复API: http://localhost:8084/api/v0/data_pipeline/vector/restore")
  5183. logger.info("📋 备份列表API: http://localhost:8084/api/v0/data_pipeline/vector/restore/list")
  5184. # 原生Flask单进程模式启动
  5185. # 如需多进程ASGI模式,请使用:uvicorn asgi_app:asgi_app --workers 4
  5186. logger.info("🚀 使用原生Flask单进程模式启动...")
  5187. logger.info(" 优点:避免WsgiToAsgi并发阻塞问题")
  5188. logger.info(" 多进程模式请使用:uvicorn asgi_app:asgi_app --workers 4")
  5189. # 启动标准Flask应用(支持异步路由)
  5190. app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)