citu_app.py 57 KB


  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, DISPLAY_SUMMARY_THINKING
  10. import re
  11. # 设置默认的最大返回行数
  12. DEFAULT_MAX_RETURN_ROWS = 200
  13. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  14. vn = create_vanna_instance()
  15. # 创建带时间戳的缓存
  16. timestamped_cache = WebSessionAwareMemoryCache()
  17. # 实例化 VannaFlaskApp,使用自定义缓存
  18. app = VannaFlaskApp(
  19. vn,
  20. cache=timestamped_cache, # 使用带时间戳的缓存
  21. title="辞图智能数据问答平台",
  22. logo = "https://www.citupro.com/img/logo-black-2.png",
  23. subtitle="让 AI 为你写 SQL",
  24. chart=False,
  25. allow_llm_to_see_data=True,
  26. ask_results_correct=True,
  27. followup_questions=True,
  28. debug=True
  29. )
  30. def _remove_thinking_content(text: str) -> str:
  31. """
  32. 移除文本中的 <think></think> 标签及其内容
  33. 复用自 base_llm_chat.py 中的同名方法
  34. Args:
  35. text (str): 包含可能的 thinking 标签的文本
  36. Returns:
  37. str: 移除 thinking 内容后的文本
  38. """
  39. if not text:
  40. return text
  41. # 移除 <think>...</think> 标签及其内容(支持多行)
  42. # 使用 re.DOTALL 标志使 . 匹配包括换行符在内的任何字符
  43. cleaned_text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL | re.IGNORECASE)
  44. # 移除可能的多余空行
  45. cleaned_text = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_text)
  46. # 去除开头和结尾的空白字符
  47. cleaned_text = cleaned_text.strip()
  48. return cleaned_text
  49. # 修改ask接口,支持前端传递session_id
  50. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  51. def ask_full():
  52. req = request.get_json(force=True)
  53. question = req.get("question", None)
  54. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  55. if not question:
  56. return jsonify(result.failed(message="未提供问题", code=400)), 400
  57. # 如果使用WebSessionAwareMemoryCache
  58. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  59. # 这里需要修改vanna的ask方法来支持传递session_id
  60. # 或者预先调用generate_id来建立会话关联
  61. conversation_id = app.cache.generate_id_with_browser_session(
  62. question=question,
  63. browser_session_id=browser_session_id
  64. )
  65. try:
  66. sql, df, _ = vn.ask(
  67. question=question,
  68. print_results=False,
  69. visualize=False,
  70. allow_llm_to_see_data=True
  71. )
  72. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  73. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  74. # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
  75. explanation_message = vn.last_llm_explanation
  76. if not DISPLAY_SUMMARY_THINKING:
  77. explanation_message = _remove_thinking_content(explanation_message)
  78. print(f"[DEBUG] 隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
  79. # 在解释性文本末尾添加提示语
  80. explanation_message = explanation_message + "请尝试提问其它问题。"
  81. # 使用 result.failed 返回,success为false,但在message中包含LLM友好的解释
  82. return jsonify(result.failed(
  83. message=explanation_message, # 处理后的解释性文本
  84. code=400, # 业务逻辑错误,使用400
  85. data={
  86. "sql": None,
  87. "rows": [],
  88. "columns": [],
  89. "summary": None,
  90. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  91. "session_id": browser_session_id
  92. }
  93. )), 200 # HTTP状态码仍为200,因为请求本身成功处理了
  94. # 如果sql为None但没有解释性文本,返回通用错误
  95. if sql is None:
  96. return jsonify(result.failed(
  97. message="无法生成SQL查询,请检查问题描述或数据表结构",
  98. code=400,
  99. data={
  100. "sql": None,
  101. "rows": [],
  102. "columns": [],
  103. "summary": None,
  104. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  105. "session_id": browser_session_id
  106. }
  107. )), 200
  108. # 正常SQL流程
  109. rows, columns = [], []
  110. summary = None
  111. if isinstance(df, pd.DataFrame) and not df.empty:
  112. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  113. columns = list(df.columns)
  114. # 生成数据摘要
  115. try:
  116. summary = vn.generate_summary(question=question, df=df)
  117. print(f"[INFO] 成功生成摘要: {summary}")
  118. except Exception as e:
  119. print(f"[WARNING] 生成摘要失败: {str(e)}")
  120. summary = None
  121. return jsonify(result.success(data={
  122. "sql": sql,
  123. "rows": rows,
  124. "columns": columns,
  125. "summary": summary, # 添加摘要到返回结果
  126. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  127. "session_id": browser_session_id
  128. }))
  129. except Exception as e:
  130. print(f"[ERROR] ask_full执行失败: {str(e)}")
  131. # 即使发生异常,也检查是否有业务层面的解释
  132. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  133. # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
  134. explanation_message = vn.last_llm_explanation
  135. if not DISPLAY_SUMMARY_THINKING:
  136. explanation_message = _remove_thinking_content(explanation_message)
  137. print(f"[DEBUG] 异常处理中隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
  138. # 在解释性文本末尾添加提示语
  139. explanation_message = explanation_message + "请尝试提问其它问题。"
  140. return jsonify(result.failed(
  141. message=explanation_message,
  142. code=400,
  143. data={
  144. "sql": None,
  145. "rows": [],
  146. "columns": [],
  147. "summary": None,
  148. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  149. "session_id": browser_session_id
  150. }
  151. )), 200
  152. else:
  153. # 技术错误,使用500错误码
  154. return jsonify(result.failed(
  155. message=f"查询处理失败: {str(e)}",
  156. code=500
  157. )), 500
  158. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  159. def citu_run_sql():
  160. req = request.get_json(force=True)
  161. sql = req.get('sql')
  162. if not sql:
  163. return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
  164. try:
  165. df = vn.run_sql(sql)
  166. rows, columns = [], []
  167. if isinstance(df, pd.DataFrame) and not df.empty:
  168. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  169. columns = list(df.columns)
  170. return jsonify(result.success(data={
  171. "sql": sql,
  172. "rows": rows,
  173. "columns": columns
  174. }))
  175. except Exception as e:
  176. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  177. return jsonify(result.failed(
  178. message=f"SQL执行失败: {str(e)}",
  179. code=500
  180. )), 500
  181. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  182. def ask_cached():
  183. """
  184. 带缓存功能的智能查询接口
  185. 支持会话管理和结果缓存,提高查询效率
  186. """
  187. req = request.get_json(force=True)
  188. question = req.get("question", None)
  189. browser_session_id = req.get("session_id", None)
  190. if not question:
  191. return jsonify(result.failed(message="未提供问题", code=400)), 400
  192. try:
  193. # 生成conversation_id
  194. # 调试:查看generate_id的实际行为
  195. print(f"[DEBUG] 输入问题: '{question}'")
  196. conversation_id = app.cache.generate_id(question=question)
  197. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  198. # 再次用相同问题测试
  199. conversation_id2 = app.cache.generate_id(question=question)
  200. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  201. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  202. # 检查缓存
  203. cached_sql = app.cache.get(id=conversation_id, field="sql")
  204. if cached_sql is not None:
  205. # 缓存命中
  206. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  207. sql = cached_sql
  208. df = app.cache.get(id=conversation_id, field="df")
  209. summary = app.cache.get(id=conversation_id, field="summary")
  210. else:
  211. # 缓存未命中,执行新查询
  212. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  213. sql, df, _ = vn.ask(
  214. question=question,
  215. print_results=False,
  216. visualize=False,
  217. allow_llm_to_see_data=True
  218. )
  219. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  220. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  221. # 根据 DISPLAY_SUMMARY_THINKING 参数决定是否移除 thinking 内容
  222. explanation_message = vn.last_llm_explanation
  223. if not DISPLAY_SUMMARY_THINKING:
  224. explanation_message = _remove_thinking_content(explanation_message)
  225. print(f"[DEBUG] ask_cached中隐藏thinking内容 - 原始长度: {len(vn.last_llm_explanation)}, 处理后长度: {len(explanation_message)}")
  226. # 在解释性文本末尾添加提示语
  227. explanation_message = explanation_message + "请尝试用其它方式提问。"
  228. return jsonify(result.failed(
  229. message=explanation_message,
  230. code=400,
  231. data={
  232. "sql": None,
  233. "rows": [],
  234. "columns": [],
  235. "summary": None,
  236. "conversation_id": conversation_id,
  237. "session_id": browser_session_id,
  238. "cached": False
  239. }
  240. )), 200
  241. # 如果sql为None但没有解释性文本,返回通用错误
  242. if sql is None:
  243. return jsonify(result.failed(
  244. message="无法生成SQL查询,请检查问题描述或数据表结构",
  245. code=400,
  246. data={
  247. "sql": None,
  248. "rows": [],
  249. "columns": [],
  250. "summary": None,
  251. "conversation_id": conversation_id,
  252. "session_id": browser_session_id,
  253. "cached": False
  254. }
  255. )), 200
  256. # 缓存结果
  257. app.cache.set(id=conversation_id, field="question", value=question)
  258. app.cache.set(id=conversation_id, field="sql", value=sql)
  259. app.cache.set(id=conversation_id, field="df", value=df)
  260. # 生成并缓存摘要
  261. summary = None
  262. if isinstance(df, pd.DataFrame) and not df.empty:
  263. try:
  264. summary = vn.generate_summary(question=question, df=df)
  265. print(f"[INFO] 成功生成摘要: {summary}")
  266. except Exception as e:
  267. print(f"[WARNING] 生成摘要失败: {str(e)}")
  268. summary = None
  269. app.cache.set(id=conversation_id, field="summary", value=summary)
  270. # 处理返回数据
  271. rows, columns = [], []
  272. if isinstance(df, pd.DataFrame) and not df.empty:
  273. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  274. columns = list(df.columns)
  275. return jsonify(result.success(data={
  276. "sql": sql,
  277. "rows": rows,
  278. "columns": columns,
  279. "summary": summary,
  280. "conversation_id": conversation_id,
  281. "session_id": browser_session_id,
  282. "cached": cached_sql is not None # 标识是否来自缓存
  283. }))
  284. except Exception as e:
  285. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  286. return jsonify(result.failed(
  287. message=f"查询处理失败: {str(e)}",
  288. code=500
  289. )), 500
  290. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  291. def citu_train_question_sql():
  292. """
  293. 训练问题-SQL对接口
  294. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  295. 支持仅传入SQL或同时传入问题和SQL进行训练。
  296. Args:
  297. question (str, optional): 用户问题
  298. sql (str, required): 对应的SQL查询语句
  299. Returns:
  300. JSON: 包含训练ID和成功消息的响应
  301. """
  302. try:
  303. req = request.get_json(force=True)
  304. question = req.get('question')
  305. sql = req.get('sql')
  306. if not sql:
  307. return jsonify(result.failed(
  308. message="'sql' are required",
  309. code=400
  310. )), 400
  311. # 正确的调用方式:同时传递question和sql
  312. if question:
  313. training_id = vn.train(question=question, sql=sql)
  314. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  315. else:
  316. training_id = vn.train(sql=sql)
  317. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  318. return jsonify(result.success(data={
  319. "training_id": training_id,
  320. "message": "Question-SQL pair trained successfully"
  321. }))
  322. except Exception as e:
  323. return jsonify(result.failed(
  324. message=f"Training failed: {str(e)}",
  325. code=500
  326. )), 500
  327. # ============ LangGraph Agent 集成 ============
  328. # 全局Agent实例(单例模式)
  329. citu_langraph_agent = None
  330. def get_citu_langraph_agent():
  331. """获取LangGraph Agent实例(懒加载)"""
  332. global citu_langraph_agent
  333. if citu_langraph_agent is None:
  334. try:
  335. from agent.citu_agent import CituLangGraphAgent
  336. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  337. citu_langraph_agent = CituLangGraphAgent()
  338. print("[CITU_APP] LangGraph Agent实例创建成功")
  339. except ImportError as e:
  340. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  341. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  342. raise Exception(f"Agent模块导入失败: {str(e)}")
  343. except Exception as e:
  344. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  345. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  346. # 提供更有用的错误信息
  347. if "config" in str(e).lower():
  348. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  349. elif "llm" in str(e).lower():
  350. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  351. elif "tool" in str(e).lower():
  352. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  353. raise Exception(f"Agent初始化失败: {str(e)}")
  354. return citu_langraph_agent
  355. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  356. def ask_agent():
  357. """
  358. 新的LangGraph Agent接口
  359. 请求格式:
  360. {
  361. "question": "用户问题",
  362. "session_id": "会话ID(可选)"
  363. }
  364. 响应格式:
  365. {
  366. "success": true/false,
  367. "code": 200,
  368. "message": "success" 或错误信息,
  369. "data": {
  370. "response": "最终回答",
  371. "type": "DATABASE/CHAT",
  372. "sql": "生成的SQL(如果是数据库查询)",
  373. "data_result": {
  374. "rows": [...],
  375. "columns": [...],
  376. "row_count": 数字
  377. },
  378. "summary": "数据摘要(如果是数据库查询)",
  379. "session_id": "会话ID",
  380. "execution_path": ["classify", "agent_database", "format_response"],
  381. "classification_info": {
  382. "confidence": 0.95,
  383. "reason": "分类原因",
  384. "method": "rule_based/llm_based"
  385. },
  386. "agent_version": "langgraph_v1"
  387. }
  388. }
  389. """
  390. req = request.get_json(force=True)
  391. question = req.get("question", None)
  392. browser_session_id = req.get("session_id", None)
  393. if not question:
  394. return jsonify(result.failed(message="未提供问题", code=400)), 400
  395. try:
  396. # 专门处理Agent初始化异常
  397. try:
  398. agent = get_citu_langraph_agent()
  399. except Exception as e:
  400. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  401. return jsonify(result.failed(
  402. message="AI服务暂时不可用,请稍后重试",
  403. code=503,
  404. data={
  405. "session_id": browser_session_id,
  406. "execution_path": ["agent_init_error"],
  407. "agent_version": "langgraph_v1",
  408. "timestamp": datetime.now().isoformat(),
  409. "error_type": "agent_initialization_failed"
  410. }
  411. )), 503
  412. # 调用Agent处理问题
  413. agent_result = agent.process_question(
  414. question=question,
  415. session_id=browser_session_id
  416. )
  417. # 统一返回格式
  418. if agent_result.get("success", False):
  419. return jsonify(result.success(data={
  420. "response": agent_result.get("response", ""),
  421. "type": agent_result.get("type", "UNKNOWN"),
  422. "sql": agent_result.get("sql"),
  423. "data_result": agent_result.get("data_result"),
  424. "summary": agent_result.get("summary"),
  425. "session_id": browser_session_id,
  426. "execution_path": agent_result.get("execution_path", []),
  427. "classification_info": agent_result.get("classification_info", {}),
  428. "agent_version": "langgraph_v1",
  429. "timestamp": datetime.now().isoformat()
  430. }))
  431. else:
  432. return jsonify(result.failed(
  433. message=agent_result.get("error", "Agent处理失败"),
  434. code=agent_result.get("error_code", 500),
  435. data={
  436. "session_id": browser_session_id,
  437. "execution_path": agent_result.get("execution_path", []),
  438. "classification_info": agent_result.get("classification_info", {}),
  439. "agent_version": "langgraph_v1",
  440. "timestamp": datetime.now().isoformat()
  441. }
  442. )), 200 # HTTP 200但业务失败
  443. except Exception as e:
  444. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  445. return jsonify(result.failed(
  446. message="请求处理异常,请稍后重试",
  447. code=500,
  448. data={
  449. "session_id": browser_session_id,
  450. "execution_path": ["general_error"],
  451. "agent_version": "langgraph_v1",
  452. "timestamp": datetime.now().isoformat(),
  453. "error_type": "request_processing_failed"
  454. }
  455. )), 500
  456. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  457. def agent_health():
  458. """
  459. Agent健康检查接口
  460. 响应格式:
  461. {
  462. "success": true/false,
  463. "code": 200/503,
  464. "message": "healthy/degraded/unhealthy",
  465. "data": {
  466. "status": "healthy/degraded/unhealthy",
  467. "test_result": true/false,
  468. "workflow_compiled": true/false,
  469. "tools_count": 4,
  470. "message": "详细信息",
  471. "timestamp": "2024-01-01T12:00:00",
  472. "checks": {
  473. "agent_creation": true/false,
  474. "tools_import": true/false,
  475. "llm_connection": true/false,
  476. "classifier_ready": true/false
  477. }
  478. }
  479. }
  480. """
  481. try:
  482. # 基础健康检查
  483. health_data = {
  484. "status": "unknown",
  485. "test_result": False,
  486. "workflow_compiled": False,
  487. "tools_count": 0,
  488. "message": "",
  489. "timestamp": datetime.now().isoformat(),
  490. "checks": {
  491. "agent_creation": False,
  492. "tools_import": False,
  493. "llm_connection": False,
  494. "classifier_ready": False
  495. }
  496. }
  497. # 检查1: Agent创建
  498. try:
  499. agent = get_citu_langraph_agent()
  500. health_data["checks"]["agent_creation"] = True
  501. health_data["workflow_compiled"] = agent.workflow is not None
  502. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  503. except Exception as e:
  504. health_data["message"] = f"Agent创建失败: {str(e)}"
  505. return jsonify(result.failed(
  506. message="Agent状态: unhealthy",
  507. data=health_data,
  508. code=503
  509. )), 503
  510. # 检查2: 工具导入
  511. try:
  512. from agent.tools import TOOLS
  513. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  514. except Exception as e:
  515. health_data["message"] = f"工具导入失败: {str(e)}"
  516. # 检查3: LLM连接(简单测试)
  517. try:
  518. from agent.utils import get_compatible_llm
  519. llm = get_compatible_llm()
  520. health_data["checks"]["llm_connection"] = llm is not None
  521. except Exception as e:
  522. health_data["message"] = f"LLM连接失败: {str(e)}"
  523. # 检查4: 分类器准备
  524. try:
  525. from agent.classifier import QuestionClassifier
  526. classifier = QuestionClassifier()
  527. health_data["checks"]["classifier_ready"] = True
  528. except Exception as e:
  529. health_data["message"] = f"分类器失败: {str(e)}"
  530. # 检查5: 完整流程测试(可选)
  531. try:
  532. if all(health_data["checks"].values()):
  533. test_result = agent.health_check()
  534. health_data["test_result"] = test_result.get("status") == "healthy"
  535. health_data["status"] = test_result.get("status", "unknown")
  536. health_data["message"] = test_result.get("message", "健康检查完成")
  537. else:
  538. health_data["status"] = "degraded"
  539. health_data["message"] = "部分组件异常"
  540. except Exception as e:
  541. health_data["status"] = "degraded"
  542. health_data["message"] = f"完整测试失败: {str(e)}"
  543. # 根据状态返回相应的HTTP代码
  544. if health_data["status"] == "healthy":
  545. return jsonify(result.success(data=health_data))
  546. elif health_data["status"] == "degraded":
  547. return jsonify(result.failed(
  548. message="Agent状态: degraded",
  549. data=health_data,
  550. code=503
  551. )), 503
  552. else:
  553. return jsonify(result.failed(
  554. message="Agent状态: unhealthy",
  555. data=health_data,
  556. code=503
  557. )), 503
  558. except Exception as e:
  559. print(f"[ERROR] 健康检查异常: {str(e)}")
  560. return jsonify(result.failed(
  561. message=f"健康检查失败: {str(e)}",
  562. code=500,
  563. data={
  564. "status": "error",
  565. "timestamp": datetime.now().isoformat()
  566. }
  567. )), 500
  568. # ==================== 日常管理API ====================
  569. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  570. def cache_overview():
  571. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  572. try:
  573. cache = app.cache
  574. result_data = {
  575. 'overview_summary': {
  576. 'total_conversations': 0,
  577. 'total_sessions': 0,
  578. 'query_time': datetime.now().isoformat()
  579. },
  580. 'recent_conversations': [], # 最近的对话
  581. 'session_summary': [] # 会话摘要
  582. }
  583. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  584. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  585. # 获取会话信息
  586. if hasattr(cache, 'get_all_sessions'):
  587. all_sessions = cache.get_all_sessions()
  588. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  589. # 会话摘要(按最近活动排序)
  590. session_list = []
  591. for session_id, session_data in all_sessions.items():
  592. session_summary = {
  593. 'session_id': session_id,
  594. 'start_time': session_data['start_time'].isoformat(),
  595. 'conversation_count': session_data.get('conversation_count', 0),
  596. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  597. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  598. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  599. }
  600. session_list.append(session_summary)
  601. # 按最后活动时间排序
  602. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  603. result_data['session_summary'] = session_list
  604. # 最近的对话(最多显示10个)
  605. conversation_list = []
  606. for conversation_id, conversation_data in cache.cache.items():
  607. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  608. conversation_info = {
  609. 'conversation_id': conversation_id,
  610. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  611. 'session_id': cache.conversation_to_session.get(conversation_id),
  612. 'has_question': 'question' in conversation_data,
  613. 'has_sql': 'sql' in conversation_data,
  614. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  615. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  616. }
  617. # 计算对话持续时间
  618. if conversation_start_time:
  619. duration = datetime.now() - conversation_start_time
  620. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  621. conversation_list.append(conversation_info)
  622. # 按对话开始时间排序,显示最新的10个
  623. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  624. result_data['recent_conversations'] = conversation_list[:10]
  625. return jsonify(result.success(data=result_data))
  626. except Exception as e:
  627. return jsonify(result.failed(
  628. message=f"获取缓存概览失败: {str(e)}",
  629. code=500
  630. )), 500
  631. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  632. def cache_stats():
  633. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  634. try:
  635. cache = app.cache
  636. current_time = datetime.now()
  637. stats = {
  638. 'basic_stats': {
  639. 'total_sessions': len(getattr(cache, 'session_info', {})),
  640. 'total_conversations': len(getattr(cache, 'cache', {})),
  641. 'active_sessions': 0, # 最近30分钟有活动
  642. 'average_conversations_per_session': 0
  643. },
  644. 'time_distribution': {
  645. 'sessions': {
  646. 'last_1_hour': 0,
  647. 'last_6_hours': 0,
  648. 'last_24_hours': 0,
  649. 'last_7_days': 0,
  650. 'older': 0
  651. },
  652. 'conversations': {
  653. 'last_1_hour': 0,
  654. 'last_6_hours': 0,
  655. 'last_24_hours': 0,
  656. 'last_7_days': 0,
  657. 'older': 0
  658. }
  659. },
  660. 'session_details': [],
  661. 'time_ranges': {
  662. 'oldest_session': None,
  663. 'newest_session': None,
  664. 'oldest_conversation': None,
  665. 'newest_conversation': None
  666. }
  667. }
  668. # 会话统计
  669. if hasattr(cache, 'session_info'):
  670. session_times = []
  671. total_conversations = 0
  672. for session_id, session_data in cache.session_info.items():
  673. start_time = session_data['start_time']
  674. session_times.append(start_time)
  675. conversation_count = len(session_data.get('conversations', []))
  676. total_conversations += conversation_count
  677. # 检查活跃状态
  678. last_activity = session_data.get('last_activity', session_data['start_time'])
  679. if (current_time - last_activity).total_seconds() < 1800:
  680. stats['basic_stats']['active_sessions'] += 1
  681. # 时间分布统计
  682. age_hours = (current_time - start_time).total_seconds() / 3600
  683. if age_hours <= 1:
  684. stats['time_distribution']['sessions']['last_1_hour'] += 1
  685. elif age_hours <= 6:
  686. stats['time_distribution']['sessions']['last_6_hours'] += 1
  687. elif age_hours <= 24:
  688. stats['time_distribution']['sessions']['last_24_hours'] += 1
  689. elif age_hours <= 168: # 7 days
  690. stats['time_distribution']['sessions']['last_7_days'] += 1
  691. else:
  692. stats['time_distribution']['sessions']['older'] += 1
  693. # 会话详细信息
  694. session_duration = current_time - start_time
  695. stats['session_details'].append({
  696. 'session_id': session_id,
  697. 'start_time': start_time.isoformat(),
  698. 'last_activity': last_activity.isoformat(),
  699. 'conversation_count': conversation_count,
  700. 'duration_seconds': session_duration.total_seconds(),
  701. 'duration_formatted': str(session_duration),
  702. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  703. 'browser_session_id': session_data.get('browser_session_id')
  704. })
  705. # 计算平均值
  706. if len(cache.session_info) > 0:
  707. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  708. # 时间范围
  709. if session_times:
  710. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  711. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  712. # 对话统计
  713. if hasattr(cache, 'conversation_start_times'):
  714. conversation_times = []
  715. for conv_time in cache.conversation_start_times.values():
  716. conversation_times.append(conv_time)
  717. age_hours = (current_time - conv_time).total_seconds() / 3600
  718. if age_hours <= 1:
  719. stats['time_distribution']['conversations']['last_1_hour'] += 1
  720. elif age_hours <= 6:
  721. stats['time_distribution']['conversations']['last_6_hours'] += 1
  722. elif age_hours <= 24:
  723. stats['time_distribution']['conversations']['last_24_hours'] += 1
  724. elif age_hours <= 168:
  725. stats['time_distribution']['conversations']['last_7_days'] += 1
  726. else:
  727. stats['time_distribution']['conversations']['older'] += 1
  728. if conversation_times:
  729. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  730. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  731. # 按最近活动排序会话详情
  732. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  733. return jsonify(result.success(data=stats))
  734. except Exception as e:
  735. return jsonify(result.failed(
  736. message=f"获取缓存统计失败: {str(e)}",
  737. code=500
  738. )), 500
  739. # ==================== 高级功能API ====================
  740. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  741. def cache_export():
  742. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  743. try:
  744. cache = app.cache
  745. # 验证缓存的实际结构
  746. if not hasattr(cache, 'cache'):
  747. return jsonify(result.failed(message="缓存对象没有cache属性", code=500)), 500
  748. if not isinstance(cache.cache, dict):
  749. return jsonify(result.failed(message="缓存不是字典类型", code=500)), 500
  750. # 定义JSON序列化辅助函数
  751. def make_json_serializable(obj):
  752. """将对象转换为JSON可序列化的格式"""
  753. if obj is None:
  754. return None
  755. elif isinstance(obj, (str, int, float, bool)):
  756. return obj
  757. elif isinstance(obj, (list, tuple)):
  758. return [make_json_serializable(item) for item in obj]
  759. elif isinstance(obj, dict):
  760. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  761. elif hasattr(obj, 'isoformat'): # datetime objects
  762. return obj.isoformat()
  763. elif hasattr(obj, 'item'): # numpy scalars
  764. return obj.item()
  765. elif hasattr(obj, 'tolist'): # numpy arrays
  766. return obj.tolist()
  767. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  768. return str(obj)
  769. else:
  770. return str(obj)
  771. # 获取完整的原始缓存数据
  772. raw_cache = cache.cache
  773. # 获取会话和对话时间信息
  774. conversation_times = getattr(cache, 'conversation_start_times', {})
  775. session_info = getattr(cache, 'session_info', {})
  776. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  777. export_data = {
  778. 'export_metadata': {
  779. 'export_time': datetime.now().isoformat(),
  780. 'total_conversations': len(raw_cache),
  781. 'total_sessions': len(session_info),
  782. 'cache_type': type(cache).__name__,
  783. 'cache_object_info': str(cache),
  784. 'has_session_times': bool(session_info),
  785. 'has_conversation_times': bool(conversation_times)
  786. },
  787. 'session_info': {
  788. session_id: {
  789. 'start_time': session_data['start_time'].isoformat(),
  790. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  791. 'conversations': session_data['conversations'],
  792. 'conversation_count': len(session_data['conversations']),
  793. 'browser_session_id': session_data.get('browser_session_id'),
  794. 'user_info': session_data.get('user_info', {})
  795. }
  796. for session_id, session_data in session_info.items()
  797. },
  798. 'conversation_times': {
  799. conversation_id: start_time.isoformat()
  800. for conversation_id, start_time in conversation_times.items()
  801. },
  802. 'conversation_to_session_mapping': conversation_to_session,
  803. 'conversations': {}
  804. }
  805. # 处理每个对话的完整数据
  806. for conversation_id, conversation_data in raw_cache.items():
  807. # 获取时间信息
  808. conversation_start_time = conversation_times.get(conversation_id)
  809. session_id = conversation_to_session.get(conversation_id)
  810. session_start_time = None
  811. if session_id and session_id in session_info:
  812. session_start_time = session_info[session_id]['start_time']
  813. processed_conversation = {
  814. 'conversation_id': conversation_id,
  815. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  816. 'session_id': session_id,
  817. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  818. 'field_count': len(conversation_data),
  819. 'fields': {}
  820. }
  821. # 添加时间计算
  822. if conversation_start_time:
  823. conversation_duration = datetime.now() - conversation_start_time
  824. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  825. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  826. if session_start_time:
  827. session_duration = datetime.now() - session_start_time
  828. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  829. processed_conversation['session_duration_formatted'] = str(session_duration)
  830. # 处理每个字段,确保JSON序列化安全
  831. for field_name, field_value in conversation_data.items():
  832. field_info = {
  833. 'field_name': field_name,
  834. 'data_type': type(field_value).__name__,
  835. 'is_none': field_value is None
  836. }
  837. try:
  838. if field_value is None:
  839. field_info['value'] = None
  840. elif field_name in ['conversation_start_time', 'session_start_time']:
  841. # 处理时间字段
  842. field_info['content'] = make_json_serializable(field_value)
  843. elif field_name == 'df' and field_value is not None:
  844. # DataFrame的安全处理
  845. if hasattr(field_value, 'to_dict'):
  846. # 安全地处理dtypes
  847. try:
  848. dtypes_dict = {}
  849. for col, dtype in field_value.dtypes.items():
  850. dtypes_dict[col] = str(dtype)
  851. except Exception:
  852. dtypes_dict = {"error": "无法序列化dtypes"}
  853. # 安全地处理内存使用
  854. try:
  855. memory_usage = field_value.memory_usage(deep=True)
  856. memory_dict = {}
  857. for idx, usage in memory_usage.items():
  858. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  859. except Exception:
  860. memory_dict = {"error": "无法获取内存使用信息"}
  861. field_info.update({
  862. 'dataframe_info': {
  863. 'shape': list(field_value.shape),
  864. 'columns': list(field_value.columns),
  865. 'dtypes': dtypes_dict,
  866. 'index_info': {
  867. 'type': type(field_value.index).__name__,
  868. 'length': len(field_value.index)
  869. }
  870. },
  871. 'data': make_json_serializable(field_value.to_dict('records')),
  872. 'memory_usage': memory_dict
  873. })
  874. else:
  875. field_info['value'] = str(field_value)
  876. field_info['note'] = 'not_standard_dataframe'
  877. elif field_name == 'fig_json':
  878. # 图表JSON数据处理
  879. if isinstance(field_value, str):
  880. try:
  881. import json
  882. parsed_fig = json.loads(field_value)
  883. field_info.update({
  884. 'json_valid': True,
  885. 'json_size_bytes': len(field_value),
  886. 'plotly_structure': {
  887. 'has_data': 'data' in parsed_fig,
  888. 'has_layout': 'layout' in parsed_fig,
  889. 'data_traces_count': len(parsed_fig.get('data', [])),
  890. },
  891. 'raw_json': field_value
  892. })
  893. except json.JSONDecodeError:
  894. field_info.update({
  895. 'json_valid': False,
  896. 'raw_content': str(field_value)
  897. })
  898. else:
  899. field_info['value'] = make_json_serializable(field_value)
  900. elif field_name == 'followup_questions':
  901. # 后续问题列表
  902. field_info.update({
  903. 'content': make_json_serializable(field_value)
  904. })
  905. elif field_name in ['question', 'sql', 'summary']:
  906. # 文本字段
  907. if isinstance(field_value, str):
  908. field_info.update({
  909. 'text_length': len(field_value),
  910. 'content': field_value
  911. })
  912. else:
  913. field_info['value'] = make_json_serializable(field_value)
  914. else:
  915. # 未知字段的安全处理
  916. field_info['content'] = make_json_serializable(field_value)
  917. except Exception as e:
  918. field_info.update({
  919. 'processing_error': str(e),
  920. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  921. })
  922. processed_conversation['fields'][field_name] = field_info
  923. export_data['conversations'][conversation_id] = processed_conversation
  924. # 添加缓存统计信息
  925. field_frequency = {}
  926. data_types_found = set()
  927. total_dataframes = 0
  928. total_questions = 0
  929. for conv_data in export_data['conversations'].values():
  930. for field_name, field_info in conv_data['fields'].items():
  931. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  932. data_types_found.add(field_info['data_type'])
  933. if field_name == 'df' and not field_info['is_none']:
  934. total_dataframes += 1
  935. if field_name == 'question' and not field_info['is_none']:
  936. total_questions += 1
  937. export_data['cache_statistics'] = {
  938. 'field_frequency': field_frequency,
  939. 'data_types_found': list(data_types_found),
  940. 'total_dataframes': total_dataframes,
  941. 'total_questions': total_questions,
  942. 'has_session_timing': 'session_start_time' in field_frequency,
  943. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  944. }
  945. return jsonify(result.success(data=export_data))
  946. except Exception as e:
  947. import traceback
  948. error_details = {
  949. 'error_message': str(e),
  950. 'error_type': type(e).__name__,
  951. 'traceback': traceback.format_exc()
  952. }
  953. return jsonify(result.failed(
  954. message=f"导出缓存失败: {str(e)}",
  955. code=500,
  956. data=error_details
  957. )), 500
  958. # ==================== 清理功能API ====================
  959. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  960. def cache_preview_cleanup():
  961. """清理功能:预览删除操作 - 保持原功能"""
  962. try:
  963. req = request.get_json(force=True)
  964. # 时间条件 - 支持三种方式
  965. older_than_hours = req.get('older_than_hours')
  966. older_than_days = req.get('older_than_days')
  967. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  968. cache = app.cache
  969. # 计算截止时间
  970. cutoff_time = None
  971. time_condition = None
  972. if older_than_hours:
  973. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  974. time_condition = f"older_than_hours: {older_than_hours}"
  975. elif older_than_days:
  976. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  977. time_condition = f"older_than_days: {older_than_days}"
  978. elif before_timestamp:
  979. try:
  980. # 支持 YYYY-MM-DD HH:MM:SS 格式
  981. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  982. time_condition = f"before_timestamp: {before_timestamp}"
  983. except ValueError:
  984. return jsonify(result.failed(
  985. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  986. code=400
  987. )), 400
  988. else:
  989. return jsonify(result.failed(
  990. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  991. code=400
  992. )), 400
  993. preview = {
  994. 'time_condition': time_condition,
  995. 'cutoff_time': cutoff_time.isoformat(),
  996. 'will_be_removed': {
  997. 'sessions': []
  998. },
  999. 'will_be_kept': {
  1000. 'sessions_count': 0,
  1001. 'conversations_count': 0
  1002. },
  1003. 'summary': {
  1004. 'sessions_to_remove': 0,
  1005. 'conversations_to_remove': 0,
  1006. 'sessions_to_keep': 0,
  1007. 'conversations_to_keep': 0
  1008. }
  1009. }
  1010. # 预览按session删除
  1011. sessions_to_remove_count = 0
  1012. conversations_to_remove_count = 0
  1013. for session_id, session_data in cache.session_info.items():
  1014. session_preview = {
  1015. 'session_id': session_id,
  1016. 'start_time': session_data['start_time'].isoformat(),
  1017. 'conversation_count': len(session_data['conversations']),
  1018. 'conversations': []
  1019. }
  1020. # 添加conversation详情
  1021. for conv_id in session_data['conversations']:
  1022. if conv_id in cache.cache:
  1023. conv_data = cache.cache[conv_id]
  1024. session_preview['conversations'].append({
  1025. 'conversation_id': conv_id,
  1026. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1027. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1028. })
  1029. if session_data['start_time'] < cutoff_time:
  1030. preview['will_be_removed']['sessions'].append(session_preview)
  1031. sessions_to_remove_count += 1
  1032. conversations_to_remove_count += len(session_data['conversations'])
  1033. else:
  1034. preview['will_be_kept']['sessions_count'] += 1
  1035. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1036. # 更新摘要统计
  1037. preview['summary'] = {
  1038. 'sessions_to_remove': sessions_to_remove_count,
  1039. 'conversations_to_remove': conversations_to_remove_count,
  1040. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1041. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1042. }
  1043. return jsonify(result.success(data=preview))
  1044. except Exception as e:
  1045. return jsonify(result.failed(
  1046. message=f"预览清理操作失败: {str(e)}",
  1047. code=500
  1048. )), 500
  1049. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1050. def cache_cleanup():
  1051. """清理功能:实际删除缓存 - 保持原功能"""
  1052. try:
  1053. req = request.get_json(force=True)
  1054. # 时间条件 - 支持三种方式
  1055. older_than_hours = req.get('older_than_hours')
  1056. older_than_days = req.get('older_than_days')
  1057. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1058. cache = app.cache
  1059. if not hasattr(cache, 'session_info'):
  1060. return jsonify(result.failed(
  1061. message="缓存不支持会话功能",
  1062. code=400
  1063. )), 400
  1064. # 计算截止时间
  1065. cutoff_time = None
  1066. time_condition = None
  1067. if older_than_hours:
  1068. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1069. time_condition = f"older_than_hours: {older_than_hours}"
  1070. elif older_than_days:
  1071. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1072. time_condition = f"older_than_days: {older_than_days}"
  1073. elif before_timestamp:
  1074. try:
  1075. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1076. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1077. time_condition = f"before_timestamp: {before_timestamp}"
  1078. except ValueError:
  1079. return jsonify(result.failed(
  1080. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  1081. code=400
  1082. )), 400
  1083. else:
  1084. return jsonify(result.failed(
  1085. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1086. code=400
  1087. )), 400
  1088. cleanup_stats = {
  1089. 'time_condition': time_condition,
  1090. 'cutoff_time': cutoff_time.isoformat(),
  1091. 'sessions_removed': 0,
  1092. 'conversations_removed': 0,
  1093. 'sessions_kept': 0,
  1094. 'conversations_kept': 0,
  1095. 'removed_session_ids': [],
  1096. 'removed_conversation_ids': []
  1097. }
  1098. # 按session删除
  1099. sessions_to_remove = []
  1100. for session_id, session_data in cache.session_info.items():
  1101. if session_data['start_time'] < cutoff_time:
  1102. sessions_to_remove.append(session_id)
  1103. # 删除符合条件的sessions及其所有conversations
  1104. for session_id in sessions_to_remove:
  1105. session_data = cache.session_info[session_id]
  1106. conversations_in_session = session_data['conversations'].copy()
  1107. # 删除session中的所有conversations
  1108. for conv_id in conversations_in_session:
  1109. if conv_id in cache.cache:
  1110. del cache.cache[conv_id]
  1111. cleanup_stats['conversations_removed'] += 1
  1112. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1113. # 清理conversation相关的时间记录
  1114. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1115. del cache.conversation_start_times[conv_id]
  1116. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1117. del cache.conversation_to_session[conv_id]
  1118. # 删除session记录
  1119. del cache.session_info[session_id]
  1120. cleanup_stats['sessions_removed'] += 1
  1121. cleanup_stats['removed_session_ids'].append(session_id)
  1122. # 统计保留的sessions和conversations
  1123. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1124. cleanup_stats['conversations_kept'] = len(cache.cache)
  1125. return jsonify(result.success(data=cleanup_stats))
  1126. except Exception as e:
  1127. return jsonify(result.failed(
  1128. message=f"清理缓存失败: {str(e)}",
  1129. code=500
  1130. )), 500
  1131. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1132. def training_error_question_sql():
  1133. """
  1134. 存储错误的question-sql对到error_sql集合中
  1135. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1136. Args:
  1137. question (str, required): 用户问题
  1138. sql (str, required): 对应的错误SQL查询语句
  1139. Returns:
  1140. JSON: 包含训练ID和成功消息的响应
  1141. """
  1142. try:
  1143. data = request.get_json()
  1144. question = data.get('question')
  1145. sql = data.get('sql')
  1146. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1147. if not question or not sql:
  1148. return jsonify(result.failed(
  1149. message="question和sql参数都是必需的",
  1150. code=400
  1151. )), 400
  1152. # 使用vn实例的train_error_sql方法存储错误SQL
  1153. id = vn.train_error_sql(question=question, sql=sql)
  1154. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1155. return jsonify(result.success(data={
  1156. "id": id,
  1157. "message": "错误SQL对已成功存储到error_sql集合"
  1158. }))
  1159. except Exception as e:
  1160. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1161. return jsonify(result.failed(
  1162. message=f"存储错误SQL失败: {str(e)}",
  1163. code=500
  1164. )), 500
  1165. # 前端JavaScript示例 - 如何维持会话
  1166. """
  1167. // 前端需要维护一个会话ID
  1168. class ChatSession {
  1169. constructor() {
  1170. // 从localStorage获取或创建新的会话ID
  1171. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1172. localStorage.setItem('chat_session_id', this.sessionId);
  1173. }
  1174. generateSessionId() {
  1175. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1176. }
  1177. async askQuestion(question) {
  1178. const response = await fetch('/api/v0/ask', {
  1179. method: 'POST',
  1180. headers: {
  1181. 'Content-Type': 'application/json',
  1182. },
  1183. body: JSON.stringify({
  1184. question: question,
  1185. session_id: this.sessionId // 关键:传递会话ID
  1186. })
  1187. });
  1188. return await response.json();
  1189. }
  1190. // 开始新会话
  1191. startNewSession() {
  1192. this.sessionId = this.generateSessionId();
  1193. localStorage.setItem('chat_session_id', this.sessionId);
  1194. }
  1195. }
  1196. // 使用示例
  1197. const chatSession = new ChatSession();
  1198. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1199. """
  1200. print("正在启动Flask应用: http://localhost:8084")
  1201. app.run(host="0.0.0.0", port=8084, debug=True)