citu_app.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  15. from common.result import ( # 统一导入所有需要的响应函数
  16. bad_request_response, service_unavailable_response,
  17. agent_success_response, agent_error_response,
  18. internal_error_response, success_response,
  19. validation_failed_response
  20. )
  21. from app_config import ( # 添加Redis相关配置导入
  22. USER_MAX_CONVERSATIONS,
  23. CONVERSATION_CONTEXT_COUNT
  24. )
  25. # 设置默认的最大返回行数
  26. DEFAULT_MAX_RETURN_ROWS = 200
  27. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  28. vn = create_vanna_instance()
  29. # 创建带时间戳的缓存
  30. timestamped_cache = WebSessionAwareMemoryCache()
  31. # 实例化 VannaFlaskApp,使用自定义缓存
  32. app = VannaFlaskApp(
  33. vn,
  34. cache=timestamped_cache, # 使用带时间戳的缓存
  35. title="辞图智能数据问答平台",
  36. logo = "https://www.citupro.com/img/logo-black-2.png",
  37. subtitle="让 AI 为你写 SQL",
  38. chart=False,
  39. allow_llm_to_see_data=True,
  40. ask_results_correct=True,
  41. followup_questions=True,
  42. debug=True
  43. )
  44. # 创建Redis对话管理器实例
  45. redis_conversation_manager = RedisConversationManager()
  46. # 修改ask接口,支持前端传递session_id
  47. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  48. def ask_full():
  49. req = request.get_json(force=True)
  50. question = req.get("question", None)
  51. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  52. if not question:
  53. from common.result import bad_request_response
  54. return jsonify(bad_request_response(
  55. response_text="缺少必需参数:question",
  56. missing_params=["question"]
  57. )), 400
  58. # 如果使用WebSessionAwareMemoryCache
  59. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  60. # 这里需要修改vanna的ask方法来支持传递session_id
  61. # 或者预先调用generate_id来建立会话关联
  62. conversation_id = app.cache.generate_id_with_browser_session(
  63. question=question,
  64. browser_session_id=browser_session_id
  65. )
  66. try:
  67. sql, df, _ = vn.ask(
  68. question=question,
  69. print_results=False,
  70. visualize=False,
  71. allow_llm_to_see_data=True
  72. )
  73. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  74. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  75. # 在解释性文本末尾添加提示语
  76. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  77. # 使用标准化错误响应
  78. from common.result import validation_failed_response
  79. return jsonify(validation_failed_response(
  80. response_text=explanation_message
  81. )), 422 # 修改HTTP状态码为422
  82. # 如果sql为None但没有解释性文本,返回通用错误
  83. if sql is None:
  84. from common.result import validation_failed_response
  85. return jsonify(validation_failed_response(
  86. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  87. )), 422
  88. # 处理返回数据 - 使用新的query_result结构
  89. query_result = {
  90. "rows": [],
  91. "columns": [],
  92. "row_count": 0,
  93. "is_limited": False,
  94. "total_row_count": 0
  95. }
  96. summary = None
  97. if isinstance(df, pd.DataFrame):
  98. query_result["columns"] = list(df.columns)
  99. if not df.empty:
  100. total_rows = len(df)
  101. limited_df = df.head(MAX_RETURN_ROWS)
  102. query_result["rows"] = limited_df.to_dict(orient="records")
  103. query_result["row_count"] = len(limited_df)
  104. query_result["total_row_count"] = total_rows
  105. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  106. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  107. if ENABLE_RESULT_SUMMARY:
  108. try:
  109. summary = vn.generate_summary(question=question, df=df)
  110. print(f"[INFO] 成功生成摘要: {summary}")
  111. except Exception as e:
  112. print(f"[WARNING] 生成摘要失败: {str(e)}")
  113. summary = None
  114. # 构建返回数据
  115. response_data = {
  116. "sql": sql,
  117. "query_result": query_result,
  118. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  119. "session_id": browser_session_id
  120. }
  121. # 添加摘要(如果启用且生成成功)
  122. if ENABLE_RESULT_SUMMARY and summary is not None:
  123. response_data["summary"] = summary
  124. response_data["response"] = summary # 同时添加response字段
  125. from common.result import success_response
  126. return jsonify(success_response(
  127. response_text="查询执行完成" if summary is None else None,
  128. data=response_data
  129. ))
  130. except Exception as e:
  131. print(f"[ERROR] ask_full执行失败: {str(e)}")
  132. # 即使发生异常,也检查是否有业务层面的解释
  133. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  134. # 在解释性文本末尾添加提示语
  135. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  136. from common.result import validation_failed_response
  137. return jsonify(validation_failed_response(
  138. response_text=explanation_message
  139. )), 422
  140. else:
  141. # 技术错误,使用500错误码
  142. from common.result import internal_error_response
  143. return jsonify(internal_error_response(
  144. response_text="查询处理失败,请稍后重试"
  145. )), 500
  146. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  147. def citu_run_sql():
  148. req = request.get_json(force=True)
  149. sql = req.get('sql')
  150. if not sql:
  151. from common.result import bad_request_response
  152. return jsonify(bad_request_response(
  153. response_text="缺少必需参数:sql",
  154. missing_params=["sql"]
  155. )), 400
  156. try:
  157. df = vn.run_sql(sql)
  158. # 处理返回数据 - 使用新的query_result结构
  159. query_result = {
  160. "rows": [],
  161. "columns": [],
  162. "row_count": 0,
  163. "is_limited": False,
  164. "total_row_count": 0
  165. }
  166. if isinstance(df, pd.DataFrame):
  167. query_result["columns"] = list(df.columns)
  168. if not df.empty:
  169. total_rows = len(df)
  170. limited_df = df.head(MAX_RETURN_ROWS)
  171. query_result["rows"] = limited_df.to_dict(orient="records")
  172. query_result["row_count"] = len(limited_df)
  173. query_result["total_row_count"] = total_rows
  174. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  175. from common.result import success_response
  176. return jsonify(success_response(
  177. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  178. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  179. data={
  180. "sql": sql,
  181. "query_result": query_result
  182. }
  183. ))
  184. except Exception as e:
  185. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  186. from common.result import internal_error_response
  187. return jsonify(internal_error_response(
  188. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  189. )), 500
  190. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  191. def ask_cached():
  192. """
  193. 带缓存功能的智能查询接口
  194. 支持会话管理和结果缓存,提高查询效率
  195. """
  196. req = request.get_json(force=True)
  197. question = req.get("question", None)
  198. browser_session_id = req.get("session_id", None)
  199. if not question:
  200. from common.result import bad_request_response
  201. return jsonify(bad_request_response(
  202. response_text="缺少必需参数:question",
  203. missing_params=["question"]
  204. )), 400
  205. try:
  206. # 生成conversation_id
  207. # 调试:查看generate_id的实际行为
  208. print(f"[DEBUG] 输入问题: '{question}'")
  209. conversation_id = app.cache.generate_id(question=question)
  210. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  211. # 再次用相同问题测试
  212. conversation_id2 = app.cache.generate_id(question=question)
  213. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  214. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  215. # 检查缓存
  216. cached_sql = app.cache.get(id=conversation_id, field="sql")
  217. if cached_sql is not None:
  218. # 缓存命中
  219. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  220. sql = cached_sql
  221. df = app.cache.get(id=conversation_id, field="df")
  222. summary = app.cache.get(id=conversation_id, field="summary")
  223. else:
  224. # 缓存未命中,执行新查询
  225. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  226. sql, df, _ = vn.ask(
  227. question=question,
  228. print_results=False,
  229. visualize=False,
  230. allow_llm_to_see_data=True
  231. )
  232. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  233. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  234. # 在解释性文本末尾添加提示语
  235. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  236. from common.result import validation_failed_response
  237. return jsonify(validation_failed_response(
  238. response_text=explanation_message
  239. )), 422
  240. # 如果sql为None但没有解释性文本,返回通用错误
  241. if sql is None:
  242. from common.result import validation_failed_response
  243. return jsonify(validation_failed_response(
  244. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  245. )), 422
  246. # 缓存结果
  247. app.cache.set(id=conversation_id, field="question", value=question)
  248. app.cache.set(id=conversation_id, field="sql", value=sql)
  249. app.cache.set(id=conversation_id, field="df", value=df)
  250. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  251. summary = None
  252. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  253. try:
  254. summary = vn.generate_summary(question=question, df=df)
  255. print(f"[INFO] 成功生成摘要: {summary}")
  256. except Exception as e:
  257. print(f"[WARNING] 生成摘要失败: {str(e)}")
  258. summary = None
  259. app.cache.set(id=conversation_id, field="summary", value=summary)
  260. # 处理返回数据 - 使用新的query_result结构
  261. query_result = {
  262. "rows": [],
  263. "columns": [],
  264. "row_count": 0,
  265. "is_limited": False,
  266. "total_row_count": 0
  267. }
  268. if isinstance(df, pd.DataFrame):
  269. query_result["columns"] = list(df.columns)
  270. if not df.empty:
  271. total_rows = len(df)
  272. limited_df = df.head(MAX_RETURN_ROWS)
  273. query_result["rows"] = limited_df.to_dict(orient="records")
  274. query_result["row_count"] = len(limited_df)
  275. query_result["total_row_count"] = total_rows
  276. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  277. # 构建返回数据
  278. response_data = {
  279. "sql": sql,
  280. "query_result": query_result,
  281. "conversation_id": conversation_id,
  282. "session_id": browser_session_id,
  283. "cached": cached_sql is not None # 标识是否来自缓存
  284. }
  285. # 添加摘要(如果启用且生成成功)
  286. if ENABLE_RESULT_SUMMARY and summary is not None:
  287. response_data["summary"] = summary
  288. response_data["response"] = summary # 同时添加response字段
  289. from common.result import success_response
  290. return jsonify(success_response(
  291. response_text="查询执行完成" if summary is None else None,
  292. data=response_data
  293. ))
  294. except Exception as e:
  295. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  296. from common.result import internal_error_response
  297. return jsonify(internal_error_response(
  298. response_text="查询处理失败,请稍后重试"
  299. )), 500
  300. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  301. def citu_train_question_sql():
  302. """
  303. 训练问题-SQL对接口
  304. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  305. 支持仅传入SQL或同时传入问题和SQL进行训练。
  306. Args:
  307. question (str, optional): 用户问题
  308. sql (str, required): 对应的SQL查询语句
  309. Returns:
  310. JSON: 包含训练ID和成功消息的响应
  311. """
  312. try:
  313. req = request.get_json(force=True)
  314. question = req.get('question')
  315. sql = req.get('sql')
  316. if not sql:
  317. from common.result import bad_request_response
  318. return jsonify(bad_request_response(
  319. response_text="缺少必需参数:sql",
  320. missing_params=["sql"]
  321. )), 400
  322. # 正确的调用方式:同时传递question和sql
  323. if question:
  324. training_id = vn.train(question=question, sql=sql)
  325. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  326. else:
  327. training_id = vn.train(sql=sql)
  328. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  329. from common.result import success_response
  330. return jsonify(success_response(
  331. response_text="问题-SQL对训练成功",
  332. data={
  333. "training_id": training_id,
  334. "message": "Question-SQL pair trained successfully"
  335. }
  336. ))
  337. except Exception as e:
  338. from common.result import internal_error_response
  339. return jsonify(internal_error_response(
  340. response_text="训练失败,请稍后重试"
  341. )), 500
  342. # ============ LangGraph Agent 集成 ============
  343. # 全局Agent实例(单例模式)
  344. citu_langraph_agent = None
  345. def get_citu_langraph_agent():
  346. """获取LangGraph Agent实例(懒加载)"""
  347. global citu_langraph_agent
  348. if citu_langraph_agent is None:
  349. try:
  350. from agent.citu_agent import CituLangGraphAgent
  351. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  352. citu_langraph_agent = CituLangGraphAgent()
  353. print("[CITU_APP] LangGraph Agent实例创建成功")
  354. except ImportError as e:
  355. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  356. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  357. raise Exception(f"Agent模块导入失败: {str(e)}")
  358. except Exception as e:
  359. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  360. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  361. # 提供更有用的错误信息
  362. if "config" in str(e).lower():
  363. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  364. elif "llm" in str(e).lower():
  365. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  366. elif "tool" in str(e).lower():
  367. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  368. raise Exception(f"Agent初始化失败: {str(e)}")
  369. return citu_langraph_agent
  370. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  371. def ask_agent():
  372. """
  373. 支持对话上下文的ask_agent API - 修正版
  374. """
  375. req = request.get_json(force=True)
  376. question = req.get("question", None)
  377. browser_session_id = req.get("session_id", None)
  378. # 新增参数解析
  379. user_id_input = req.get("user_id", None)
  380. conversation_id_input = req.get("conversation_id", None)
  381. continue_conversation = req.get("continue_conversation", False)
  382. if not question:
  383. return jsonify(bad_request_response(
  384. response_text="缺少必需参数:question",
  385. missing_params=["question"]
  386. )), 400
  387. try:
  388. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  389. login_user_id = session.get('user_id') if 'user_id' in session else None
  390. # 2. 智能ID解析(修正:传入登录用户ID)
  391. user_id = redis_conversation_manager.resolve_user_id(
  392. user_id_input, browser_session_id, request.remote_addr, login_user_id
  393. )
  394. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  395. user_id, conversation_id_input, continue_conversation
  396. )
  397. # 3. 获取上下文(提前到缓存检查之前)
  398. context = redis_conversation_manager.get_context(conversation_id)
  399. # 4. 检查缓存(修正:传入context以实现真正的上下文感知)
  400. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  401. if cached_answer:
  402. print(f"[AGENT_API] 使用缓存答案")
  403. # 更新对话历史
  404. redis_conversation_manager.save_message(conversation_id, "user", question)
  405. redis_conversation_manager.save_message(
  406. conversation_id, "assistant",
  407. cached_answer.get("response", ""),
  408. metadata={"from_cache": True}
  409. )
  410. # 添加对话信息到缓存结果
  411. cached_answer["conversation_id"] = conversation_id
  412. cached_answer["user_id"] = user_id
  413. cached_answer["from_cache"] = True
  414. cached_answer.update(conversation_status)
  415. # 使用agent_success_response返回标准格式
  416. return jsonify(agent_success_response(
  417. response_type=cached_answer.get("type", "UNKNOWN"),
  418. response_text=cached_answer.get("response", ""),
  419. sql=cached_answer.get("sql"),
  420. query_result=cached_answer.get("query_result"),
  421. summary=cached_answer.get("summary"),
  422. session_id=browser_session_id,
  423. execution_path=cached_answer.get("execution_path", []),
  424. classification_info=cached_answer.get("classification_info", {}),
  425. conversation_id=conversation_id,
  426. user_id=user_id,
  427. is_guest_user=user_id.startswith("guest"),
  428. context_used=bool(context),
  429. from_cache=True,
  430. conversation_status=conversation_status["status"],
  431. conversation_message=conversation_status["message"],
  432. requested_conversation_id=conversation_status.get("requested_id")
  433. ))
  434. # 5. 保存用户消息
  435. redis_conversation_manager.save_message(conversation_id, "user", question)
  436. # 6. 构建带上下文的问题
  437. if context:
  438. enhanced_question = f"对话历史:\n{context}\n\n当前问题: {question}"
  439. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  440. else:
  441. enhanced_question = question
  442. print(f"[AGENT_API] 新对话,无上下文")
  443. # 7. 现有Agent处理逻辑(保持不变)
  444. try:
  445. agent = get_citu_langraph_agent()
  446. except Exception as e:
  447. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  448. return jsonify(service_unavailable_response(
  449. response_text="AI服务暂时不可用,请稍后重试",
  450. can_retry=True
  451. )), 503
  452. agent_result = agent.process_question(
  453. question=enhanced_question, # 使用增强后的问题
  454. session_id=browser_session_id
  455. )
  456. # 8. 处理Agent结果
  457. if agent_result.get("success", False):
  458. # 修正:直接从agent_result获取字段,因为它就是final_response
  459. response_type = agent_result.get("type", "UNKNOWN")
  460. response_text = agent_result.get("response", "")
  461. sql = agent_result.get("sql")
  462. query_result = agent_result.get("query_result")
  463. summary = agent_result.get("summary")
  464. execution_path = agent_result.get("execution_path", [])
  465. classification_info = agent_result.get("classification_info", {})
  466. assistant_response = response_text # 使用response_text作为助手回复
  467. # 保存助手回复
  468. redis_conversation_manager.save_message(
  469. conversation_id, "assistant", assistant_response,
  470. metadata={
  471. "type": response_type,
  472. "sql": sql,
  473. "execution_path": execution_path
  474. }
  475. )
  476. # 缓存成功的答案(修正:缓存完整的agent_result)
  477. # 直接缓存agent_result,它已经包含所有需要的字段
  478. redis_conversation_manager.cache_answer(question, agent_result, context)
  479. # 使用agent_success_response的正确方式
  480. return jsonify(agent_success_response(
  481. response_type=response_type,
  482. response_text=response_text,
  483. sql=sql,
  484. query_result=query_result,
  485. summary=summary,
  486. session_id=browser_session_id,
  487. execution_path=execution_path,
  488. classification_info=classification_info,
  489. conversation_id=conversation_id,
  490. user_id=user_id,
  491. is_guest_user=user_id.startswith("guest"),
  492. context_used=bool(context),
  493. from_cache=False,
  494. conversation_status=conversation_status["status"],
  495. conversation_message=conversation_status["message"],
  496. requested_conversation_id=conversation_status.get("requested_id")
  497. ))
  498. else:
  499. # 错误处理(修正:确保使用现有的错误响应格式)
  500. error_message = agent_result.get("error", "Agent处理失败")
  501. error_code = agent_result.get("error_code", 500)
  502. return jsonify(agent_error_response(
  503. response_text=error_message,
  504. error_type="agent_processing_failed",
  505. code=error_code,
  506. session_id=browser_session_id,
  507. conversation_id=conversation_id,
  508. user_id=user_id
  509. )), error_code
  510. except Exception as e:
  511. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  512. return jsonify(internal_error_response(
  513. response_text="查询处理失败,请稍后重试"
  514. )), 500
  515. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  516. def agent_health():
  517. """
  518. Agent健康检查接口
  519. 响应格式:
  520. {
  521. "success": true/false,
  522. "code": 200/503,
  523. "message": "healthy/degraded/unhealthy",
  524. "data": {
  525. "status": "healthy/degraded/unhealthy",
  526. "test_result": true/false,
  527. "workflow_compiled": true/false,
  528. "tools_count": 4,
  529. "message": "详细信息",
  530. "timestamp": "2024-01-01T12:00:00",
  531. "checks": {
  532. "agent_creation": true/false,
  533. "tools_import": true/false,
  534. "llm_connection": true/false,
  535. "classifier_ready": true/false
  536. }
  537. }
  538. }
  539. """
  540. try:
  541. # 基础健康检查
  542. health_data = {
  543. "status": "unknown",
  544. "test_result": False,
  545. "workflow_compiled": False,
  546. "tools_count": 0,
  547. "message": "",
  548. "timestamp": datetime.now().isoformat(),
  549. "checks": {
  550. "agent_creation": False,
  551. "tools_import": False,
  552. "llm_connection": False,
  553. "classifier_ready": False
  554. }
  555. }
  556. # 检查1: Agent创建
  557. try:
  558. agent = get_citu_langraph_agent()
  559. health_data["checks"]["agent_creation"] = True
  560. health_data["workflow_compiled"] = agent.workflow is not None
  561. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  562. except Exception as e:
  563. health_data["message"] = f"Agent创建失败: {str(e)}"
  564. from common.result import health_error_response
  565. return jsonify(health_error_response(
  566. status="unhealthy",
  567. **health_data
  568. )), 503
  569. # 检查2: 工具导入
  570. try:
  571. from agent.tools import TOOLS
  572. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  573. except Exception as e:
  574. health_data["message"] = f"工具导入失败: {str(e)}"
  575. # 检查3: LLM连接(简单测试)
  576. try:
  577. from agent.utils import get_compatible_llm
  578. llm = get_compatible_llm()
  579. health_data["checks"]["llm_connection"] = llm is not None
  580. except Exception as e:
  581. health_data["message"] = f"LLM连接失败: {str(e)}"
  582. # 检查4: 分类器准备
  583. try:
  584. from agent.classifier import QuestionClassifier
  585. classifier = QuestionClassifier()
  586. health_data["checks"]["classifier_ready"] = True
  587. except Exception as e:
  588. health_data["message"] = f"分类器失败: {str(e)}"
  589. # 检查5: 完整流程测试(可选)
  590. try:
  591. if all(health_data["checks"].values()):
  592. test_result = agent.health_check()
  593. health_data["test_result"] = test_result.get("status") == "healthy"
  594. health_data["status"] = test_result.get("status", "unknown")
  595. health_data["message"] = test_result.get("message", "健康检查完成")
  596. else:
  597. health_data["status"] = "degraded"
  598. health_data["message"] = "部分组件异常"
  599. except Exception as e:
  600. health_data["status"] = "degraded"
  601. health_data["message"] = f"完整测试失败: {str(e)}"
  602. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  603. from common.result import health_success_response, health_error_response
  604. if health_data["status"] == "healthy":
  605. return jsonify(health_success_response(**health_data))
  606. elif health_data["status"] == "degraded":
  607. return jsonify(health_error_response(status="degraded", **health_data)), 503
  608. else:
  609. return jsonify(health_error_response(status="unhealthy", **health_data)), 503
  610. except Exception as e:
  611. print(f"[ERROR] 健康检查异常: {str(e)}")
  612. from common.result import internal_error_response
  613. return jsonify(internal_error_response(
  614. response_text="健康检查失败,请稍后重试"
  615. )), 500
  616. # ==================== 日常管理API ====================
  617. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  618. def cache_overview():
  619. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  620. try:
  621. cache = app.cache
  622. result_data = {
  623. 'overview_summary': {
  624. 'total_conversations': 0,
  625. 'total_sessions': 0,
  626. 'query_time': datetime.now().isoformat()
  627. },
  628. 'recent_conversations': [], # 最近的对话
  629. 'session_summary': [] # 会话摘要
  630. }
  631. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  632. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  633. # 获取会话信息
  634. if hasattr(cache, 'get_all_sessions'):
  635. all_sessions = cache.get_all_sessions()
  636. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  637. # 会话摘要(按最近活动排序)
  638. session_list = []
  639. for session_id, session_data in all_sessions.items():
  640. session_summary = {
  641. 'session_id': session_id,
  642. 'start_time': session_data['start_time'].isoformat(),
  643. 'conversation_count': session_data.get('conversation_count', 0),
  644. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  645. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  646. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  647. }
  648. session_list.append(session_summary)
  649. # 按最后活动时间排序
  650. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  651. result_data['session_summary'] = session_list
  652. # 最近的对话(最多显示10个)
  653. conversation_list = []
  654. for conversation_id, conversation_data in cache.cache.items():
  655. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  656. conversation_info = {
  657. 'conversation_id': conversation_id,
  658. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  659. 'session_id': cache.conversation_to_session.get(conversation_id),
  660. 'has_question': 'question' in conversation_data,
  661. 'has_sql': 'sql' in conversation_data,
  662. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  663. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  664. }
  665. # 计算对话持续时间
  666. if conversation_start_time:
  667. duration = datetime.now() - conversation_start_time
  668. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  669. conversation_list.append(conversation_info)
  670. # 按对话开始时间排序,显示最新的10个
  671. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  672. result_data['recent_conversations'] = conversation_list[:10]
  673. from common.result import success_response
  674. return jsonify(success_response(
  675. response_text="缓存概览查询完成",
  676. data=result_data
  677. ))
  678. except Exception as e:
  679. from common.result import internal_error_response
  680. return jsonify(internal_error_response(
  681. response_text="获取缓存概览失败,请稍后重试"
  682. )), 500
  683. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  684. def cache_stats():
  685. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  686. try:
  687. cache = app.cache
  688. current_time = datetime.now()
  689. stats = {
  690. 'basic_stats': {
  691. 'total_sessions': len(getattr(cache, 'session_info', {})),
  692. 'total_conversations': len(getattr(cache, 'cache', {})),
  693. 'active_sessions': 0, # 最近30分钟有活动
  694. 'average_conversations_per_session': 0
  695. },
  696. 'time_distribution': {
  697. 'sessions': {
  698. 'last_1_hour': 0,
  699. 'last_6_hours': 0,
  700. 'last_24_hours': 0,
  701. 'last_7_days': 0,
  702. 'older': 0
  703. },
  704. 'conversations': {
  705. 'last_1_hour': 0,
  706. 'last_6_hours': 0,
  707. 'last_24_hours': 0,
  708. 'last_7_days': 0,
  709. 'older': 0
  710. }
  711. },
  712. 'session_details': [],
  713. 'time_ranges': {
  714. 'oldest_session': None,
  715. 'newest_session': None,
  716. 'oldest_conversation': None,
  717. 'newest_conversation': None
  718. }
  719. }
  720. # 会话统计
  721. if hasattr(cache, 'session_info'):
  722. session_times = []
  723. total_conversations = 0
  724. for session_id, session_data in cache.session_info.items():
  725. start_time = session_data['start_time']
  726. session_times.append(start_time)
  727. conversation_count = len(session_data.get('conversations', []))
  728. total_conversations += conversation_count
  729. # 检查活跃状态
  730. last_activity = session_data.get('last_activity', session_data['start_time'])
  731. if (current_time - last_activity).total_seconds() < 1800:
  732. stats['basic_stats']['active_sessions'] += 1
  733. # 时间分布统计
  734. age_hours = (current_time - start_time).total_seconds() / 3600
  735. if age_hours <= 1:
  736. stats['time_distribution']['sessions']['last_1_hour'] += 1
  737. elif age_hours <= 6:
  738. stats['time_distribution']['sessions']['last_6_hours'] += 1
  739. elif age_hours <= 24:
  740. stats['time_distribution']['sessions']['last_24_hours'] += 1
  741. elif age_hours <= 168: # 7 days
  742. stats['time_distribution']['sessions']['last_7_days'] += 1
  743. else:
  744. stats['time_distribution']['sessions']['older'] += 1
  745. # 会话详细信息
  746. session_duration = current_time - start_time
  747. stats['session_details'].append({
  748. 'session_id': session_id,
  749. 'start_time': start_time.isoformat(),
  750. 'last_activity': last_activity.isoformat(),
  751. 'conversation_count': conversation_count,
  752. 'duration_seconds': session_duration.total_seconds(),
  753. 'duration_formatted': str(session_duration),
  754. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  755. 'browser_session_id': session_data.get('browser_session_id')
  756. })
  757. # 计算平均值
  758. if len(cache.session_info) > 0:
  759. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  760. # 时间范围
  761. if session_times:
  762. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  763. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  764. # 对话统计
  765. if hasattr(cache, 'conversation_start_times'):
  766. conversation_times = []
  767. for conv_time in cache.conversation_start_times.values():
  768. conversation_times.append(conv_time)
  769. age_hours = (current_time - conv_time).total_seconds() / 3600
  770. if age_hours <= 1:
  771. stats['time_distribution']['conversations']['last_1_hour'] += 1
  772. elif age_hours <= 6:
  773. stats['time_distribution']['conversations']['last_6_hours'] += 1
  774. elif age_hours <= 24:
  775. stats['time_distribution']['conversations']['last_24_hours'] += 1
  776. elif age_hours <= 168:
  777. stats['time_distribution']['conversations']['last_7_days'] += 1
  778. else:
  779. stats['time_distribution']['conversations']['older'] += 1
  780. if conversation_times:
  781. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  782. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  783. # 按最近活动排序会话详情
  784. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  785. from common.result import success_response
  786. return jsonify(success_response(
  787. response_text="缓存统计信息查询完成",
  788. data=stats
  789. ))
  790. except Exception as e:
  791. from common.result import internal_error_response
  792. return jsonify(internal_error_response(
  793. response_text="获取缓存统计失败,请稍后重试"
  794. )), 500
  795. # ==================== 高级功能API ====================
  796. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  797. def cache_export():
  798. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  799. try:
  800. cache = app.cache
  801. # 验证缓存的实际结构
  802. if not hasattr(cache, 'cache'):
  803. from common.result import internal_error_response
  804. return jsonify(internal_error_response(
  805. response_text="缓存对象结构异常,请联系系统管理员"
  806. )), 500
  807. if not isinstance(cache.cache, dict):
  808. from common.result import internal_error_response
  809. return jsonify(internal_error_response(
  810. response_text="缓存数据类型异常,请联系系统管理员"
  811. )), 500
  812. # 定义JSON序列化辅助函数
  813. def make_json_serializable(obj):
  814. """将对象转换为JSON可序列化的格式"""
  815. if obj is None:
  816. return None
  817. elif isinstance(obj, (str, int, float, bool)):
  818. return obj
  819. elif isinstance(obj, (list, tuple)):
  820. return [make_json_serializable(item) for item in obj]
  821. elif isinstance(obj, dict):
  822. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  823. elif hasattr(obj, 'isoformat'): # datetime objects
  824. return obj.isoformat()
  825. elif hasattr(obj, 'item'): # numpy scalars
  826. return obj.item()
  827. elif hasattr(obj, 'tolist'): # numpy arrays
  828. return obj.tolist()
  829. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  830. return str(obj)
  831. else:
  832. return str(obj)
  833. # 获取完整的原始缓存数据
  834. raw_cache = cache.cache
  835. # 获取会话和对话时间信息
  836. conversation_times = getattr(cache, 'conversation_start_times', {})
  837. session_info = getattr(cache, 'session_info', {})
  838. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  839. export_data = {
  840. 'export_metadata': {
  841. 'export_time': datetime.now().isoformat(),
  842. 'total_conversations': len(raw_cache),
  843. 'total_sessions': len(session_info),
  844. 'cache_type': type(cache).__name__,
  845. 'cache_object_info': str(cache),
  846. 'has_session_times': bool(session_info),
  847. 'has_conversation_times': bool(conversation_times)
  848. },
  849. 'session_info': {
  850. session_id: {
  851. 'start_time': session_data['start_time'].isoformat(),
  852. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  853. 'conversations': session_data['conversations'],
  854. 'conversation_count': len(session_data['conversations']),
  855. 'browser_session_id': session_data.get('browser_session_id'),
  856. 'user_info': session_data.get('user_info', {})
  857. }
  858. for session_id, session_data in session_info.items()
  859. },
  860. 'conversation_times': {
  861. conversation_id: start_time.isoformat()
  862. for conversation_id, start_time in conversation_times.items()
  863. },
  864. 'conversation_to_session_mapping': conversation_to_session,
  865. 'conversations': {}
  866. }
  867. # 处理每个对话的完整数据
  868. for conversation_id, conversation_data in raw_cache.items():
  869. # 获取时间信息
  870. conversation_start_time = conversation_times.get(conversation_id)
  871. session_id = conversation_to_session.get(conversation_id)
  872. session_start_time = None
  873. if session_id and session_id in session_info:
  874. session_start_time = session_info[session_id]['start_time']
  875. processed_conversation = {
  876. 'conversation_id': conversation_id,
  877. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  878. 'session_id': session_id,
  879. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  880. 'field_count': len(conversation_data),
  881. 'fields': {}
  882. }
  883. # 添加时间计算
  884. if conversation_start_time:
  885. conversation_duration = datetime.now() - conversation_start_time
  886. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  887. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  888. if session_start_time:
  889. session_duration = datetime.now() - session_start_time
  890. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  891. processed_conversation['session_duration_formatted'] = str(session_duration)
  892. # 处理每个字段,确保JSON序列化安全
  893. for field_name, field_value in conversation_data.items():
  894. field_info = {
  895. 'field_name': field_name,
  896. 'data_type': type(field_value).__name__,
  897. 'is_none': field_value is None
  898. }
  899. try:
  900. if field_value is None:
  901. field_info['value'] = None
  902. elif field_name in ['conversation_start_time', 'session_start_time']:
  903. # 处理时间字段
  904. field_info['content'] = make_json_serializable(field_value)
  905. elif field_name == 'df' and field_value is not None:
  906. # DataFrame的安全处理
  907. if hasattr(field_value, 'to_dict'):
  908. # 安全地处理dtypes
  909. try:
  910. dtypes_dict = {}
  911. for col, dtype in field_value.dtypes.items():
  912. dtypes_dict[col] = str(dtype)
  913. except Exception:
  914. dtypes_dict = {"error": "无法序列化dtypes"}
  915. # 安全地处理内存使用
  916. try:
  917. memory_usage = field_value.memory_usage(deep=True)
  918. memory_dict = {}
  919. for idx, usage in memory_usage.items():
  920. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  921. except Exception:
  922. memory_dict = {"error": "无法获取内存使用信息"}
  923. field_info.update({
  924. 'dataframe_info': {
  925. 'shape': list(field_value.shape),
  926. 'columns': list(field_value.columns),
  927. 'dtypes': dtypes_dict,
  928. 'index_info': {
  929. 'type': type(field_value.index).__name__,
  930. 'length': len(field_value.index)
  931. }
  932. },
  933. 'data': make_json_serializable(field_value.to_dict('records')),
  934. 'memory_usage': memory_dict
  935. })
  936. else:
  937. field_info['value'] = str(field_value)
  938. field_info['note'] = 'not_standard_dataframe'
  939. elif field_name == 'fig_json':
  940. # 图表JSON数据处理
  941. if isinstance(field_value, str):
  942. try:
  943. import json
  944. parsed_fig = json.loads(field_value)
  945. field_info.update({
  946. 'json_valid': True,
  947. 'json_size_bytes': len(field_value),
  948. 'plotly_structure': {
  949. 'has_data': 'data' in parsed_fig,
  950. 'has_layout': 'layout' in parsed_fig,
  951. 'data_traces_count': len(parsed_fig.get('data', [])),
  952. },
  953. 'raw_json': field_value
  954. })
  955. except json.JSONDecodeError:
  956. field_info.update({
  957. 'json_valid': False,
  958. 'raw_content': str(field_value)
  959. })
  960. else:
  961. field_info['value'] = make_json_serializable(field_value)
  962. elif field_name == 'followup_questions':
  963. # 后续问题列表
  964. field_info.update({
  965. 'content': make_json_serializable(field_value)
  966. })
  967. elif field_name in ['question', 'sql', 'summary']:
  968. # 文本字段
  969. if isinstance(field_value, str):
  970. field_info.update({
  971. 'text_length': len(field_value),
  972. 'content': field_value
  973. })
  974. else:
  975. field_info['value'] = make_json_serializable(field_value)
  976. else:
  977. # 未知字段的安全处理
  978. field_info['content'] = make_json_serializable(field_value)
  979. except Exception as e:
  980. field_info.update({
  981. 'processing_error': str(e),
  982. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  983. })
  984. processed_conversation['fields'][field_name] = field_info
  985. export_data['conversations'][conversation_id] = processed_conversation
  986. # 添加缓存统计信息
  987. field_frequency = {}
  988. data_types_found = set()
  989. total_dataframes = 0
  990. total_questions = 0
  991. for conv_data in export_data['conversations'].values():
  992. for field_name, field_info in conv_data['fields'].items():
  993. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  994. data_types_found.add(field_info['data_type'])
  995. if field_name == 'df' and not field_info['is_none']:
  996. total_dataframes += 1
  997. if field_name == 'question' and not field_info['is_none']:
  998. total_questions += 1
  999. export_data['cache_statistics'] = {
  1000. 'field_frequency': field_frequency,
  1001. 'data_types_found': list(data_types_found),
  1002. 'total_dataframes': total_dataframes,
  1003. 'total_questions': total_questions,
  1004. 'has_session_timing': 'session_start_time' in field_frequency,
  1005. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1006. }
  1007. from common.result import success_response
  1008. return jsonify(success_response(
  1009. response_text="缓存数据导出完成",
  1010. data=export_data
  1011. ))
  1012. except Exception as e:
  1013. import traceback
  1014. error_details = {
  1015. 'error_message': str(e),
  1016. 'error_type': type(e).__name__,
  1017. 'traceback': traceback.format_exc()
  1018. }
  1019. from common.result import internal_error_response
  1020. return jsonify(internal_error_response(
  1021. response_text="导出缓存失败,请稍后重试"
  1022. )), 500
  1023. # ==================== 清理功能API ====================
  1024. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1025. def cache_preview_cleanup():
  1026. """清理功能:预览删除操作 - 保持原功能"""
  1027. try:
  1028. req = request.get_json(force=True)
  1029. # 时间条件 - 支持三种方式
  1030. older_than_hours = req.get('older_than_hours')
  1031. older_than_days = req.get('older_than_days')
  1032. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1033. cache = app.cache
  1034. # 计算截止时间
  1035. cutoff_time = None
  1036. time_condition = None
  1037. if older_than_hours:
  1038. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1039. time_condition = f"older_than_hours: {older_than_hours}"
  1040. elif older_than_days:
  1041. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1042. time_condition = f"older_than_days: {older_than_days}"
  1043. elif before_timestamp:
  1044. try:
  1045. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1046. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1047. time_condition = f"before_timestamp: {before_timestamp}"
  1048. except ValueError:
  1049. from common.result import validation_failed_response
  1050. return jsonify(validation_failed_response(
  1051. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1052. )), 422
  1053. else:
  1054. from common.result import bad_request_response
  1055. return jsonify(bad_request_response(
  1056. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1057. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1058. )), 400
  1059. preview = {
  1060. 'time_condition': time_condition,
  1061. 'cutoff_time': cutoff_time.isoformat(),
  1062. 'will_be_removed': {
  1063. 'sessions': []
  1064. },
  1065. 'will_be_kept': {
  1066. 'sessions_count': 0,
  1067. 'conversations_count': 0
  1068. },
  1069. 'summary': {
  1070. 'sessions_to_remove': 0,
  1071. 'conversations_to_remove': 0,
  1072. 'sessions_to_keep': 0,
  1073. 'conversations_to_keep': 0
  1074. }
  1075. }
  1076. # 预览按session删除
  1077. sessions_to_remove_count = 0
  1078. conversations_to_remove_count = 0
  1079. for session_id, session_data in cache.session_info.items():
  1080. session_preview = {
  1081. 'session_id': session_id,
  1082. 'start_time': session_data['start_time'].isoformat(),
  1083. 'conversation_count': len(session_data['conversations']),
  1084. 'conversations': []
  1085. }
  1086. # 添加conversation详情
  1087. for conv_id in session_data['conversations']:
  1088. if conv_id in cache.cache:
  1089. conv_data = cache.cache[conv_id]
  1090. session_preview['conversations'].append({
  1091. 'conversation_id': conv_id,
  1092. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1093. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1094. })
  1095. if session_data['start_time'] < cutoff_time:
  1096. preview['will_be_removed']['sessions'].append(session_preview)
  1097. sessions_to_remove_count += 1
  1098. conversations_to_remove_count += len(session_data['conversations'])
  1099. else:
  1100. preview['will_be_kept']['sessions_count'] += 1
  1101. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1102. # 更新摘要统计
  1103. preview['summary'] = {
  1104. 'sessions_to_remove': sessions_to_remove_count,
  1105. 'conversations_to_remove': conversations_to_remove_count,
  1106. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1107. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1108. }
  1109. from common.result import success_response
  1110. return jsonify(success_response(
  1111. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1112. data=preview
  1113. ))
  1114. except Exception as e:
  1115. from common.result import internal_error_response
  1116. return jsonify(internal_error_response(
  1117. response_text="预览清理操作失败,请稍后重试"
  1118. )), 500
  1119. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1120. def cache_cleanup():
  1121. """清理功能:实际删除缓存 - 保持原功能"""
  1122. try:
  1123. req = request.get_json(force=True)
  1124. # 时间条件 - 支持三种方式
  1125. older_than_hours = req.get('older_than_hours')
  1126. older_than_days = req.get('older_than_days')
  1127. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1128. cache = app.cache
  1129. if not hasattr(cache, 'session_info'):
  1130. from common.result import service_unavailable_response
  1131. return jsonify(service_unavailable_response(
  1132. response_text="缓存不支持会话功能"
  1133. )), 503
  1134. # 计算截止时间
  1135. cutoff_time = None
  1136. time_condition = None
  1137. if older_than_hours:
  1138. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1139. time_condition = f"older_than_hours: {older_than_hours}"
  1140. elif older_than_days:
  1141. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1142. time_condition = f"older_than_days: {older_than_days}"
  1143. elif before_timestamp:
  1144. try:
  1145. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1146. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1147. time_condition = f"before_timestamp: {before_timestamp}"
  1148. except ValueError:
  1149. from common.result import validation_failed_response
  1150. return jsonify(validation_failed_response(
  1151. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1152. )), 422
  1153. else:
  1154. from common.result import bad_request_response
  1155. return jsonify(bad_request_response(
  1156. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1157. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1158. )), 400
  1159. cleanup_stats = {
  1160. 'time_condition': time_condition,
  1161. 'cutoff_time': cutoff_time.isoformat(),
  1162. 'sessions_removed': 0,
  1163. 'conversations_removed': 0,
  1164. 'sessions_kept': 0,
  1165. 'conversations_kept': 0,
  1166. 'removed_session_ids': [],
  1167. 'removed_conversation_ids': []
  1168. }
  1169. # 按session删除
  1170. sessions_to_remove = []
  1171. for session_id, session_data in cache.session_info.items():
  1172. if session_data['start_time'] < cutoff_time:
  1173. sessions_to_remove.append(session_id)
  1174. # 删除符合条件的sessions及其所有conversations
  1175. for session_id in sessions_to_remove:
  1176. session_data = cache.session_info[session_id]
  1177. conversations_in_session = session_data['conversations'].copy()
  1178. # 删除session中的所有conversations
  1179. for conv_id in conversations_in_session:
  1180. if conv_id in cache.cache:
  1181. del cache.cache[conv_id]
  1182. cleanup_stats['conversations_removed'] += 1
  1183. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1184. # 清理conversation相关的时间记录
  1185. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1186. del cache.conversation_start_times[conv_id]
  1187. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1188. del cache.conversation_to_session[conv_id]
  1189. # 删除session记录
  1190. del cache.session_info[session_id]
  1191. cleanup_stats['sessions_removed'] += 1
  1192. cleanup_stats['removed_session_ids'].append(session_id)
  1193. # 统计保留的sessions和conversations
  1194. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1195. cleanup_stats['conversations_kept'] = len(cache.cache)
  1196. from common.result import success_response
  1197. return jsonify(success_response(
  1198. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1199. data=cleanup_stats
  1200. ))
  1201. except Exception as e:
  1202. from common.result import internal_error_response
  1203. return jsonify(internal_error_response(
  1204. response_text="缓存清理失败,请稍后重试"
  1205. )), 500
  1206. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1207. def training_error_question_sql():
  1208. """
  1209. 存储错误的question-sql对到error_sql集合中
  1210. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1211. Args:
  1212. question (str, required): 用户问题
  1213. sql (str, required): 对应的错误SQL查询语句
  1214. Returns:
  1215. JSON: 包含训练ID和成功消息的响应
  1216. """
  1217. try:
  1218. data = request.get_json()
  1219. question = data.get('question')
  1220. sql = data.get('sql')
  1221. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1222. if not question or not sql:
  1223. from common.result import bad_request_response
  1224. missing_params = []
  1225. if not question:
  1226. missing_params.append("question")
  1227. if not sql:
  1228. missing_params.append("sql")
  1229. return jsonify(bad_request_response(
  1230. response_text="question和sql参数都是必需的",
  1231. missing_params=missing_params
  1232. )), 400
  1233. # 使用vn实例的train_error_sql方法存储错误SQL
  1234. id = vn.train_error_sql(question=question, sql=sql)
  1235. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1236. from common.result import success_response
  1237. return jsonify(success_response(
  1238. response_text="错误SQL对已成功存储",
  1239. data={
  1240. "id": id,
  1241. "message": "错误SQL对已成功存储到error_sql集合"
  1242. }
  1243. ))
  1244. except Exception as e:
  1245. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1246. from common.result import internal_error_response
  1247. return jsonify(internal_error_response(
  1248. response_text="存储错误SQL失败,请稍后重试"
  1249. )), 500
  1250. # ==================== Redis对话管理API ====================
  1251. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1252. def get_user_conversations(user_id: str):
  1253. """获取用户的对话列表(按时间倒序)"""
  1254. try:
  1255. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1256. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1257. return jsonify(success_response(
  1258. response_text="获取用户对话列表成功",
  1259. data={
  1260. "user_id": user_id,
  1261. "conversations": conversations,
  1262. "total_count": len(conversations)
  1263. }
  1264. ))
  1265. except Exception as e:
  1266. return jsonify(internal_error_response(
  1267. response_text="获取对话列表失败,请稍后重试"
  1268. )), 500
  1269. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1270. def get_conversation_messages(conversation_id: str):
  1271. """获取特定对话的消息历史"""
  1272. try:
  1273. limit = request.args.get('limit', type=int) # 可选参数
  1274. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1275. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1276. return jsonify(success_response(
  1277. response_text="获取对话消息成功",
  1278. data={
  1279. "conversation_id": conversation_id,
  1280. "conversation_meta": meta,
  1281. "messages": messages,
  1282. "message_count": len(messages)
  1283. }
  1284. ))
  1285. except Exception as e:
  1286. return jsonify(internal_error_response(
  1287. response_text="获取对话消息失败"
  1288. )), 500
  1289. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1290. def get_conversation_context(conversation_id: str):
  1291. """获取对话上下文(格式化用于LLM)"""
  1292. try:
  1293. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1294. context = redis_conversation_manager.get_context(conversation_id, count)
  1295. return jsonify(success_response(
  1296. response_text="获取对话上下文成功",
  1297. data={
  1298. "conversation_id": conversation_id,
  1299. "context": context,
  1300. "context_message_count": count
  1301. }
  1302. ))
  1303. except Exception as e:
  1304. return jsonify(internal_error_response(
  1305. response_text="获取对话上下文失败"
  1306. )), 500
  1307. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1308. def conversation_stats():
  1309. """获取对话系统统计信息"""
  1310. try:
  1311. stats = redis_conversation_manager.get_stats()
  1312. return jsonify(success_response(
  1313. response_text="获取统计信息成功",
  1314. data=stats
  1315. ))
  1316. except Exception as e:
  1317. return jsonify(internal_error_response(
  1318. response_text="获取统计信息失败,请稍后重试"
  1319. )), 500
  1320. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1321. def conversation_cleanup():
  1322. """手动清理过期对话"""
  1323. try:
  1324. redis_conversation_manager.cleanup_expired_conversations()
  1325. return jsonify(success_response(
  1326. response_text="对话清理完成"
  1327. ))
  1328. except Exception as e:
  1329. return jsonify(internal_error_response(
  1330. response_text="对话清理失败,请稍后重试"
  1331. )), 500
  1332. # 前端JavaScript示例 - 如何维持会话
  1333. """
  1334. // 前端需要维护一个会话ID
  1335. class ChatSession {
  1336. constructor() {
  1337. // 从localStorage获取或创建新的会话ID
  1338. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1339. localStorage.setItem('chat_session_id', this.sessionId);
  1340. }
  1341. generateSessionId() {
  1342. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1343. }
  1344. async askQuestion(question) {
  1345. const response = await fetch('/api/v0/ask', {
  1346. method: 'POST',
  1347. headers: {
  1348. 'Content-Type': 'application/json',
  1349. },
  1350. body: JSON.stringify({
  1351. question: question,
  1352. session_id: this.sessionId // 关键:传递会话ID
  1353. })
  1354. });
  1355. return await response.json();
  1356. }
  1357. // 开始新会话
  1358. startNewSession() {
  1359. this.sessionId = this.generateSessionId();
  1360. localStorage.setItem('chat_session_id', this.sessionId);
  1361. }
  1362. }
  1363. // 使用示例
  1364. const chatSession = new ChatSession();
  1365. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1366. """
  1367. print("正在启动Flask应用: http://localhost:8084")
  1368. app.run(host="0.0.0.0", port=8084, debug=True)