citu_app.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS
  10. import re
  11. import chainlit as cl
  12. import json
  13. # 设置默认的最大返回行数
  14. DEFAULT_MAX_RETURN_ROWS = 200
  15. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  16. vn = create_vanna_instance()
  17. # 创建带时间戳的缓存
  18. timestamped_cache = WebSessionAwareMemoryCache()
  19. # 实例化 VannaFlaskApp,使用自定义缓存
  20. app = VannaFlaskApp(
  21. vn,
  22. cache=timestamped_cache, # 使用带时间戳的缓存
  23. title="辞图智能数据问答平台",
  24. logo = "https://www.citupro.com/img/logo-black-2.png",
  25. subtitle="让 AI 为你写 SQL",
  26. chart=False,
  27. allow_llm_to_see_data=True,
  28. ask_results_correct=True,
  29. followup_questions=True,
  30. debug=True
  31. )
  32. # 修改ask接口,支持前端传递session_id
  33. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  34. def ask_full():
  35. req = request.get_json(force=True)
  36. question = req.get("question", None)
  37. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  38. if not question:
  39. return jsonify(result.failed(message="未提供问题", code=400)), 400
  40. # 如果使用WebSessionAwareMemoryCache
  41. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  42. # 这里需要修改vanna的ask方法来支持传递session_id
  43. # 或者预先调用generate_id来建立会话关联
  44. conversation_id = app.cache.generate_id_with_browser_session(
  45. question=question,
  46. browser_session_id=browser_session_id
  47. )
  48. try:
  49. sql, df, _ = vn.ask(
  50. question=question,
  51. print_results=False,
  52. visualize=False,
  53. allow_llm_to_see_data=True
  54. )
  55. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  56. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  57. # 在解释性文本末尾添加提示语
  58. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  59. # 使用 result.failed 返回,success为false,但在message中包含LLM友好的解释
  60. return jsonify(result.failed(
  61. message=explanation_message, # 已处理的解释性文本
  62. code=400, # 业务逻辑错误,使用400
  63. data={
  64. "sql": None,
  65. "rows": [],
  66. "columns": [],
  67. "summary": None,
  68. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  69. "session_id": browser_session_id
  70. }
  71. )), 200 # HTTP状态码仍为200,因为请求本身成功处理了
  72. # 如果sql为None但没有解释性文本,返回通用错误
  73. if sql is None:
  74. return jsonify(result.failed(
  75. message="无法生成SQL查询,请检查问题描述或数据表结构",
  76. code=400,
  77. data={
  78. "sql": None,
  79. "rows": [],
  80. "columns": [],
  81. "summary": None,
  82. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  83. "session_id": browser_session_id
  84. }
  85. )), 200
  86. # 正常SQL流程
  87. rows, columns = [], []
  88. summary = None
  89. if isinstance(df, pd.DataFrame) and not df.empty:
  90. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  91. columns = list(df.columns)
  92. # 生成数据摘要(thinking内容已在base_llm_chat.py中统一处理)
  93. try:
  94. summary = vn.generate_summary(question=question, df=df)
  95. print(f"[INFO] 成功生成摘要: {summary}")
  96. except Exception as e:
  97. print(f"[WARNING] 生成摘要失败: {str(e)}")
  98. summary = None
  99. return jsonify(result.success(data={
  100. "sql": sql,
  101. "rows": rows,
  102. "columns": columns,
  103. "summary": summary, # 添加摘要到返回结果
  104. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  105. "session_id": browser_session_id
  106. }))
  107. except Exception as e:
  108. print(f"[ERROR] ask_full执行失败: {str(e)}")
  109. # 即使发生异常,也检查是否有业务层面的解释
  110. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  111. # 在解释性文本末尾添加提示语
  112. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  113. return jsonify(result.failed(
  114. message=explanation_message,
  115. code=400,
  116. data={
  117. "sql": None,
  118. "rows": [],
  119. "columns": [],
  120. "summary": None,
  121. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  122. "session_id": browser_session_id
  123. }
  124. )), 200
  125. else:
  126. # 技术错误,使用500错误码
  127. return jsonify(result.failed(
  128. message=f"查询处理失败: {str(e)}",
  129. code=500
  130. )), 500
  131. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  132. def citu_run_sql():
  133. req = request.get_json(force=True)
  134. sql = req.get('sql')
  135. if not sql:
  136. return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
  137. try:
  138. df = vn.run_sql(sql)
  139. rows, columns = [], []
  140. if isinstance(df, pd.DataFrame) and not df.empty:
  141. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  142. columns = list(df.columns)
  143. return jsonify(result.success(data={
  144. "sql": sql,
  145. "rows": rows,
  146. "columns": columns
  147. }))
  148. except Exception as e:
  149. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  150. return jsonify(result.failed(
  151. message=f"SQL执行失败: {str(e)}",
  152. code=500
  153. )), 500
  154. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  155. def ask_cached():
  156. """
  157. 带缓存功能的智能查询接口
  158. 支持会话管理和结果缓存,提高查询效率
  159. """
  160. req = request.get_json(force=True)
  161. question = req.get("question", None)
  162. browser_session_id = req.get("session_id", None)
  163. if not question:
  164. return jsonify(result.failed(message="未提供问题", code=400)), 400
  165. try:
  166. # 生成conversation_id
  167. # 调试:查看generate_id的实际行为
  168. print(f"[DEBUG] 输入问题: '{question}'")
  169. conversation_id = app.cache.generate_id(question=question)
  170. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  171. # 再次用相同问题测试
  172. conversation_id2 = app.cache.generate_id(question=question)
  173. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  174. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  175. # 检查缓存
  176. cached_sql = app.cache.get(id=conversation_id, field="sql")
  177. if cached_sql is not None:
  178. # 缓存命中
  179. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  180. sql = cached_sql
  181. df = app.cache.get(id=conversation_id, field="df")
  182. summary = app.cache.get(id=conversation_id, field="summary")
  183. else:
  184. # 缓存未命中,执行新查询
  185. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  186. sql, df, _ = vn.ask(
  187. question=question,
  188. print_results=False,
  189. visualize=False,
  190. allow_llm_to_see_data=True
  191. )
  192. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  193. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  194. # 在解释性文本末尾添加提示语
  195. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  196. return jsonify(result.failed(
  197. message=explanation_message,
  198. code=400,
  199. data={
  200. "sql": None,
  201. "rows": [],
  202. "columns": [],
  203. "summary": None,
  204. "conversation_id": conversation_id,
  205. "session_id": browser_session_id,
  206. "cached": False
  207. }
  208. )), 200
  209. # 如果sql为None但没有解释性文本,返回通用错误
  210. if sql is None:
  211. return jsonify(result.failed(
  212. message="无法生成SQL查询,请检查问题描述或数据表结构",
  213. code=400,
  214. data={
  215. "sql": None,
  216. "rows": [],
  217. "columns": [],
  218. "summary": None,
  219. "conversation_id": conversation_id,
  220. "session_id": browser_session_id,
  221. "cached": False
  222. }
  223. )), 200
  224. # 缓存结果
  225. app.cache.set(id=conversation_id, field="question", value=question)
  226. app.cache.set(id=conversation_id, field="sql", value=sql)
  227. app.cache.set(id=conversation_id, field="df", value=df)
  228. # 生成并缓存摘要
  229. summary = None
  230. if isinstance(df, pd.DataFrame) and not df.empty:
  231. try:
  232. summary = vn.generate_summary(question=question, df=df)
  233. print(f"[INFO] 成功生成摘要: {summary}")
  234. except Exception as e:
  235. print(f"[WARNING] 生成摘要失败: {str(e)}")
  236. summary = None
  237. app.cache.set(id=conversation_id, field="summary", value=summary)
  238. # 处理返回数据
  239. rows, columns = [], []
  240. if isinstance(df, pd.DataFrame) and not df.empty:
  241. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  242. columns = list(df.columns)
  243. return jsonify(result.success(data={
  244. "sql": sql,
  245. "rows": rows,
  246. "columns": columns,
  247. "summary": summary,
  248. "conversation_id": conversation_id,
  249. "session_id": browser_session_id,
  250. "cached": cached_sql is not None # 标识是否来自缓存
  251. }))
  252. except Exception as e:
  253. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  254. return jsonify(result.failed(
  255. message=f"查询处理失败: {str(e)}",
  256. code=500
  257. )), 500
  258. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  259. def citu_train_question_sql():
  260. """
  261. 训练问题-SQL对接口
  262. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  263. 支持仅传入SQL或同时传入问题和SQL进行训练。
  264. Args:
  265. question (str, optional): 用户问题
  266. sql (str, required): 对应的SQL查询语句
  267. Returns:
  268. JSON: 包含训练ID和成功消息的响应
  269. """
  270. try:
  271. req = request.get_json(force=True)
  272. question = req.get('question')
  273. sql = req.get('sql')
  274. if not sql:
  275. return jsonify(result.failed(
  276. message="'sql' are required",
  277. code=400
  278. )), 400
  279. # 正确的调用方式:同时传递question和sql
  280. if question:
  281. training_id = vn.train(question=question, sql=sql)
  282. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  283. else:
  284. training_id = vn.train(sql=sql)
  285. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  286. return jsonify(result.success(data={
  287. "training_id": training_id,
  288. "message": "Question-SQL pair trained successfully"
  289. }))
  290. except Exception as e:
  291. return jsonify(result.failed(
  292. message=f"Training failed: {str(e)}",
  293. code=500
  294. )), 500
  295. # ============ LangGraph Agent 集成 ============
  296. # 全局Agent实例(单例模式)
  297. citu_langraph_agent = None
  298. def get_citu_langraph_agent():
  299. """获取LangGraph Agent实例(懒加载)"""
  300. global citu_langraph_agent
  301. if citu_langraph_agent is None:
  302. try:
  303. from agent.citu_agent import CituLangGraphAgent
  304. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  305. citu_langraph_agent = CituLangGraphAgent()
  306. print("[CITU_APP] LangGraph Agent实例创建成功")
  307. except ImportError as e:
  308. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  309. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  310. raise Exception(f"Agent模块导入失败: {str(e)}")
  311. except Exception as e:
  312. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  313. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  314. # 提供更有用的错误信息
  315. if "config" in str(e).lower():
  316. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  317. elif "llm" in str(e).lower():
  318. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  319. elif "tool" in str(e).lower():
  320. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  321. raise Exception(f"Agent初始化失败: {str(e)}")
  322. return citu_langraph_agent
  323. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  324. def ask_agent():
  325. """
  326. 新的LangGraph Agent接口
  327. 请求格式:
  328. {
  329. "question": "用户问题",
  330. "session_id": "会话ID(可选)"
  331. }
  332. 响应格式:
  333. {
  334. "success": true/false,
  335. "code": 200,
  336. "message": "success" 或错误信息,
  337. "data": {
  338. "response": "最终回答",
  339. "type": "DATABASE/CHAT",
  340. "sql": "生成的SQL(如果是数据库查询)",
  341. "data_result": {
  342. "rows": [...],
  343. "columns": [...],
  344. "row_count": 数字
  345. },
  346. "summary": "数据摘要(如果是数据库查询)",
  347. "session_id": "会话ID",
  348. "execution_path": ["classify", "agent_database", "format_response"],
  349. "classification_info": {
  350. "confidence": 0.95,
  351. "reason": "分类原因",
  352. "method": "rule_based/llm_based"
  353. },
  354. "agent_version": "langgraph_v1"
  355. }
  356. }
  357. """
  358. req = request.get_json(force=True)
  359. question = req.get("question", None)
  360. browser_session_id = req.get("session_id", None)
  361. if not question:
  362. return jsonify(result.failed(message="未提供问题", code=400)), 400
  363. try:
  364. # 专门处理Agent初始化异常
  365. try:
  366. agent = get_citu_langraph_agent()
  367. except Exception as e:
  368. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  369. return jsonify(result.failed(
  370. message="AI服务暂时不可用,请稍后重试",
  371. code=503,
  372. data={
  373. "session_id": browser_session_id,
  374. "execution_path": ["agent_init_error"],
  375. "agent_version": "langgraph_v1",
  376. "timestamp": datetime.now().isoformat(),
  377. "error_type": "agent_initialization_failed"
  378. }
  379. )), 503
  380. # 调用Agent处理问题
  381. agent_result = agent.process_question(
  382. question=question,
  383. session_id=browser_session_id
  384. )
  385. # 统一返回格式
  386. if agent_result.get("success", False):
  387. return jsonify(result.success(data={
  388. "response": agent_result.get("response", ""),
  389. "type": agent_result.get("type", "UNKNOWN"),
  390. "sql": agent_result.get("sql"),
  391. "data_result": agent_result.get("data_result"),
  392. "summary": agent_result.get("summary"),
  393. "session_id": browser_session_id,
  394. "execution_path": agent_result.get("execution_path", []),
  395. "classification_info": agent_result.get("classification_info", {}),
  396. "agent_version": "langgraph_v1",
  397. "timestamp": datetime.now().isoformat()
  398. }))
  399. else:
  400. return jsonify(result.failed(
  401. message=agent_result.get("error", "Agent处理失败"),
  402. code=agent_result.get("error_code", 500),
  403. data={
  404. "session_id": browser_session_id,
  405. "execution_path": agent_result.get("execution_path", []),
  406. "classification_info": agent_result.get("classification_info", {}),
  407. "agent_version": "langgraph_v1",
  408. "timestamp": datetime.now().isoformat()
  409. }
  410. )), 200 # HTTP 200但业务失败
  411. except Exception as e:
  412. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  413. return jsonify(result.failed(
  414. message="请求处理异常,请稍后重试",
  415. code=500,
  416. data={
  417. "session_id": browser_session_id,
  418. "execution_path": ["general_error"],
  419. "agent_version": "langgraph_v1",
  420. "timestamp": datetime.now().isoformat(),
  421. "error_type": "request_processing_failed"
  422. }
  423. )), 500
  424. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  425. def agent_health():
  426. """
  427. Agent健康检查接口
  428. 响应格式:
  429. {
  430. "success": true/false,
  431. "code": 200/503,
  432. "message": "healthy/degraded/unhealthy",
  433. "data": {
  434. "status": "healthy/degraded/unhealthy",
  435. "test_result": true/false,
  436. "workflow_compiled": true/false,
  437. "tools_count": 4,
  438. "message": "详细信息",
  439. "timestamp": "2024-01-01T12:00:00",
  440. "checks": {
  441. "agent_creation": true/false,
  442. "tools_import": true/false,
  443. "llm_connection": true/false,
  444. "classifier_ready": true/false
  445. }
  446. }
  447. }
  448. """
  449. try:
  450. # 基础健康检查
  451. health_data = {
  452. "status": "unknown",
  453. "test_result": False,
  454. "workflow_compiled": False,
  455. "tools_count": 0,
  456. "message": "",
  457. "timestamp": datetime.now().isoformat(),
  458. "checks": {
  459. "agent_creation": False,
  460. "tools_import": False,
  461. "llm_connection": False,
  462. "classifier_ready": False
  463. }
  464. }
  465. # 检查1: Agent创建
  466. try:
  467. agent = get_citu_langraph_agent()
  468. health_data["checks"]["agent_creation"] = True
  469. health_data["workflow_compiled"] = agent.workflow is not None
  470. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  471. except Exception as e:
  472. health_data["message"] = f"Agent创建失败: {str(e)}"
  473. return jsonify(result.failed(
  474. message="Agent状态: unhealthy",
  475. data=health_data,
  476. code=503
  477. )), 503
  478. # 检查2: 工具导入
  479. try:
  480. from agent.tools import TOOLS
  481. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  482. except Exception as e:
  483. health_data["message"] = f"工具导入失败: {str(e)}"
  484. # 检查3: LLM连接(简单测试)
  485. try:
  486. from agent.utils import get_compatible_llm
  487. llm = get_compatible_llm()
  488. health_data["checks"]["llm_connection"] = llm is not None
  489. except Exception as e:
  490. health_data["message"] = f"LLM连接失败: {str(e)}"
  491. # 检查4: 分类器准备
  492. try:
  493. from agent.classifier import QuestionClassifier
  494. classifier = QuestionClassifier()
  495. health_data["checks"]["classifier_ready"] = True
  496. except Exception as e:
  497. health_data["message"] = f"分类器失败: {str(e)}"
  498. # 检查5: 完整流程测试(可选)
  499. try:
  500. if all(health_data["checks"].values()):
  501. test_result = agent.health_check()
  502. health_data["test_result"] = test_result.get("status") == "healthy"
  503. health_data["status"] = test_result.get("status", "unknown")
  504. health_data["message"] = test_result.get("message", "健康检查完成")
  505. else:
  506. health_data["status"] = "degraded"
  507. health_data["message"] = "部分组件异常"
  508. except Exception as e:
  509. health_data["status"] = "degraded"
  510. health_data["message"] = f"完整测试失败: {str(e)}"
  511. # 根据状态返回相应的HTTP代码
  512. if health_data["status"] == "healthy":
  513. return jsonify(result.success(data=health_data))
  514. elif health_data["status"] == "degraded":
  515. return jsonify(result.failed(
  516. message="Agent状态: degraded",
  517. data=health_data,
  518. code=503
  519. )), 503
  520. else:
  521. return jsonify(result.failed(
  522. message="Agent状态: unhealthy",
  523. data=health_data,
  524. code=503
  525. )), 503
  526. except Exception as e:
  527. print(f"[ERROR] 健康检查异常: {str(e)}")
  528. return jsonify(result.failed(
  529. message=f"健康检查失败: {str(e)}",
  530. code=500,
  531. data={
  532. "status": "error",
  533. "timestamp": datetime.now().isoformat()
  534. }
  535. )), 500
  536. # ==================== 日常管理API ====================
  537. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  538. def cache_overview():
  539. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  540. try:
  541. cache = app.cache
  542. result_data = {
  543. 'overview_summary': {
  544. 'total_conversations': 0,
  545. 'total_sessions': 0,
  546. 'query_time': datetime.now().isoformat()
  547. },
  548. 'recent_conversations': [], # 最近的对话
  549. 'session_summary': [] # 会话摘要
  550. }
  551. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  552. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  553. # 获取会话信息
  554. if hasattr(cache, 'get_all_sessions'):
  555. all_sessions = cache.get_all_sessions()
  556. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  557. # 会话摘要(按最近活动排序)
  558. session_list = []
  559. for session_id, session_data in all_sessions.items():
  560. session_summary = {
  561. 'session_id': session_id,
  562. 'start_time': session_data['start_time'].isoformat(),
  563. 'conversation_count': session_data.get('conversation_count', 0),
  564. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  565. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  566. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  567. }
  568. session_list.append(session_summary)
  569. # 按最后活动时间排序
  570. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  571. result_data['session_summary'] = session_list
  572. # 最近的对话(最多显示10个)
  573. conversation_list = []
  574. for conversation_id, conversation_data in cache.cache.items():
  575. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  576. conversation_info = {
  577. 'conversation_id': conversation_id,
  578. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  579. 'session_id': cache.conversation_to_session.get(conversation_id),
  580. 'has_question': 'question' in conversation_data,
  581. 'has_sql': 'sql' in conversation_data,
  582. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  583. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  584. }
  585. # 计算对话持续时间
  586. if conversation_start_time:
  587. duration = datetime.now() - conversation_start_time
  588. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  589. conversation_list.append(conversation_info)
  590. # 按对话开始时间排序,显示最新的10个
  591. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  592. result_data['recent_conversations'] = conversation_list[:10]
  593. return jsonify(result.success(data=result_data))
  594. except Exception as e:
  595. return jsonify(result.failed(
  596. message=f"获取缓存概览失败: {str(e)}",
  597. code=500
  598. )), 500
  599. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  600. def cache_stats():
  601. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  602. try:
  603. cache = app.cache
  604. current_time = datetime.now()
  605. stats = {
  606. 'basic_stats': {
  607. 'total_sessions': len(getattr(cache, 'session_info', {})),
  608. 'total_conversations': len(getattr(cache, 'cache', {})),
  609. 'active_sessions': 0, # 最近30分钟有活动
  610. 'average_conversations_per_session': 0
  611. },
  612. 'time_distribution': {
  613. 'sessions': {
  614. 'last_1_hour': 0,
  615. 'last_6_hours': 0,
  616. 'last_24_hours': 0,
  617. 'last_7_days': 0,
  618. 'older': 0
  619. },
  620. 'conversations': {
  621. 'last_1_hour': 0,
  622. 'last_6_hours': 0,
  623. 'last_24_hours': 0,
  624. 'last_7_days': 0,
  625. 'older': 0
  626. }
  627. },
  628. 'session_details': [],
  629. 'time_ranges': {
  630. 'oldest_session': None,
  631. 'newest_session': None,
  632. 'oldest_conversation': None,
  633. 'newest_conversation': None
  634. }
  635. }
  636. # 会话统计
  637. if hasattr(cache, 'session_info'):
  638. session_times = []
  639. total_conversations = 0
  640. for session_id, session_data in cache.session_info.items():
  641. start_time = session_data['start_time']
  642. session_times.append(start_time)
  643. conversation_count = len(session_data.get('conversations', []))
  644. total_conversations += conversation_count
  645. # 检查活跃状态
  646. last_activity = session_data.get('last_activity', session_data['start_time'])
  647. if (current_time - last_activity).total_seconds() < 1800:
  648. stats['basic_stats']['active_sessions'] += 1
  649. # 时间分布统计
  650. age_hours = (current_time - start_time).total_seconds() / 3600
  651. if age_hours <= 1:
  652. stats['time_distribution']['sessions']['last_1_hour'] += 1
  653. elif age_hours <= 6:
  654. stats['time_distribution']['sessions']['last_6_hours'] += 1
  655. elif age_hours <= 24:
  656. stats['time_distribution']['sessions']['last_24_hours'] += 1
  657. elif age_hours <= 168: # 7 days
  658. stats['time_distribution']['sessions']['last_7_days'] += 1
  659. else:
  660. stats['time_distribution']['sessions']['older'] += 1
  661. # 会话详细信息
  662. session_duration = current_time - start_time
  663. stats['session_details'].append({
  664. 'session_id': session_id,
  665. 'start_time': start_time.isoformat(),
  666. 'last_activity': last_activity.isoformat(),
  667. 'conversation_count': conversation_count,
  668. 'duration_seconds': session_duration.total_seconds(),
  669. 'duration_formatted': str(session_duration),
  670. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  671. 'browser_session_id': session_data.get('browser_session_id')
  672. })
  673. # 计算平均值
  674. if len(cache.session_info) > 0:
  675. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  676. # 时间范围
  677. if session_times:
  678. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  679. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  680. # 对话统计
  681. if hasattr(cache, 'conversation_start_times'):
  682. conversation_times = []
  683. for conv_time in cache.conversation_start_times.values():
  684. conversation_times.append(conv_time)
  685. age_hours = (current_time - conv_time).total_seconds() / 3600
  686. if age_hours <= 1:
  687. stats['time_distribution']['conversations']['last_1_hour'] += 1
  688. elif age_hours <= 6:
  689. stats['time_distribution']['conversations']['last_6_hours'] += 1
  690. elif age_hours <= 24:
  691. stats['time_distribution']['conversations']['last_24_hours'] += 1
  692. elif age_hours <= 168:
  693. stats['time_distribution']['conversations']['last_7_days'] += 1
  694. else:
  695. stats['time_distribution']['conversations']['older'] += 1
  696. if conversation_times:
  697. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  698. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  699. # 按最近活动排序会话详情
  700. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  701. return jsonify(result.success(data=stats))
  702. except Exception as e:
  703. return jsonify(result.failed(
  704. message=f"获取缓存统计失败: {str(e)}",
  705. code=500
  706. )), 500
  707. # ==================== 高级功能API ====================
  708. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  709. def cache_export():
  710. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  711. try:
  712. cache = app.cache
  713. # 验证缓存的实际结构
  714. if not hasattr(cache, 'cache'):
  715. return jsonify(result.failed(message="缓存对象没有cache属性", code=500)), 500
  716. if not isinstance(cache.cache, dict):
  717. return jsonify(result.failed(message="缓存不是字典类型", code=500)), 500
  718. # 定义JSON序列化辅助函数
  719. def make_json_serializable(obj):
  720. """将对象转换为JSON可序列化的格式"""
  721. if obj is None:
  722. return None
  723. elif isinstance(obj, (str, int, float, bool)):
  724. return obj
  725. elif isinstance(obj, (list, tuple)):
  726. return [make_json_serializable(item) for item in obj]
  727. elif isinstance(obj, dict):
  728. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  729. elif hasattr(obj, 'isoformat'): # datetime objects
  730. return obj.isoformat()
  731. elif hasattr(obj, 'item'): # numpy scalars
  732. return obj.item()
  733. elif hasattr(obj, 'tolist'): # numpy arrays
  734. return obj.tolist()
  735. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  736. return str(obj)
  737. else:
  738. return str(obj)
  739. # 获取完整的原始缓存数据
  740. raw_cache = cache.cache
  741. # 获取会话和对话时间信息
  742. conversation_times = getattr(cache, 'conversation_start_times', {})
  743. session_info = getattr(cache, 'session_info', {})
  744. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  745. export_data = {
  746. 'export_metadata': {
  747. 'export_time': datetime.now().isoformat(),
  748. 'total_conversations': len(raw_cache),
  749. 'total_sessions': len(session_info),
  750. 'cache_type': type(cache).__name__,
  751. 'cache_object_info': str(cache),
  752. 'has_session_times': bool(session_info),
  753. 'has_conversation_times': bool(conversation_times)
  754. },
  755. 'session_info': {
  756. session_id: {
  757. 'start_time': session_data['start_time'].isoformat(),
  758. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  759. 'conversations': session_data['conversations'],
  760. 'conversation_count': len(session_data['conversations']),
  761. 'browser_session_id': session_data.get('browser_session_id'),
  762. 'user_info': session_data.get('user_info', {})
  763. }
  764. for session_id, session_data in session_info.items()
  765. },
  766. 'conversation_times': {
  767. conversation_id: start_time.isoformat()
  768. for conversation_id, start_time in conversation_times.items()
  769. },
  770. 'conversation_to_session_mapping': conversation_to_session,
  771. 'conversations': {}
  772. }
  773. # 处理每个对话的完整数据
  774. for conversation_id, conversation_data in raw_cache.items():
  775. # 获取时间信息
  776. conversation_start_time = conversation_times.get(conversation_id)
  777. session_id = conversation_to_session.get(conversation_id)
  778. session_start_time = None
  779. if session_id and session_id in session_info:
  780. session_start_time = session_info[session_id]['start_time']
  781. processed_conversation = {
  782. 'conversation_id': conversation_id,
  783. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  784. 'session_id': session_id,
  785. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  786. 'field_count': len(conversation_data),
  787. 'fields': {}
  788. }
  789. # 添加时间计算
  790. if conversation_start_time:
  791. conversation_duration = datetime.now() - conversation_start_time
  792. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  793. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  794. if session_start_time:
  795. session_duration = datetime.now() - session_start_time
  796. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  797. processed_conversation['session_duration_formatted'] = str(session_duration)
  798. # 处理每个字段,确保JSON序列化安全
  799. for field_name, field_value in conversation_data.items():
  800. field_info = {
  801. 'field_name': field_name,
  802. 'data_type': type(field_value).__name__,
  803. 'is_none': field_value is None
  804. }
  805. try:
  806. if field_value is None:
  807. field_info['value'] = None
  808. elif field_name in ['conversation_start_time', 'session_start_time']:
  809. # 处理时间字段
  810. field_info['content'] = make_json_serializable(field_value)
  811. elif field_name == 'df' and field_value is not None:
  812. # DataFrame的安全处理
  813. if hasattr(field_value, 'to_dict'):
  814. # 安全地处理dtypes
  815. try:
  816. dtypes_dict = {}
  817. for col, dtype in field_value.dtypes.items():
  818. dtypes_dict[col] = str(dtype)
  819. except Exception:
  820. dtypes_dict = {"error": "无法序列化dtypes"}
  821. # 安全地处理内存使用
  822. try:
  823. memory_usage = field_value.memory_usage(deep=True)
  824. memory_dict = {}
  825. for idx, usage in memory_usage.items():
  826. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  827. except Exception:
  828. memory_dict = {"error": "无法获取内存使用信息"}
  829. field_info.update({
  830. 'dataframe_info': {
  831. 'shape': list(field_value.shape),
  832. 'columns': list(field_value.columns),
  833. 'dtypes': dtypes_dict,
  834. 'index_info': {
  835. 'type': type(field_value.index).__name__,
  836. 'length': len(field_value.index)
  837. }
  838. },
  839. 'data': make_json_serializable(field_value.to_dict('records')),
  840. 'memory_usage': memory_dict
  841. })
  842. else:
  843. field_info['value'] = str(field_value)
  844. field_info['note'] = 'not_standard_dataframe'
  845. elif field_name == 'fig_json':
  846. # 图表JSON数据处理
  847. if isinstance(field_value, str):
  848. try:
  849. import json
  850. parsed_fig = json.loads(field_value)
  851. field_info.update({
  852. 'json_valid': True,
  853. 'json_size_bytes': len(field_value),
  854. 'plotly_structure': {
  855. 'has_data': 'data' in parsed_fig,
  856. 'has_layout': 'layout' in parsed_fig,
  857. 'data_traces_count': len(parsed_fig.get('data', [])),
  858. },
  859. 'raw_json': field_value
  860. })
  861. except json.JSONDecodeError:
  862. field_info.update({
  863. 'json_valid': False,
  864. 'raw_content': str(field_value)
  865. })
  866. else:
  867. field_info['value'] = make_json_serializable(field_value)
  868. elif field_name == 'followup_questions':
  869. # 后续问题列表
  870. field_info.update({
  871. 'content': make_json_serializable(field_value)
  872. })
  873. elif field_name in ['question', 'sql', 'summary']:
  874. # 文本字段
  875. if isinstance(field_value, str):
  876. field_info.update({
  877. 'text_length': len(field_value),
  878. 'content': field_value
  879. })
  880. else:
  881. field_info['value'] = make_json_serializable(field_value)
  882. else:
  883. # 未知字段的安全处理
  884. field_info['content'] = make_json_serializable(field_value)
  885. except Exception as e:
  886. field_info.update({
  887. 'processing_error': str(e),
  888. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  889. })
  890. processed_conversation['fields'][field_name] = field_info
  891. export_data['conversations'][conversation_id] = processed_conversation
  892. # 添加缓存统计信息
  893. field_frequency = {}
  894. data_types_found = set()
  895. total_dataframes = 0
  896. total_questions = 0
  897. for conv_data in export_data['conversations'].values():
  898. for field_name, field_info in conv_data['fields'].items():
  899. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  900. data_types_found.add(field_info['data_type'])
  901. if field_name == 'df' and not field_info['is_none']:
  902. total_dataframes += 1
  903. if field_name == 'question' and not field_info['is_none']:
  904. total_questions += 1
  905. export_data['cache_statistics'] = {
  906. 'field_frequency': field_frequency,
  907. 'data_types_found': list(data_types_found),
  908. 'total_dataframes': total_dataframes,
  909. 'total_questions': total_questions,
  910. 'has_session_timing': 'session_start_time' in field_frequency,
  911. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  912. }
  913. return jsonify(result.success(data=export_data))
  914. except Exception as e:
  915. import traceback
  916. error_details = {
  917. 'error_message': str(e),
  918. 'error_type': type(e).__name__,
  919. 'traceback': traceback.format_exc()
  920. }
  921. return jsonify(result.failed(
  922. message=f"导出缓存失败: {str(e)}",
  923. code=500,
  924. data=error_details
  925. )), 500
  926. # ==================== 清理功能API ====================
  927. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  928. def cache_preview_cleanup():
  929. """清理功能:预览删除操作 - 保持原功能"""
  930. try:
  931. req = request.get_json(force=True)
  932. # 时间条件 - 支持三种方式
  933. older_than_hours = req.get('older_than_hours')
  934. older_than_days = req.get('older_than_days')
  935. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  936. cache = app.cache
  937. # 计算截止时间
  938. cutoff_time = None
  939. time_condition = None
  940. if older_than_hours:
  941. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  942. time_condition = f"older_than_hours: {older_than_hours}"
  943. elif older_than_days:
  944. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  945. time_condition = f"older_than_days: {older_than_days}"
  946. elif before_timestamp:
  947. try:
  948. # 支持 YYYY-MM-DD HH:MM:SS 格式
  949. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  950. time_condition = f"before_timestamp: {before_timestamp}"
  951. except ValueError:
  952. return jsonify(result.failed(
  953. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  954. code=400
  955. )), 400
  956. else:
  957. return jsonify(result.failed(
  958. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  959. code=400
  960. )), 400
  961. preview = {
  962. 'time_condition': time_condition,
  963. 'cutoff_time': cutoff_time.isoformat(),
  964. 'will_be_removed': {
  965. 'sessions': []
  966. },
  967. 'will_be_kept': {
  968. 'sessions_count': 0,
  969. 'conversations_count': 0
  970. },
  971. 'summary': {
  972. 'sessions_to_remove': 0,
  973. 'conversations_to_remove': 0,
  974. 'sessions_to_keep': 0,
  975. 'conversations_to_keep': 0
  976. }
  977. }
  978. # 预览按session删除
  979. sessions_to_remove_count = 0
  980. conversations_to_remove_count = 0
  981. for session_id, session_data in cache.session_info.items():
  982. session_preview = {
  983. 'session_id': session_id,
  984. 'start_time': session_data['start_time'].isoformat(),
  985. 'conversation_count': len(session_data['conversations']),
  986. 'conversations': []
  987. }
  988. # 添加conversation详情
  989. for conv_id in session_data['conversations']:
  990. if conv_id in cache.cache:
  991. conv_data = cache.cache[conv_id]
  992. session_preview['conversations'].append({
  993. 'conversation_id': conv_id,
  994. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  995. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  996. })
  997. if session_data['start_time'] < cutoff_time:
  998. preview['will_be_removed']['sessions'].append(session_preview)
  999. sessions_to_remove_count += 1
  1000. conversations_to_remove_count += len(session_data['conversations'])
  1001. else:
  1002. preview['will_be_kept']['sessions_count'] += 1
  1003. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1004. # 更新摘要统计
  1005. preview['summary'] = {
  1006. 'sessions_to_remove': sessions_to_remove_count,
  1007. 'conversations_to_remove': conversations_to_remove_count,
  1008. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1009. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1010. }
  1011. return jsonify(result.success(data=preview))
  1012. except Exception as e:
  1013. return jsonify(result.failed(
  1014. message=f"预览清理操作失败: {str(e)}",
  1015. code=500
  1016. )), 500
  1017. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1018. def cache_cleanup():
  1019. """清理功能:实际删除缓存 - 保持原功能"""
  1020. try:
  1021. req = request.get_json(force=True)
  1022. # 时间条件 - 支持三种方式
  1023. older_than_hours = req.get('older_than_hours')
  1024. older_than_days = req.get('older_than_days')
  1025. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1026. cache = app.cache
  1027. if not hasattr(cache, 'session_info'):
  1028. return jsonify(result.failed(
  1029. message="缓存不支持会话功能",
  1030. code=400
  1031. )), 400
  1032. # 计算截止时间
  1033. cutoff_time = None
  1034. time_condition = None
  1035. if older_than_hours:
  1036. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1037. time_condition = f"older_than_hours: {older_than_hours}"
  1038. elif older_than_days:
  1039. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1040. time_condition = f"older_than_days: {older_than_days}"
  1041. elif before_timestamp:
  1042. try:
  1043. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1044. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1045. time_condition = f"before_timestamp: {before_timestamp}"
  1046. except ValueError:
  1047. return jsonify(result.failed(
  1048. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  1049. code=400
  1050. )), 400
  1051. else:
  1052. return jsonify(result.failed(
  1053. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1054. code=400
  1055. )), 400
  1056. cleanup_stats = {
  1057. 'time_condition': time_condition,
  1058. 'cutoff_time': cutoff_time.isoformat(),
  1059. 'sessions_removed': 0,
  1060. 'conversations_removed': 0,
  1061. 'sessions_kept': 0,
  1062. 'conversations_kept': 0,
  1063. 'removed_session_ids': [],
  1064. 'removed_conversation_ids': []
  1065. }
  1066. # 按session删除
  1067. sessions_to_remove = []
  1068. for session_id, session_data in cache.session_info.items():
  1069. if session_data['start_time'] < cutoff_time:
  1070. sessions_to_remove.append(session_id)
  1071. # 删除符合条件的sessions及其所有conversations
  1072. for session_id in sessions_to_remove:
  1073. session_data = cache.session_info[session_id]
  1074. conversations_in_session = session_data['conversations'].copy()
  1075. # 删除session中的所有conversations
  1076. for conv_id in conversations_in_session:
  1077. if conv_id in cache.cache:
  1078. del cache.cache[conv_id]
  1079. cleanup_stats['conversations_removed'] += 1
  1080. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1081. # 清理conversation相关的时间记录
  1082. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1083. del cache.conversation_start_times[conv_id]
  1084. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1085. del cache.conversation_to_session[conv_id]
  1086. # 删除session记录
  1087. del cache.session_info[session_id]
  1088. cleanup_stats['sessions_removed'] += 1
  1089. cleanup_stats['removed_session_ids'].append(session_id)
  1090. # 统计保留的sessions和conversations
  1091. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1092. cleanup_stats['conversations_kept'] = len(cache.cache)
  1093. return jsonify(result.success(data=cleanup_stats))
  1094. except Exception as e:
  1095. return jsonify(result.failed(
  1096. message=f"清理缓存失败: {str(e)}",
  1097. code=500
  1098. )), 500
  1099. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1100. def training_error_question_sql():
  1101. """
  1102. 存储错误的question-sql对到error_sql集合中
  1103. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1104. Args:
  1105. question (str, required): 用户问题
  1106. sql (str, required): 对应的错误SQL查询语句
  1107. Returns:
  1108. JSON: 包含训练ID和成功消息的响应
  1109. """
  1110. try:
  1111. data = request.get_json()
  1112. question = data.get('question')
  1113. sql = data.get('sql')
  1114. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1115. if not question or not sql:
  1116. return jsonify(result.failed(
  1117. message="question和sql参数都是必需的",
  1118. code=400
  1119. )), 400
  1120. # 使用vn实例的train_error_sql方法存储错误SQL
  1121. id = vn.train_error_sql(question=question, sql=sql)
  1122. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1123. return jsonify(result.success(data={
  1124. "id": id,
  1125. "message": "错误SQL对已成功存储到error_sql集合"
  1126. }))
  1127. except Exception as e:
  1128. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1129. return jsonify(result.failed(
  1130. message=f"存储错误SQL失败: {str(e)}",
  1131. code=500
  1132. )), 500
  1133. # 前端JavaScript示例 - 如何维持会话
  1134. """
  1135. // 前端需要维护一个会话ID
  1136. class ChatSession {
  1137. constructor() {
  1138. // 从localStorage获取或创建新的会话ID
  1139. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1140. localStorage.setItem('chat_session_id', this.sessionId);
  1141. }
  1142. generateSessionId() {
  1143. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1144. }
  1145. async askQuestion(question) {
  1146. const response = await fetch('/api/v0/ask', {
  1147. method: 'POST',
  1148. headers: {
  1149. 'Content-Type': 'application/json',
  1150. },
  1151. body: JSON.stringify({
  1152. question: question,
  1153. session_id: this.sessionId // 关键:传递会话ID
  1154. })
  1155. });
  1156. return await response.json();
  1157. }
  1158. // 开始新会话
  1159. startNewSession() {
  1160. this.sessionId = this.generateSessionId();
  1161. localStorage.setItem('chat_session_id', this.sessionId);
  1162. }
  1163. }
  1164. // 使用示例
  1165. const chatSession = new ChatSession();
  1166. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1167. """
  1168. print("正在启动Flask应用: http://localhost:8084")
  1169. app.run(host="0.0.0.0", port=8084, debug=True)