citu_app.py 71 KB


  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  15. from common.result import ( # 统一导入所有需要的响应函数
  16. bad_request_response, service_unavailable_response,
  17. agent_success_response, agent_error_response,
  18. internal_error_response, success_response,
  19. validation_failed_response
  20. )
  21. from app_config import ( # 添加Redis相关配置导入
  22. USER_MAX_CONVERSATIONS,
  23. CONVERSATION_CONTEXT_COUNT,
  24. DEFAULT_ANONYMOUS_USER
  25. )
  26. # 设置默认的最大返回行数
  27. DEFAULT_MAX_RETURN_ROWS = 200
  28. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  29. vn = create_vanna_instance()
  30. # 创建带时间戳的缓存
  31. timestamped_cache = WebSessionAwareMemoryCache()
  32. # 实例化 VannaFlaskApp,使用自定义缓存
  33. app = VannaFlaskApp(
  34. vn,
  35. cache=timestamped_cache, # 使用带时间戳的缓存
  36. title="辞图智能数据问答平台",
  37. logo = "https://www.citupro.com/img/logo-black-2.png",
  38. subtitle="让 AI 为你写 SQL",
  39. chart=False,
  40. allow_llm_to_see_data=True,
  41. ask_results_correct=True,
  42. followup_questions=True,
  43. debug=True
  44. )
  45. # 创建Redis对话管理器实例
  46. redis_conversation_manager = RedisConversationManager()
  47. # 修改ask接口,支持前端传递session_id
  48. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  49. def ask_full():
  50. req = request.get_json(force=True)
  51. question = req.get("question", None)
  52. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  53. if not question:
  54. from common.result import bad_request_response
  55. return jsonify(bad_request_response(
  56. response_text="缺少必需参数:question",
  57. missing_params=["question"]
  58. )), 400
  59. # 如果使用WebSessionAwareMemoryCache
  60. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  61. # 这里需要修改vanna的ask方法来支持传递session_id
  62. # 或者预先调用generate_id来建立会话关联
  63. conversation_id = app.cache.generate_id_with_browser_session(
  64. question=question,
  65. browser_session_id=browser_session_id
  66. )
  67. try:
  68. sql, df, _ = vn.ask(
  69. question=question,
  70. print_results=False,
  71. visualize=False,
  72. allow_llm_to_see_data=True
  73. )
  74. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  75. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  76. # 在解释性文本末尾添加提示语
  77. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  78. # 使用标准化错误响应
  79. from common.result import validation_failed_response
  80. return jsonify(validation_failed_response(
  81. response_text=explanation_message
  82. )), 422 # 修改HTTP状态码为422
  83. # 如果sql为None但没有解释性文本,返回通用错误
  84. if sql is None:
  85. from common.result import validation_failed_response
  86. return jsonify(validation_failed_response(
  87. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  88. )), 422
  89. # 处理返回数据 - 使用新的query_result结构
  90. query_result = {
  91. "rows": [],
  92. "columns": [],
  93. "row_count": 0,
  94. "is_limited": False,
  95. "total_row_count": 0
  96. }
  97. summary = None
  98. if isinstance(df, pd.DataFrame):
  99. query_result["columns"] = list(df.columns)
  100. if not df.empty:
  101. total_rows = len(df)
  102. limited_df = df.head(MAX_RETURN_ROWS)
  103. query_result["rows"] = limited_df.to_dict(orient="records")
  104. query_result["row_count"] = len(limited_df)
  105. query_result["total_row_count"] = total_rows
  106. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  107. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  108. if ENABLE_RESULT_SUMMARY:
  109. try:
  110. summary = vn.generate_summary(question=question, df=df)
  111. print(f"[INFO] 成功生成摘要: {summary}")
  112. except Exception as e:
  113. print(f"[WARNING] 生成摘要失败: {str(e)}")
  114. summary = None
  115. # 构建返回数据
  116. response_data = {
  117. "sql": sql,
  118. "query_result": query_result,
  119. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  120. "session_id": browser_session_id
  121. }
  122. # 添加摘要(如果启用且生成成功)
  123. if ENABLE_RESULT_SUMMARY and summary is not None:
  124. response_data["summary"] = summary
  125. response_data["response"] = summary # 同时添加response字段
  126. from common.result import success_response
  127. return jsonify(success_response(
  128. response_text="查询执行完成" if summary is None else None,
  129. data=response_data
  130. ))
  131. except Exception as e:
  132. print(f"[ERROR] ask_full执行失败: {str(e)}")
  133. # 即使发生异常,也检查是否有业务层面的解释
  134. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  135. # 在解释性文本末尾添加提示语
  136. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  137. from common.result import validation_failed_response
  138. return jsonify(validation_failed_response(
  139. response_text=explanation_message
  140. )), 422
  141. else:
  142. # 技术错误,使用500错误码
  143. from common.result import internal_error_response
  144. return jsonify(internal_error_response(
  145. response_text="查询处理失败,请稍后重试"
  146. )), 500
  147. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  148. def citu_run_sql():
  149. req = request.get_json(force=True)
  150. sql = req.get('sql')
  151. if not sql:
  152. from common.result import bad_request_response
  153. return jsonify(bad_request_response(
  154. response_text="缺少必需参数:sql",
  155. missing_params=["sql"]
  156. )), 400
  157. try:
  158. df = vn.run_sql(sql)
  159. # 处理返回数据 - 使用新的query_result结构
  160. query_result = {
  161. "rows": [],
  162. "columns": [],
  163. "row_count": 0,
  164. "is_limited": False,
  165. "total_row_count": 0
  166. }
  167. if isinstance(df, pd.DataFrame):
  168. query_result["columns"] = list(df.columns)
  169. if not df.empty:
  170. total_rows = len(df)
  171. limited_df = df.head(MAX_RETURN_ROWS)
  172. query_result["rows"] = limited_df.to_dict(orient="records")
  173. query_result["row_count"] = len(limited_df)
  174. query_result["total_row_count"] = total_rows
  175. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  176. from common.result import success_response
  177. return jsonify(success_response(
  178. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  179. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  180. data={
  181. "sql": sql,
  182. "query_result": query_result
  183. }
  184. ))
  185. except Exception as e:
  186. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  187. from common.result import internal_error_response
  188. return jsonify(internal_error_response(
  189. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  190. )), 500
  191. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  192. def ask_cached():
  193. """
  194. 带缓存功能的智能查询接口
  195. 支持会话管理和结果缓存,提高查询效率
  196. """
  197. req = request.get_json(force=True)
  198. question = req.get("question", None)
  199. browser_session_id = req.get("session_id", None)
  200. if not question:
  201. from common.result import bad_request_response
  202. return jsonify(bad_request_response(
  203. response_text="缺少必需参数:question",
  204. missing_params=["question"]
  205. )), 400
  206. try:
  207. # 生成conversation_id
  208. # 调试:查看generate_id的实际行为
  209. print(f"[DEBUG] 输入问题: '{question}'")
  210. conversation_id = app.cache.generate_id(question=question)
  211. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  212. # 再次用相同问题测试
  213. conversation_id2 = app.cache.generate_id(question=question)
  214. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  215. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  216. # 检查缓存
  217. cached_sql = app.cache.get(id=conversation_id, field="sql")
  218. if cached_sql is not None:
  219. # 缓存命中
  220. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  221. sql = cached_sql
  222. df = app.cache.get(id=conversation_id, field="df")
  223. summary = app.cache.get(id=conversation_id, field="summary")
  224. else:
  225. # 缓存未命中,执行新查询
  226. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  227. sql, df, _ = vn.ask(
  228. question=question,
  229. print_results=False,
  230. visualize=False,
  231. allow_llm_to_see_data=True
  232. )
  233. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  234. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  235. # 在解释性文本末尾添加提示语
  236. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  237. from common.result import validation_failed_response
  238. return jsonify(validation_failed_response(
  239. response_text=explanation_message
  240. )), 422
  241. # 如果sql为None但没有解释性文本,返回通用错误
  242. if sql is None:
  243. from common.result import validation_failed_response
  244. return jsonify(validation_failed_response(
  245. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  246. )), 422
  247. # 缓存结果
  248. app.cache.set(id=conversation_id, field="question", value=question)
  249. app.cache.set(id=conversation_id, field="sql", value=sql)
  250. app.cache.set(id=conversation_id, field="df", value=df)
  251. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  252. summary = None
  253. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  254. try:
  255. summary = vn.generate_summary(question=question, df=df)
  256. print(f"[INFO] 成功生成摘要: {summary}")
  257. except Exception as e:
  258. print(f"[WARNING] 生成摘要失败: {str(e)}")
  259. summary = None
  260. app.cache.set(id=conversation_id, field="summary", value=summary)
  261. # 处理返回数据 - 使用新的query_result结构
  262. query_result = {
  263. "rows": [],
  264. "columns": [],
  265. "row_count": 0,
  266. "is_limited": False,
  267. "total_row_count": 0
  268. }
  269. if isinstance(df, pd.DataFrame):
  270. query_result["columns"] = list(df.columns)
  271. if not df.empty:
  272. total_rows = len(df)
  273. limited_df = df.head(MAX_RETURN_ROWS)
  274. query_result["rows"] = limited_df.to_dict(orient="records")
  275. query_result["row_count"] = len(limited_df)
  276. query_result["total_row_count"] = total_rows
  277. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  278. # 构建返回数据
  279. response_data = {
  280. "sql": sql,
  281. "query_result": query_result,
  282. "conversation_id": conversation_id,
  283. "session_id": browser_session_id,
  284. "cached": cached_sql is not None # 标识是否来自缓存
  285. }
  286. # 添加摘要(如果启用且生成成功)
  287. if ENABLE_RESULT_SUMMARY and summary is not None:
  288. response_data["summary"] = summary
  289. response_data["response"] = summary # 同时添加response字段
  290. from common.result import success_response
  291. return jsonify(success_response(
  292. response_text="查询执行完成" if summary is None else None,
  293. data=response_data
  294. ))
  295. except Exception as e:
  296. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  297. from common.result import internal_error_response
  298. return jsonify(internal_error_response(
  299. response_text="查询处理失败,请稍后重试"
  300. )), 500
  301. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  302. def citu_train_question_sql():
  303. """
  304. 训练问题-SQL对接口
  305. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  306. 支持仅传入SQL或同时传入问题和SQL进行训练。
  307. Args:
  308. question (str, optional): 用户问题
  309. sql (str, required): 对应的SQL查询语句
  310. Returns:
  311. JSON: 包含训练ID和成功消息的响应
  312. """
  313. try:
  314. req = request.get_json(force=True)
  315. question = req.get('question')
  316. sql = req.get('sql')
  317. if not sql:
  318. from common.result import bad_request_response
  319. return jsonify(bad_request_response(
  320. response_text="缺少必需参数:sql",
  321. missing_params=["sql"]
  322. )), 400
  323. # 正确的调用方式:同时传递question和sql
  324. if question:
  325. training_id = vn.train(question=question, sql=sql)
  326. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  327. else:
  328. training_id = vn.train(sql=sql)
  329. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  330. from common.result import success_response
  331. return jsonify(success_response(
  332. response_text="问题-SQL对训练成功",
  333. data={
  334. "training_id": training_id,
  335. "message": "Question-SQL pair trained successfully"
  336. }
  337. ))
  338. except Exception as e:
  339. from common.result import internal_error_response
  340. return jsonify(internal_error_response(
  341. response_text="训练失败,请稍后重试"
  342. )), 500
  343. # ============ LangGraph Agent 集成 ============
  344. # 全局Agent实例(单例模式)
  345. citu_langraph_agent = None
  346. def get_citu_langraph_agent():
  347. """获取LangGraph Agent实例(懒加载)"""
  348. global citu_langraph_agent
  349. if citu_langraph_agent is None:
  350. try:
  351. from agent.citu_agent import CituLangGraphAgent
  352. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  353. citu_langraph_agent = CituLangGraphAgent()
  354. print("[CITU_APP] LangGraph Agent实例创建成功")
  355. except ImportError as e:
  356. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  357. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  358. raise Exception(f"Agent模块导入失败: {str(e)}")
  359. except Exception as e:
  360. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  361. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  362. # 提供更有用的错误信息
  363. if "config" in str(e).lower():
  364. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  365. elif "llm" in str(e).lower():
  366. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  367. elif "tool" in str(e).lower():
  368. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  369. raise Exception(f"Agent初始化失败: {str(e)}")
  370. return citu_langraph_agent
  371. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  372. def ask_agent():
  373. """
  374. 支持对话上下文的ask_agent API - 修正版
  375. """
  376. req = request.get_json(force=True)
  377. question = req.get("question", None)
  378. browser_session_id = req.get("session_id", None)
  379. # 新增参数解析
  380. user_id_input = req.get("user_id", None)
  381. conversation_id_input = req.get("conversation_id", None)
  382. continue_conversation = req.get("continue_conversation", False)
  383. if not question:
  384. return jsonify(bad_request_response(
  385. response_text="缺少必需参数:question",
  386. missing_params=["question"]
  387. )), 400
  388. try:
  389. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  390. login_user_id = session.get('user_id') if 'user_id' in session else None
  391. # 2. 智能ID解析(修正:传入登录用户ID)
  392. user_id = redis_conversation_manager.resolve_user_id(
  393. user_id_input, browser_session_id, request.remote_addr, login_user_id
  394. )
  395. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  396. user_id, conversation_id_input, continue_conversation
  397. )
  398. # 3. 获取上下文(提前到缓存检查之前)
  399. context = redis_conversation_manager.get_context(conversation_id)
  400. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  401. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  402. if cached_answer:
  403. print(f"[AGENT_API] 使用缓存答案")
  404. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  405. cached_response_type = cached_answer.get("type", "UNKNOWN")
  406. if cached_response_type == "DATABASE":
  407. # DATABASE类型:按优先级选择内容
  408. if cached_answer.get("response"):
  409. # 优先级1:错误或解释性回复(如SQL生成失败)
  410. assistant_response = cached_answer.get("response")
  411. elif cached_answer.get("summary"):
  412. # 优先级2:查询成功的摘要
  413. assistant_response = cached_answer.get("summary")
  414. elif cached_answer.get("query_result"):
  415. # 优先级3:构造简单描述
  416. query_result = cached_answer.get("query_result")
  417. row_count = query_result.get("row_count", 0)
  418. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  419. else:
  420. # 异常情况
  421. assistant_response = "数据库查询已处理。"
  422. else:
  423. # CHAT类型:直接使用response
  424. assistant_response = cached_answer.get("response", "")
  425. # 更新对话历史
  426. redis_conversation_manager.save_message(conversation_id, "user", question)
  427. redis_conversation_manager.save_message(
  428. conversation_id, "assistant",
  429. assistant_response,
  430. metadata={"from_cache": True}
  431. )
  432. # 添加对话信息到缓存结果
  433. cached_answer["conversation_id"] = conversation_id
  434. cached_answer["user_id"] = user_id
  435. cached_answer["from_cache"] = True
  436. cached_answer.update(conversation_status)
  437. # 使用agent_success_response返回标准格式
  438. return jsonify(agent_success_response(
  439. response_type=cached_answer.get("type", "UNKNOWN"),
  440. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  441. sql=cached_answer.get("sql"),
  442. query_result=cached_answer.get("query_result"),
  443. summary=cached_answer.get("summary"),
  444. session_id=browser_session_id,
  445. execution_path=cached_answer.get("execution_path", []),
  446. classification_info=cached_answer.get("classification_info", {}),
  447. conversation_id=conversation_id,
  448. user_id=user_id,
  449. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  450. context_used=bool(context),
  451. from_cache=True,
  452. conversation_status=conversation_status["status"],
  453. conversation_message=conversation_status["message"],
  454. requested_conversation_id=conversation_status.get("requested_id")
  455. ))
  456. # 5. 保存用户消息
  457. redis_conversation_manager.save_message(conversation_id, "user", question)
  458. # 6. 构建带上下文的问题
  459. if context:
  460. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  461. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  462. else:
  463. enhanced_question = question
  464. print(f"[AGENT_API] 新对话,无上下文")
  465. # 7. 现有Agent处理逻辑(保持不变)
  466. try:
  467. agent = get_citu_langraph_agent()
  468. except Exception as e:
  469. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  470. return jsonify(service_unavailable_response(
  471. response_text="AI服务暂时不可用,请稍后重试",
  472. can_retry=True
  473. )), 503
  474. agent_result = agent.process_question(
  475. question=enhanced_question, # 使用增强后的问题
  476. session_id=browser_session_id
  477. )
  478. # 8. 处理Agent结果
  479. if agent_result.get("success", False):
  480. # 修正:直接从agent_result获取字段,因为它就是final_response
  481. response_type = agent_result.get("type", "UNKNOWN")
  482. response_text = agent_result.get("response", "")
  483. sql = agent_result.get("sql")
  484. query_result = agent_result.get("query_result")
  485. summary = agent_result.get("summary")
  486. execution_path = agent_result.get("execution_path", [])
  487. classification_info = agent_result.get("classification_info", {})
  488. # 确定助手回复内容的优先级
  489. if response_type == "DATABASE":
  490. # DATABASE类型:按优先级选择内容
  491. if response_text:
  492. # 优先级1:错误或解释性回复(如SQL生成失败)
  493. assistant_response = response_text
  494. elif summary:
  495. # 优先级2:查询成功的摘要
  496. assistant_response = summary
  497. elif query_result:
  498. # 优先级3:构造简单描述
  499. row_count = query_result.get("row_count", 0)
  500. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  501. else:
  502. # 异常情况
  503. assistant_response = "数据库查询已处理。"
  504. else:
  505. # CHAT类型:直接使用response
  506. assistant_response = response_text
  507. # 保存助手回复
  508. redis_conversation_manager.save_message(
  509. conversation_id, "assistant", assistant_response,
  510. metadata={
  511. "type": response_type,
  512. "sql": sql,
  513. "execution_path": execution_path
  514. }
  515. )
  516. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  517. # 直接缓存agent_result,它已经包含所有需要的字段
  518. redis_conversation_manager.cache_answer(question, agent_result, context)
  519. # 使用agent_success_response的正确方式
  520. return jsonify(agent_success_response(
  521. response_type=response_type,
  522. response=response_text, # 修正:使用response而不是response_text
  523. sql=sql,
  524. query_result=query_result,
  525. summary=summary,
  526. session_id=browser_session_id,
  527. execution_path=execution_path,
  528. classification_info=classification_info,
  529. conversation_id=conversation_id,
  530. user_id=user_id,
  531. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  532. context_used=bool(context),
  533. from_cache=False,
  534. conversation_status=conversation_status["status"],
  535. conversation_message=conversation_status["message"],
  536. requested_conversation_id=conversation_status.get("requested_id")
  537. ))
  538. else:
  539. # 错误处理(修正:确保使用现有的错误响应格式)
  540. error_message = agent_result.get("error", "Agent处理失败")
  541. error_code = agent_result.get("error_code", 500)
  542. return jsonify(agent_error_response(
  543. response_text=error_message,
  544. error_type="agent_processing_failed",
  545. code=error_code,
  546. session_id=browser_session_id,
  547. conversation_id=conversation_id,
  548. user_id=user_id
  549. )), error_code
  550. except Exception as e:
  551. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  552. return jsonify(internal_error_response(
  553. response_text="查询处理失败,请稍后重试"
  554. )), 500
  555. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  556. def agent_health():
  557. """
  558. Agent健康检查接口
  559. 响应格式:
  560. {
  561. "success": true/false,
  562. "code": 200/503,
  563. "message": "healthy/degraded/unhealthy",
  564. "data": {
  565. "status": "healthy/degraded/unhealthy",
  566. "test_result": true/false,
  567. "workflow_compiled": true/false,
  568. "tools_count": 4,
  569. "message": "详细信息",
  570. "timestamp": "2024-01-01T12:00:00",
  571. "checks": {
  572. "agent_creation": true/false,
  573. "tools_import": true/false,
  574. "llm_connection": true/false,
  575. "classifier_ready": true/false
  576. }
  577. }
  578. }
  579. """
  580. try:
  581. # 基础健康检查
  582. health_data = {
  583. "status": "unknown",
  584. "test_result": False,
  585. "workflow_compiled": False,
  586. "tools_count": 0,
  587. "message": "",
  588. "timestamp": datetime.now().isoformat(),
  589. "checks": {
  590. "agent_creation": False,
  591. "tools_import": False,
  592. "llm_connection": False,
  593. "classifier_ready": False
  594. }
  595. }
  596. # 检查1: Agent创建
  597. try:
  598. agent = get_citu_langraph_agent()
  599. health_data["checks"]["agent_creation"] = True
  600. health_data["workflow_compiled"] = agent.workflow is not None
  601. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  602. except Exception as e:
  603. health_data["message"] = f"Agent创建失败: {str(e)}"
  604. from common.result import health_error_response
  605. return jsonify(health_error_response(
  606. status="unhealthy",
  607. **health_data
  608. )), 503
  609. # 检查2: 工具导入
  610. try:
  611. from agent.tools import TOOLS
  612. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  613. except Exception as e:
  614. health_data["message"] = f"工具导入失败: {str(e)}"
  615. # 检查3: LLM连接(简单测试)
  616. try:
  617. from agent.utils import get_compatible_llm
  618. llm = get_compatible_llm()
  619. health_data["checks"]["llm_connection"] = llm is not None
  620. except Exception as e:
  621. health_data["message"] = f"LLM连接失败: {str(e)}"
  622. # 检查4: 分类器准备
  623. try:
  624. from agent.classifier import QuestionClassifier
  625. classifier = QuestionClassifier()
  626. health_data["checks"]["classifier_ready"] = True
  627. except Exception as e:
  628. health_data["message"] = f"分类器失败: {str(e)}"
  629. # 检查5: 完整流程测试(可选)
  630. try:
  631. if all(health_data["checks"].values()):
  632. test_result = agent.health_check()
  633. health_data["test_result"] = test_result.get("status") == "healthy"
  634. health_data["status"] = test_result.get("status", "unknown")
  635. health_data["message"] = test_result.get("message", "健康检查完成")
  636. else:
  637. health_data["status"] = "degraded"
  638. health_data["message"] = "部分组件异常"
  639. except Exception as e:
  640. health_data["status"] = "degraded"
  641. health_data["message"] = f"完整测试失败: {str(e)}"
  642. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  643. from common.result import health_success_response, health_error_response
  644. if health_data["status"] == "healthy":
  645. return jsonify(health_success_response(**health_data))
  646. elif health_data["status"] == "degraded":
  647. return jsonify(health_error_response(status="degraded", **health_data)), 503
  648. else:
  649. return jsonify(health_error_response(status="unhealthy", **health_data)), 503
  650. except Exception as e:
  651. print(f"[ERROR] 健康检查异常: {str(e)}")
  652. from common.result import internal_error_response
  653. return jsonify(internal_error_response(
  654. response_text="健康检查失败,请稍后重试"
  655. )), 500
  656. # ==================== 日常管理API ====================
  657. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  658. def cache_overview():
  659. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  660. try:
  661. cache = app.cache
  662. result_data = {
  663. 'overview_summary': {
  664. 'total_conversations': 0,
  665. 'total_sessions': 0,
  666. 'query_time': datetime.now().isoformat()
  667. },
  668. 'recent_conversations': [], # 最近的对话
  669. 'session_summary': [] # 会话摘要
  670. }
  671. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  672. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  673. # 获取会话信息
  674. if hasattr(cache, 'get_all_sessions'):
  675. all_sessions = cache.get_all_sessions()
  676. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  677. # 会话摘要(按最近活动排序)
  678. session_list = []
  679. for session_id, session_data in all_sessions.items():
  680. session_summary = {
  681. 'session_id': session_id,
  682. 'start_time': session_data['start_time'].isoformat(),
  683. 'conversation_count': session_data.get('conversation_count', 0),
  684. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  685. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  686. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  687. }
  688. session_list.append(session_summary)
  689. # 按最后活动时间排序
  690. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  691. result_data['session_summary'] = session_list
  692. # 最近的对话(最多显示10个)
  693. conversation_list = []
  694. for conversation_id, conversation_data in cache.cache.items():
  695. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  696. conversation_info = {
  697. 'conversation_id': conversation_id,
  698. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  699. 'session_id': cache.conversation_to_session.get(conversation_id),
  700. 'has_question': 'question' in conversation_data,
  701. 'has_sql': 'sql' in conversation_data,
  702. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  703. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  704. }
  705. # 计算对话持续时间
  706. if conversation_start_time:
  707. duration = datetime.now() - conversation_start_time
  708. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  709. conversation_list.append(conversation_info)
  710. # 按对话开始时间排序,显示最新的10个
  711. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  712. result_data['recent_conversations'] = conversation_list[:10]
  713. from common.result import success_response
  714. return jsonify(success_response(
  715. response_text="缓存概览查询完成",
  716. data=result_data
  717. ))
  718. except Exception as e:
  719. from common.result import internal_error_response
  720. return jsonify(internal_error_response(
  721. response_text="获取缓存概览失败,请稍后重试"
  722. )), 500
  723. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  724. def cache_stats():
  725. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  726. try:
  727. cache = app.cache
  728. current_time = datetime.now()
  729. stats = {
  730. 'basic_stats': {
  731. 'total_sessions': len(getattr(cache, 'session_info', {})),
  732. 'total_conversations': len(getattr(cache, 'cache', {})),
  733. 'active_sessions': 0, # 最近30分钟有活动
  734. 'average_conversations_per_session': 0
  735. },
  736. 'time_distribution': {
  737. 'sessions': {
  738. 'last_1_hour': 0,
  739. 'last_6_hours': 0,
  740. 'last_24_hours': 0,
  741. 'last_7_days': 0,
  742. 'older': 0
  743. },
  744. 'conversations': {
  745. 'last_1_hour': 0,
  746. 'last_6_hours': 0,
  747. 'last_24_hours': 0,
  748. 'last_7_days': 0,
  749. 'older': 0
  750. }
  751. },
  752. 'session_details': [],
  753. 'time_ranges': {
  754. 'oldest_session': None,
  755. 'newest_session': None,
  756. 'oldest_conversation': None,
  757. 'newest_conversation': None
  758. }
  759. }
  760. # 会话统计
  761. if hasattr(cache, 'session_info'):
  762. session_times = []
  763. total_conversations = 0
  764. for session_id, session_data in cache.session_info.items():
  765. start_time = session_data['start_time']
  766. session_times.append(start_time)
  767. conversation_count = len(session_data.get('conversations', []))
  768. total_conversations += conversation_count
  769. # 检查活跃状态
  770. last_activity = session_data.get('last_activity', session_data['start_time'])
  771. if (current_time - last_activity).total_seconds() < 1800:
  772. stats['basic_stats']['active_sessions'] += 1
  773. # 时间分布统计
  774. age_hours = (current_time - start_time).total_seconds() / 3600
  775. if age_hours <= 1:
  776. stats['time_distribution']['sessions']['last_1_hour'] += 1
  777. elif age_hours <= 6:
  778. stats['time_distribution']['sessions']['last_6_hours'] += 1
  779. elif age_hours <= 24:
  780. stats['time_distribution']['sessions']['last_24_hours'] += 1
  781. elif age_hours <= 168: # 7 days
  782. stats['time_distribution']['sessions']['last_7_days'] += 1
  783. else:
  784. stats['time_distribution']['sessions']['older'] += 1
  785. # 会话详细信息
  786. session_duration = current_time - start_time
  787. stats['session_details'].append({
  788. 'session_id': session_id,
  789. 'start_time': start_time.isoformat(),
  790. 'last_activity': last_activity.isoformat(),
  791. 'conversation_count': conversation_count,
  792. 'duration_seconds': session_duration.total_seconds(),
  793. 'duration_formatted': str(session_duration),
  794. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  795. 'browser_session_id': session_data.get('browser_session_id')
  796. })
  797. # 计算平均值
  798. if len(cache.session_info) > 0:
  799. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  800. # 时间范围
  801. if session_times:
  802. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  803. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  804. # 对话统计
  805. if hasattr(cache, 'conversation_start_times'):
  806. conversation_times = []
  807. for conv_time in cache.conversation_start_times.values():
  808. conversation_times.append(conv_time)
  809. age_hours = (current_time - conv_time).total_seconds() / 3600
  810. if age_hours <= 1:
  811. stats['time_distribution']['conversations']['last_1_hour'] += 1
  812. elif age_hours <= 6:
  813. stats['time_distribution']['conversations']['last_6_hours'] += 1
  814. elif age_hours <= 24:
  815. stats['time_distribution']['conversations']['last_24_hours'] += 1
  816. elif age_hours <= 168:
  817. stats['time_distribution']['conversations']['last_7_days'] += 1
  818. else:
  819. stats['time_distribution']['conversations']['older'] += 1
  820. if conversation_times:
  821. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  822. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  823. # 按最近活动排序会话详情
  824. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  825. from common.result import success_response
  826. return jsonify(success_response(
  827. response_text="缓存统计信息查询完成",
  828. data=stats
  829. ))
  830. except Exception as e:
  831. from common.result import internal_error_response
  832. return jsonify(internal_error_response(
  833. response_text="获取缓存统计失败,请稍后重试"
  834. )), 500
  835. # ==================== 高级功能API ====================
  836. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  837. def cache_export():
  838. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  839. try:
  840. cache = app.cache
  841. # 验证缓存的实际结构
  842. if not hasattr(cache, 'cache'):
  843. from common.result import internal_error_response
  844. return jsonify(internal_error_response(
  845. response_text="缓存对象结构异常,请联系系统管理员"
  846. )), 500
  847. if not isinstance(cache.cache, dict):
  848. from common.result import internal_error_response
  849. return jsonify(internal_error_response(
  850. response_text="缓存数据类型异常,请联系系统管理员"
  851. )), 500
  852. # 定义JSON序列化辅助函数
  853. def make_json_serializable(obj):
  854. """将对象转换为JSON可序列化的格式"""
  855. if obj is None:
  856. return None
  857. elif isinstance(obj, (str, int, float, bool)):
  858. return obj
  859. elif isinstance(obj, (list, tuple)):
  860. return [make_json_serializable(item) for item in obj]
  861. elif isinstance(obj, dict):
  862. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  863. elif hasattr(obj, 'isoformat'): # datetime objects
  864. return obj.isoformat()
  865. elif hasattr(obj, 'item'): # numpy scalars
  866. return obj.item()
  867. elif hasattr(obj, 'tolist'): # numpy arrays
  868. return obj.tolist()
  869. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  870. return str(obj)
  871. else:
  872. return str(obj)
  873. # 获取完整的原始缓存数据
  874. raw_cache = cache.cache
  875. # 获取会话和对话时间信息
  876. conversation_times = getattr(cache, 'conversation_start_times', {})
  877. session_info = getattr(cache, 'session_info', {})
  878. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  879. export_data = {
  880. 'export_metadata': {
  881. 'export_time': datetime.now().isoformat(),
  882. 'total_conversations': len(raw_cache),
  883. 'total_sessions': len(session_info),
  884. 'cache_type': type(cache).__name__,
  885. 'cache_object_info': str(cache),
  886. 'has_session_times': bool(session_info),
  887. 'has_conversation_times': bool(conversation_times)
  888. },
  889. 'session_info': {
  890. session_id: {
  891. 'start_time': session_data['start_time'].isoformat(),
  892. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  893. 'conversations': session_data['conversations'],
  894. 'conversation_count': len(session_data['conversations']),
  895. 'browser_session_id': session_data.get('browser_session_id'),
  896. 'user_info': session_data.get('user_info', {})
  897. }
  898. for session_id, session_data in session_info.items()
  899. },
  900. 'conversation_times': {
  901. conversation_id: start_time.isoformat()
  902. for conversation_id, start_time in conversation_times.items()
  903. },
  904. 'conversation_to_session_mapping': conversation_to_session,
  905. 'conversations': {}
  906. }
  907. # 处理每个对话的完整数据
  908. for conversation_id, conversation_data in raw_cache.items():
  909. # 获取时间信息
  910. conversation_start_time = conversation_times.get(conversation_id)
  911. session_id = conversation_to_session.get(conversation_id)
  912. session_start_time = None
  913. if session_id and session_id in session_info:
  914. session_start_time = session_info[session_id]['start_time']
  915. processed_conversation = {
  916. 'conversation_id': conversation_id,
  917. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  918. 'session_id': session_id,
  919. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  920. 'field_count': len(conversation_data),
  921. 'fields': {}
  922. }
  923. # 添加时间计算
  924. if conversation_start_time:
  925. conversation_duration = datetime.now() - conversation_start_time
  926. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  927. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  928. if session_start_time:
  929. session_duration = datetime.now() - session_start_time
  930. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  931. processed_conversation['session_duration_formatted'] = str(session_duration)
  932. # 处理每个字段,确保JSON序列化安全
  933. for field_name, field_value in conversation_data.items():
  934. field_info = {
  935. 'field_name': field_name,
  936. 'data_type': type(field_value).__name__,
  937. 'is_none': field_value is None
  938. }
  939. try:
  940. if field_value is None:
  941. field_info['value'] = None
  942. elif field_name in ['conversation_start_time', 'session_start_time']:
  943. # 处理时间字段
  944. field_info['content'] = make_json_serializable(field_value)
  945. elif field_name == 'df' and field_value is not None:
  946. # DataFrame的安全处理
  947. if hasattr(field_value, 'to_dict'):
  948. # 安全地处理dtypes
  949. try:
  950. dtypes_dict = {}
  951. for col, dtype in field_value.dtypes.items():
  952. dtypes_dict[col] = str(dtype)
  953. except Exception:
  954. dtypes_dict = {"error": "无法序列化dtypes"}
  955. # 安全地处理内存使用
  956. try:
  957. memory_usage = field_value.memory_usage(deep=True)
  958. memory_dict = {}
  959. for idx, usage in memory_usage.items():
  960. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  961. except Exception:
  962. memory_dict = {"error": "无法获取内存使用信息"}
  963. field_info.update({
  964. 'dataframe_info': {
  965. 'shape': list(field_value.shape),
  966. 'columns': list(field_value.columns),
  967. 'dtypes': dtypes_dict,
  968. 'index_info': {
  969. 'type': type(field_value.index).__name__,
  970. 'length': len(field_value.index)
  971. }
  972. },
  973. 'data': make_json_serializable(field_value.to_dict('records')),
  974. 'memory_usage': memory_dict
  975. })
  976. else:
  977. field_info['value'] = str(field_value)
  978. field_info['note'] = 'not_standard_dataframe'
  979. elif field_name == 'fig_json':
  980. # 图表JSON数据处理
  981. if isinstance(field_value, str):
  982. try:
  983. import json
  984. parsed_fig = json.loads(field_value)
  985. field_info.update({
  986. 'json_valid': True,
  987. 'json_size_bytes': len(field_value),
  988. 'plotly_structure': {
  989. 'has_data': 'data' in parsed_fig,
  990. 'has_layout': 'layout' in parsed_fig,
  991. 'data_traces_count': len(parsed_fig.get('data', [])),
  992. },
  993. 'raw_json': field_value
  994. })
  995. except json.JSONDecodeError:
  996. field_info.update({
  997. 'json_valid': False,
  998. 'raw_content': str(field_value)
  999. })
  1000. else:
  1001. field_info['value'] = make_json_serializable(field_value)
  1002. elif field_name == 'followup_questions':
  1003. # 后续问题列表
  1004. field_info.update({
  1005. 'content': make_json_serializable(field_value)
  1006. })
  1007. elif field_name in ['question', 'sql', 'summary']:
  1008. # 文本字段
  1009. if isinstance(field_value, str):
  1010. field_info.update({
  1011. 'text_length': len(field_value),
  1012. 'content': field_value
  1013. })
  1014. else:
  1015. field_info['value'] = make_json_serializable(field_value)
  1016. else:
  1017. # 未知字段的安全处理
  1018. field_info['content'] = make_json_serializable(field_value)
  1019. except Exception as e:
  1020. field_info.update({
  1021. 'processing_error': str(e),
  1022. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1023. })
  1024. processed_conversation['fields'][field_name] = field_info
  1025. export_data['conversations'][conversation_id] = processed_conversation
  1026. # 添加缓存统计信息
  1027. field_frequency = {}
  1028. data_types_found = set()
  1029. total_dataframes = 0
  1030. total_questions = 0
  1031. for conv_data in export_data['conversations'].values():
  1032. for field_name, field_info in conv_data['fields'].items():
  1033. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1034. data_types_found.add(field_info['data_type'])
  1035. if field_name == 'df' and not field_info['is_none']:
  1036. total_dataframes += 1
  1037. if field_name == 'question' and not field_info['is_none']:
  1038. total_questions += 1
  1039. export_data['cache_statistics'] = {
  1040. 'field_frequency': field_frequency,
  1041. 'data_types_found': list(data_types_found),
  1042. 'total_dataframes': total_dataframes,
  1043. 'total_questions': total_questions,
  1044. 'has_session_timing': 'session_start_time' in field_frequency,
  1045. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1046. }
  1047. from common.result import success_response
  1048. return jsonify(success_response(
  1049. response_text="缓存数据导出完成",
  1050. data=export_data
  1051. ))
  1052. except Exception as e:
  1053. import traceback
  1054. error_details = {
  1055. 'error_message': str(e),
  1056. 'error_type': type(e).__name__,
  1057. 'traceback': traceback.format_exc()
  1058. }
  1059. from common.result import internal_error_response
  1060. return jsonify(internal_error_response(
  1061. response_text="导出缓存失败,请稍后重试"
  1062. )), 500
  1063. # ==================== 清理功能API ====================
  1064. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1065. def cache_preview_cleanup():
  1066. """清理功能:预览删除操作 - 保持原功能"""
  1067. try:
  1068. req = request.get_json(force=True)
  1069. # 时间条件 - 支持三种方式
  1070. older_than_hours = req.get('older_than_hours')
  1071. older_than_days = req.get('older_than_days')
  1072. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1073. cache = app.cache
  1074. # 计算截止时间
  1075. cutoff_time = None
  1076. time_condition = None
  1077. if older_than_hours:
  1078. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1079. time_condition = f"older_than_hours: {older_than_hours}"
  1080. elif older_than_days:
  1081. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1082. time_condition = f"older_than_days: {older_than_days}"
  1083. elif before_timestamp:
  1084. try:
  1085. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1086. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1087. time_condition = f"before_timestamp: {before_timestamp}"
  1088. except ValueError:
  1089. from common.result import validation_failed_response
  1090. return jsonify(validation_failed_response(
  1091. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1092. )), 422
  1093. else:
  1094. from common.result import bad_request_response
  1095. return jsonify(bad_request_response(
  1096. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1097. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1098. )), 400
  1099. preview = {
  1100. 'time_condition': time_condition,
  1101. 'cutoff_time': cutoff_time.isoformat(),
  1102. 'will_be_removed': {
  1103. 'sessions': []
  1104. },
  1105. 'will_be_kept': {
  1106. 'sessions_count': 0,
  1107. 'conversations_count': 0
  1108. },
  1109. 'summary': {
  1110. 'sessions_to_remove': 0,
  1111. 'conversations_to_remove': 0,
  1112. 'sessions_to_keep': 0,
  1113. 'conversations_to_keep': 0
  1114. }
  1115. }
  1116. # 预览按session删除
  1117. sessions_to_remove_count = 0
  1118. conversations_to_remove_count = 0
  1119. for session_id, session_data in cache.session_info.items():
  1120. session_preview = {
  1121. 'session_id': session_id,
  1122. 'start_time': session_data['start_time'].isoformat(),
  1123. 'conversation_count': len(session_data['conversations']),
  1124. 'conversations': []
  1125. }
  1126. # 添加conversation详情
  1127. for conv_id in session_data['conversations']:
  1128. if conv_id in cache.cache:
  1129. conv_data = cache.cache[conv_id]
  1130. session_preview['conversations'].append({
  1131. 'conversation_id': conv_id,
  1132. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1133. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1134. })
  1135. if session_data['start_time'] < cutoff_time:
  1136. preview['will_be_removed']['sessions'].append(session_preview)
  1137. sessions_to_remove_count += 1
  1138. conversations_to_remove_count += len(session_data['conversations'])
  1139. else:
  1140. preview['will_be_kept']['sessions_count'] += 1
  1141. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1142. # 更新摘要统计
  1143. preview['summary'] = {
  1144. 'sessions_to_remove': sessions_to_remove_count,
  1145. 'conversations_to_remove': conversations_to_remove_count,
  1146. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1147. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1148. }
  1149. from common.result import success_response
  1150. return jsonify(success_response(
  1151. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1152. data=preview
  1153. ))
  1154. except Exception as e:
  1155. from common.result import internal_error_response
  1156. return jsonify(internal_error_response(
  1157. response_text="预览清理操作失败,请稍后重试"
  1158. )), 500
  1159. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1160. def cache_cleanup():
  1161. """清理功能:实际删除缓存 - 保持原功能"""
  1162. try:
  1163. req = request.get_json(force=True)
  1164. # 时间条件 - 支持三种方式
  1165. older_than_hours = req.get('older_than_hours')
  1166. older_than_days = req.get('older_than_days')
  1167. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1168. cache = app.cache
  1169. if not hasattr(cache, 'session_info'):
  1170. from common.result import service_unavailable_response
  1171. return jsonify(service_unavailable_response(
  1172. response_text="缓存不支持会话功能"
  1173. )), 503
  1174. # 计算截止时间
  1175. cutoff_time = None
  1176. time_condition = None
  1177. if older_than_hours:
  1178. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1179. time_condition = f"older_than_hours: {older_than_hours}"
  1180. elif older_than_days:
  1181. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1182. time_condition = f"older_than_days: {older_than_days}"
  1183. elif before_timestamp:
  1184. try:
  1185. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1186. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1187. time_condition = f"before_timestamp: {before_timestamp}"
  1188. except ValueError:
  1189. from common.result import validation_failed_response
  1190. return jsonify(validation_failed_response(
  1191. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1192. )), 422
  1193. else:
  1194. from common.result import bad_request_response
  1195. return jsonify(bad_request_response(
  1196. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1197. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1198. )), 400
  1199. cleanup_stats = {
  1200. 'time_condition': time_condition,
  1201. 'cutoff_time': cutoff_time.isoformat(),
  1202. 'sessions_removed': 0,
  1203. 'conversations_removed': 0,
  1204. 'sessions_kept': 0,
  1205. 'conversations_kept': 0,
  1206. 'removed_session_ids': [],
  1207. 'removed_conversation_ids': []
  1208. }
  1209. # 按session删除
  1210. sessions_to_remove = []
  1211. for session_id, session_data in cache.session_info.items():
  1212. if session_data['start_time'] < cutoff_time:
  1213. sessions_to_remove.append(session_id)
  1214. # 删除符合条件的sessions及其所有conversations
  1215. for session_id in sessions_to_remove:
  1216. session_data = cache.session_info[session_id]
  1217. conversations_in_session = session_data['conversations'].copy()
  1218. # 删除session中的所有conversations
  1219. for conv_id in conversations_in_session:
  1220. if conv_id in cache.cache:
  1221. del cache.cache[conv_id]
  1222. cleanup_stats['conversations_removed'] += 1
  1223. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1224. # 清理conversation相关的时间记录
  1225. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1226. del cache.conversation_start_times[conv_id]
  1227. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1228. del cache.conversation_to_session[conv_id]
  1229. # 删除session记录
  1230. del cache.session_info[session_id]
  1231. cleanup_stats['sessions_removed'] += 1
  1232. cleanup_stats['removed_session_ids'].append(session_id)
  1233. # 统计保留的sessions和conversations
  1234. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1235. cleanup_stats['conversations_kept'] = len(cache.cache)
  1236. from common.result import success_response
  1237. return jsonify(success_response(
  1238. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1239. data=cleanup_stats
  1240. ))
  1241. except Exception as e:
  1242. from common.result import internal_error_response
  1243. return jsonify(internal_error_response(
  1244. response_text="缓存清理失败,请稍后重试"
  1245. )), 500
  1246. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1247. def training_error_question_sql():
  1248. """
  1249. 存储错误的question-sql对到error_sql集合中
  1250. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1251. Args:
  1252. question (str, required): 用户问题
  1253. sql (str, required): 对应的错误SQL查询语句
  1254. Returns:
  1255. JSON: 包含训练ID和成功消息的响应
  1256. """
  1257. try:
  1258. data = request.get_json()
  1259. question = data.get('question')
  1260. sql = data.get('sql')
  1261. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1262. if not question or not sql:
  1263. from common.result import bad_request_response
  1264. missing_params = []
  1265. if not question:
  1266. missing_params.append("question")
  1267. if not sql:
  1268. missing_params.append("sql")
  1269. return jsonify(bad_request_response(
  1270. response_text="question和sql参数都是必需的",
  1271. missing_params=missing_params
  1272. )), 400
  1273. # 使用vn实例的train_error_sql方法存储错误SQL
  1274. id = vn.train_error_sql(question=question, sql=sql)
  1275. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1276. from common.result import success_response
  1277. return jsonify(success_response(
  1278. response_text="错误SQL对已成功存储",
  1279. data={
  1280. "id": id,
  1281. "message": "错误SQL对已成功存储到error_sql集合"
  1282. }
  1283. ))
  1284. except Exception as e:
  1285. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1286. from common.result import internal_error_response
  1287. return jsonify(internal_error_response(
  1288. response_text="存储错误SQL失败,请稍后重试"
  1289. )), 500
  1290. # ==================== Redis对话管理API ====================
  1291. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1292. def get_user_conversations(user_id: str):
  1293. """获取用户的对话列表(按时间倒序)"""
  1294. try:
  1295. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1296. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1297. return jsonify(success_response(
  1298. response_text="获取用户对话列表成功",
  1299. data={
  1300. "user_id": user_id,
  1301. "conversations": conversations,
  1302. "total_count": len(conversations)
  1303. }
  1304. ))
  1305. except Exception as e:
  1306. return jsonify(internal_error_response(
  1307. response_text="获取对话列表失败,请稍后重试"
  1308. )), 500
  1309. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1310. def get_conversation_messages(conversation_id: str):
  1311. """获取特定对话的消息历史"""
  1312. try:
  1313. limit = request.args.get('limit', type=int) # 可选参数
  1314. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1315. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1316. return jsonify(success_response(
  1317. response_text="获取对话消息成功",
  1318. data={
  1319. "conversation_id": conversation_id,
  1320. "conversation_meta": meta,
  1321. "messages": messages,
  1322. "message_count": len(messages)
  1323. }
  1324. ))
  1325. except Exception as e:
  1326. return jsonify(internal_error_response(
  1327. response_text="获取对话消息失败"
  1328. )), 500
  1329. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1330. def get_conversation_context(conversation_id: str):
  1331. """获取对话上下文(格式化用于LLM)"""
  1332. try:
  1333. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1334. context = redis_conversation_manager.get_context(conversation_id, count)
  1335. return jsonify(success_response(
  1336. response_text="获取对话上下文成功",
  1337. data={
  1338. "conversation_id": conversation_id,
  1339. "context": context,
  1340. "context_message_count": count
  1341. }
  1342. ))
  1343. except Exception as e:
  1344. return jsonify(internal_error_response(
  1345. response_text="获取对话上下文失败"
  1346. )), 500
  1347. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1348. def conversation_stats():
  1349. """获取对话系统统计信息"""
  1350. try:
  1351. stats = redis_conversation_manager.get_stats()
  1352. return jsonify(success_response(
  1353. response_text="获取统计信息成功",
  1354. data=stats
  1355. ))
  1356. except Exception as e:
  1357. return jsonify(internal_error_response(
  1358. response_text="获取统计信息失败,请稍后重试"
  1359. )), 500
  1360. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1361. def conversation_cleanup():
  1362. """手动清理过期对话"""
  1363. try:
  1364. redis_conversation_manager.cleanup_expired_conversations()
  1365. return jsonify(success_response(
  1366. response_text="对话清理完成"
  1367. ))
  1368. except Exception as e:
  1369. return jsonify(internal_error_response(
  1370. response_text="对话清理失败,请稍后重试"
  1371. )), 500
  1372. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1373. def get_user_conversations_with_messages(user_id: str):
  1374. """
  1375. 获取用户的完整对话数据(包含所有消息)
  1376. 一次性返回用户的所有对话和每个对话下的消息历史
  1377. Args:
  1378. user_id: 用户ID(路径参数)
  1379. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1380. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1381. Returns:
  1382. 包含用户所有对话和消息的完整数据
  1383. """
  1384. try:
  1385. # 获取可选参数,不传递时使用None(返回所有记录)
  1386. conversation_limit = request.args.get('conversation_limit', type=int)
  1387. message_limit = request.args.get('message_limit', type=int)
  1388. # 获取用户的对话列表
  1389. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1390. # 为每个对话获取消息历史
  1391. full_conversations = []
  1392. total_messages = 0
  1393. for conversation in conversations:
  1394. conversation_id = conversation['conversation_id']
  1395. # 获取对话消息
  1396. messages = redis_conversation_manager.get_conversation_messages(
  1397. conversation_id, message_limit
  1398. )
  1399. # 获取对话元数据
  1400. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1401. # 组合完整数据
  1402. full_conversation = {
  1403. **conversation, # 基础对话信息
  1404. 'meta': meta, # 对话元数据
  1405. 'messages': messages, # 消息列表
  1406. 'message_count': len(messages)
  1407. }
  1408. full_conversations.append(full_conversation)
  1409. total_messages += len(messages)
  1410. return jsonify(success_response(
  1411. response_text="获取用户完整对话数据成功",
  1412. data={
  1413. "user_id": user_id,
  1414. "conversations": full_conversations,
  1415. "total_conversations": len(full_conversations),
  1416. "total_messages": total_messages,
  1417. "conversation_limit_applied": conversation_limit,
  1418. "message_limit_applied": message_limit,
  1419. "query_time": datetime.now().isoformat()
  1420. }
  1421. ))
  1422. except Exception as e:
  1423. print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
  1424. return jsonify(internal_error_response(
  1425. response_text="获取用户对话数据失败,请稍后重试"
  1426. )), 500
  1427. # 前端JavaScript示例 - 如何维持会话
  1428. """
  1429. // 前端需要维护一个会话ID
  1430. class ChatSession {
  1431. constructor() {
  1432. // 从localStorage获取或创建新的会话ID
  1433. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1434. localStorage.setItem('chat_session_id', this.sessionId);
  1435. }
  1436. generateSessionId() {
  1437. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1438. }
  1439. async askQuestion(question) {
  1440. const response = await fetch('/api/v0/ask', {
  1441. method: 'POST',
  1442. headers: {
  1443. 'Content-Type': 'application/json',
  1444. },
  1445. body: JSON.stringify({
  1446. question: question,
  1447. session_id: this.sessionId // 关键:传递会话ID
  1448. })
  1449. });
  1450. return await response.json();
  1451. }
  1452. // 开始新会话
  1453. startNewSession() {
  1454. this.sessionId = this.generateSessionId();
  1455. localStorage.setItem('chat_session_id', this.sessionId);
  1456. }
  1457. }
  1458. // 使用示例
  1459. const chatSession = new ChatSession();
  1460. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1461. """
  1462. print("正在启动Flask应用: http://localhost:8084")
  1463. app.run(host="0.0.0.0", port=8084, debug=True)