citu_app.py 72 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  15. from common.result import ( # 统一导入所有需要的响应函数
  16. bad_request_response, service_unavailable_response,
  17. agent_success_response, agent_error_response,
  18. internal_error_response, success_response,
  19. validation_failed_response
  20. )
  21. from app_config import ( # 添加Redis相关配置导入
  22. USER_MAX_CONVERSATIONS,
  23. CONVERSATION_CONTEXT_COUNT,
  24. DEFAULT_ANONYMOUS_USER
  25. )
  26. # 设置默认的最大返回行数
  27. DEFAULT_MAX_RETURN_ROWS = 200
  28. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  29. vn = create_vanna_instance()
  30. # 创建带时间戳的缓存
  31. timestamped_cache = WebSessionAwareMemoryCache()
  32. # 实例化 VannaFlaskApp,使用自定义缓存
  33. app = VannaFlaskApp(
  34. vn,
  35. cache=timestamped_cache, # 使用带时间戳的缓存
  36. title="辞图智能数据问答平台",
  37. logo = "https://www.citupro.com/img/logo-black-2.png",
  38. subtitle="让 AI 为你写 SQL",
  39. chart=False,
  40. allow_llm_to_see_data=True,
  41. ask_results_correct=True,
  42. followup_questions=True,
  43. debug=True
  44. )
  45. # 创建Redis对话管理器实例
  46. redis_conversation_manager = RedisConversationManager()
  47. # 修改ask接口,支持前端传递session_id
  48. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  49. def ask_full():
  50. req = request.get_json(force=True)
  51. question = req.get("question", None)
  52. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  53. if not question:
  54. from common.result import bad_request_response
  55. return jsonify(bad_request_response(
  56. response_text="缺少必需参数:question",
  57. missing_params=["question"]
  58. )), 400
  59. # 如果使用WebSessionAwareMemoryCache
  60. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  61. # 这里需要修改vanna的ask方法来支持传递session_id
  62. # 或者预先调用generate_id来建立会话关联
  63. conversation_id = app.cache.generate_id_with_browser_session(
  64. question=question,
  65. browser_session_id=browser_session_id
  66. )
  67. try:
  68. sql, df, _ = vn.ask(
  69. question=question,
  70. print_results=False,
  71. visualize=False,
  72. allow_llm_to_see_data=True
  73. )
  74. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  75. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  76. # 在解释性文本末尾添加提示语
  77. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  78. # 使用标准化错误响应
  79. from common.result import validation_failed_response
  80. return jsonify(validation_failed_response(
  81. response_text=explanation_message
  82. )), 422 # 修改HTTP状态码为422
  83. # 如果sql为None但没有解释性文本,返回通用错误
  84. if sql is None:
  85. from common.result import validation_failed_response
  86. return jsonify(validation_failed_response(
  87. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  88. )), 422
  89. # 处理返回数据 - 使用新的query_result结构
  90. query_result = {
  91. "rows": [],
  92. "columns": [],
  93. "row_count": 0,
  94. "is_limited": False,
  95. "total_row_count": 0
  96. }
  97. summary = None
  98. if isinstance(df, pd.DataFrame):
  99. query_result["columns"] = list(df.columns)
  100. if not df.empty:
  101. total_rows = len(df)
  102. limited_df = df.head(MAX_RETURN_ROWS)
  103. query_result["rows"] = limited_df.to_dict(orient="records")
  104. query_result["row_count"] = len(limited_df)
  105. query_result["total_row_count"] = total_rows
  106. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  107. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  108. if ENABLE_RESULT_SUMMARY:
  109. try:
  110. summary = vn.generate_summary(question=question, df=df)
  111. print(f"[INFO] 成功生成摘要: {summary}")
  112. except Exception as e:
  113. print(f"[WARNING] 生成摘要失败: {str(e)}")
  114. summary = None
  115. # 构建返回数据
  116. response_data = {
  117. "sql": sql,
  118. "query_result": query_result,
  119. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  120. "session_id": browser_session_id
  121. }
  122. # 添加摘要(如果启用且生成成功)
  123. if ENABLE_RESULT_SUMMARY and summary is not None:
  124. response_data["summary"] = summary
  125. response_data["response"] = summary # 同时添加response字段
  126. from common.result import success_response
  127. return jsonify(success_response(
  128. response_text="查询执行完成" if summary is None else None,
  129. data=response_data
  130. ))
  131. except Exception as e:
  132. print(f"[ERROR] ask_full执行失败: {str(e)}")
  133. # 即使发生异常,也检查是否有业务层面的解释
  134. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  135. # 在解释性文本末尾添加提示语
  136. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  137. from common.result import validation_failed_response
  138. return jsonify(validation_failed_response(
  139. response_text=explanation_message
  140. )), 422
  141. else:
  142. # 技术错误,使用500错误码
  143. from common.result import internal_error_response
  144. return jsonify(internal_error_response(
  145. response_text="查询处理失败,请稍后重试"
  146. )), 500
  147. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  148. def citu_run_sql():
  149. req = request.get_json(force=True)
  150. sql = req.get('sql')
  151. if not sql:
  152. from common.result import bad_request_response
  153. return jsonify(bad_request_response(
  154. response_text="缺少必需参数:sql",
  155. missing_params=["sql"]
  156. )), 400
  157. try:
  158. df = vn.run_sql(sql)
  159. # 处理返回数据 - 使用新的query_result结构
  160. query_result = {
  161. "rows": [],
  162. "columns": [],
  163. "row_count": 0,
  164. "is_limited": False,
  165. "total_row_count": 0
  166. }
  167. if isinstance(df, pd.DataFrame):
  168. query_result["columns"] = list(df.columns)
  169. if not df.empty:
  170. total_rows = len(df)
  171. limited_df = df.head(MAX_RETURN_ROWS)
  172. query_result["rows"] = limited_df.to_dict(orient="records")
  173. query_result["row_count"] = len(limited_df)
  174. query_result["total_row_count"] = total_rows
  175. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  176. from common.result import success_response
  177. return jsonify(success_response(
  178. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  179. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  180. data={
  181. "sql": sql,
  182. "query_result": query_result
  183. }
  184. ))
  185. except Exception as e:
  186. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  187. from common.result import internal_error_response
  188. return jsonify(internal_error_response(
  189. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  190. )), 500
  191. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  192. def ask_cached():
  193. """
  194. 带缓存功能的智能查询接口
  195. 支持会话管理和结果缓存,提高查询效率
  196. """
  197. req = request.get_json(force=True)
  198. question = req.get("question", None)
  199. browser_session_id = req.get("session_id", None)
  200. if not question:
  201. from common.result import bad_request_response
  202. return jsonify(bad_request_response(
  203. response_text="缺少必需参数:question",
  204. missing_params=["question"]
  205. )), 400
  206. try:
  207. # 生成conversation_id
  208. # 调试:查看generate_id的实际行为
  209. print(f"[DEBUG] 输入问题: '{question}'")
  210. conversation_id = app.cache.generate_id(question=question)
  211. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  212. # 再次用相同问题测试
  213. conversation_id2 = app.cache.generate_id(question=question)
  214. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  215. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  216. # 检查缓存
  217. cached_sql = app.cache.get(id=conversation_id, field="sql")
  218. if cached_sql is not None:
  219. # 缓存命中
  220. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  221. sql = cached_sql
  222. df = app.cache.get(id=conversation_id, field="df")
  223. summary = app.cache.get(id=conversation_id, field="summary")
  224. else:
  225. # 缓存未命中,执行新查询
  226. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  227. sql, df, _ = vn.ask(
  228. question=question,
  229. print_results=False,
  230. visualize=False,
  231. allow_llm_to_see_data=True
  232. )
  233. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  234. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  235. # 在解释性文本末尾添加提示语
  236. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  237. from common.result import validation_failed_response
  238. return jsonify(validation_failed_response(
  239. response_text=explanation_message
  240. )), 422
  241. # 如果sql为None但没有解释性文本,返回通用错误
  242. if sql is None:
  243. from common.result import validation_failed_response
  244. return jsonify(validation_failed_response(
  245. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  246. )), 422
  247. # 缓存结果
  248. app.cache.set(id=conversation_id, field="question", value=question)
  249. app.cache.set(id=conversation_id, field="sql", value=sql)
  250. app.cache.set(id=conversation_id, field="df", value=df)
  251. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  252. summary = None
  253. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  254. try:
  255. summary = vn.generate_summary(question=question, df=df)
  256. print(f"[INFO] 成功生成摘要: {summary}")
  257. except Exception as e:
  258. print(f"[WARNING] 生成摘要失败: {str(e)}")
  259. summary = None
  260. app.cache.set(id=conversation_id, field="summary", value=summary)
  261. # 处理返回数据 - 使用新的query_result结构
  262. query_result = {
  263. "rows": [],
  264. "columns": [],
  265. "row_count": 0,
  266. "is_limited": False,
  267. "total_row_count": 0
  268. }
  269. if isinstance(df, pd.DataFrame):
  270. query_result["columns"] = list(df.columns)
  271. if not df.empty:
  272. total_rows = len(df)
  273. limited_df = df.head(MAX_RETURN_ROWS)
  274. query_result["rows"] = limited_df.to_dict(orient="records")
  275. query_result["row_count"] = len(limited_df)
  276. query_result["total_row_count"] = total_rows
  277. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  278. # 构建返回数据
  279. response_data = {
  280. "sql": sql,
  281. "query_result": query_result,
  282. "conversation_id": conversation_id,
  283. "session_id": browser_session_id,
  284. "cached": cached_sql is not None # 标识是否来自缓存
  285. }
  286. # 添加摘要(如果启用且生成成功)
  287. if ENABLE_RESULT_SUMMARY and summary is not None:
  288. response_data["summary"] = summary
  289. response_data["response"] = summary # 同时添加response字段
  290. from common.result import success_response
  291. return jsonify(success_response(
  292. response_text="查询执行完成" if summary is None else None,
  293. data=response_data
  294. ))
  295. except Exception as e:
  296. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  297. from common.result import internal_error_response
  298. return jsonify(internal_error_response(
  299. response_text="查询处理失败,请稍后重试"
  300. )), 500
  301. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  302. def citu_train_question_sql():
  303. """
  304. 训练问题-SQL对接口
  305. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  306. 支持仅传入SQL或同时传入问题和SQL进行训练。
  307. Args:
  308. question (str, optional): 用户问题
  309. sql (str, required): 对应的SQL查询语句
  310. Returns:
  311. JSON: 包含训练ID和成功消息的响应
  312. """
  313. try:
  314. req = request.get_json(force=True)
  315. question = req.get('question')
  316. sql = req.get('sql')
  317. if not sql:
  318. from common.result import bad_request_response
  319. return jsonify(bad_request_response(
  320. response_text="缺少必需参数:sql",
  321. missing_params=["sql"]
  322. )), 400
  323. # 正确的调用方式:同时传递question和sql
  324. if question:
  325. training_id = vn.train(question=question, sql=sql)
  326. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  327. else:
  328. training_id = vn.train(sql=sql)
  329. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  330. from common.result import success_response
  331. return jsonify(success_response(
  332. response_text="问题-SQL对训练成功",
  333. data={
  334. "training_id": training_id,
  335. "message": "Question-SQL pair trained successfully"
  336. }
  337. ))
  338. except Exception as e:
  339. from common.result import internal_error_response
  340. return jsonify(internal_error_response(
  341. response_text="训练失败,请稍后重试"
  342. )), 500
  343. # ============ LangGraph Agent 集成 ============
  344. # 全局Agent实例(单例模式)
  345. citu_langraph_agent = None
  346. def get_citu_langraph_agent():
  347. """获取LangGraph Agent实例(懒加载)"""
  348. global citu_langraph_agent
  349. if citu_langraph_agent is None:
  350. try:
  351. from agent.citu_agent import CituLangGraphAgent
  352. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  353. citu_langraph_agent = CituLangGraphAgent()
  354. print("[CITU_APP] LangGraph Agent实例创建成功")
  355. except ImportError as e:
  356. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  357. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  358. raise Exception(f"Agent模块导入失败: {str(e)}")
  359. except Exception as e:
  360. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  361. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  362. # 提供更有用的错误信息
  363. if "config" in str(e).lower():
  364. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  365. elif "llm" in str(e).lower():
  366. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  367. elif "tool" in str(e).lower():
  368. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  369. raise Exception(f"Agent初始化失败: {str(e)}")
  370. return citu_langraph_agent
  371. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  372. def ask_agent():
  373. """
  374. 支持对话上下文的ask_agent API - 修正版
  375. """
  376. req = request.get_json(force=True)
  377. question = req.get("question", None)
  378. browser_session_id = req.get("session_id", None)
  379. # 新增参数解析
  380. user_id_input = req.get("user_id", None)
  381. conversation_id_input = req.get("conversation_id", None)
  382. continue_conversation = req.get("continue_conversation", False)
  383. if not question:
  384. return jsonify(bad_request_response(
  385. response_text="缺少必需参数:question",
  386. missing_params=["question"]
  387. )), 400
  388. try:
  389. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  390. login_user_id = session.get('user_id') if 'user_id' in session else None
  391. # 2. 智能ID解析(修正:传入登录用户ID)
  392. user_id = redis_conversation_manager.resolve_user_id(
  393. user_id_input, browser_session_id, request.remote_addr, login_user_id
  394. )
  395. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  396. user_id, conversation_id_input, continue_conversation
  397. )
  398. # 3. 获取上下文和上下文类型(提前到缓存检查之前)
  399. context = redis_conversation_manager.get_context(conversation_id)
  400. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  401. context_type = None
  402. if context:
  403. try:
  404. # 获取最后一条助手消息的metadata
  405. messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
  406. for message in reversed(messages): # 从最新的开始找
  407. if message.get("role") == "assistant":
  408. metadata = message.get("metadata", {})
  409. context_type = metadata.get("type")
  410. if context_type:
  411. print(f"[AGENT_API] 检测到上下文类型: {context_type}")
  412. break
  413. except Exception as e:
  414. print(f"[WARNING] 获取上下文类型失败: {str(e)}")
  415. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  416. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  417. if cached_answer:
  418. print(f"[AGENT_API] 使用缓存答案")
  419. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  420. cached_response_type = cached_answer.get("type", "UNKNOWN")
  421. if cached_response_type == "DATABASE":
  422. # DATABASE类型:按优先级选择内容
  423. if cached_answer.get("response"):
  424. # 优先级1:错误或解释性回复(如SQL生成失败)
  425. assistant_response = cached_answer.get("response")
  426. elif cached_answer.get("summary"):
  427. # 优先级2:查询成功的摘要
  428. assistant_response = cached_answer.get("summary")
  429. elif cached_answer.get("query_result"):
  430. # 优先级3:构造简单描述
  431. query_result = cached_answer.get("query_result")
  432. row_count = query_result.get("row_count", 0)
  433. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  434. else:
  435. # 异常情况
  436. assistant_response = "数据库查询已处理。"
  437. else:
  438. # CHAT类型:直接使用response
  439. assistant_response = cached_answer.get("response", "")
  440. # 更新对话历史
  441. redis_conversation_manager.save_message(conversation_id, "user", question)
  442. redis_conversation_manager.save_message(
  443. conversation_id, "assistant",
  444. assistant_response,
  445. metadata={"from_cache": True}
  446. )
  447. # 添加对话信息到缓存结果
  448. cached_answer["conversation_id"] = conversation_id
  449. cached_answer["user_id"] = user_id
  450. cached_answer["from_cache"] = True
  451. cached_answer.update(conversation_status)
  452. # 使用agent_success_response返回标准格式
  453. return jsonify(agent_success_response(
  454. response_type=cached_answer.get("type", "UNKNOWN"),
  455. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  456. sql=cached_answer.get("sql"),
  457. query_result=cached_answer.get("query_result"),
  458. summary=cached_answer.get("summary"),
  459. session_id=browser_session_id,
  460. execution_path=cached_answer.get("execution_path", []),
  461. classification_info=cached_answer.get("classification_info", {}),
  462. conversation_id=conversation_id,
  463. user_id=user_id,
  464. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  465. context_used=bool(context),
  466. from_cache=True,
  467. conversation_status=conversation_status["status"],
  468. conversation_message=conversation_status["message"],
  469. requested_conversation_id=conversation_status.get("requested_id")
  470. ))
  471. # 5. 保存用户消息
  472. redis_conversation_manager.save_message(conversation_id, "user", question)
  473. # 6. 构建带上下文的问题
  474. if context:
  475. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  476. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  477. else:
  478. enhanced_question = question
  479. print(f"[AGENT_API] 新对话,无上下文")
  480. # 7. 现有Agent处理逻辑(保持不变)
  481. try:
  482. agent = get_citu_langraph_agent()
  483. except Exception as e:
  484. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  485. return jsonify(service_unavailable_response(
  486. response_text="AI服务暂时不可用,请稍后重试",
  487. can_retry=True
  488. )), 503
  489. agent_result = agent.process_question(
  490. question=enhanced_question, # 使用增强后的问题
  491. session_id=browser_session_id,
  492. context_type=context_type # 传递上下文类型
  493. )
  494. # 8. 处理Agent结果
  495. if agent_result.get("success", False):
  496. # 修正:直接从agent_result获取字段,因为它就是final_response
  497. response_type = agent_result.get("type", "UNKNOWN")
  498. response_text = agent_result.get("response", "")
  499. sql = agent_result.get("sql")
  500. query_result = agent_result.get("query_result")
  501. summary = agent_result.get("summary")
  502. execution_path = agent_result.get("execution_path", [])
  503. classification_info = agent_result.get("classification_info", {})
  504. # 确定助手回复内容的优先级
  505. if response_type == "DATABASE":
  506. # DATABASE类型:按优先级选择内容
  507. if response_text:
  508. # 优先级1:错误或解释性回复(如SQL生成失败)
  509. assistant_response = response_text
  510. elif summary:
  511. # 优先级2:查询成功的摘要
  512. assistant_response = summary
  513. elif query_result:
  514. # 优先级3:构造简单描述
  515. row_count = query_result.get("row_count", 0)
  516. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  517. else:
  518. # 异常情况
  519. assistant_response = "数据库查询已处理。"
  520. else:
  521. # CHAT类型:直接使用response
  522. assistant_response = response_text
  523. # 保存助手回复
  524. redis_conversation_manager.save_message(
  525. conversation_id, "assistant", assistant_response,
  526. metadata={
  527. "type": response_type,
  528. "sql": sql,
  529. "execution_path": execution_path
  530. }
  531. )
  532. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  533. # 直接缓存agent_result,它已经包含所有需要的字段
  534. redis_conversation_manager.cache_answer(question, agent_result, context)
  535. # 使用agent_success_response的正确方式
  536. return jsonify(agent_success_response(
  537. response_type=response_type,
  538. response=response_text, # 修正:使用response而不是response_text
  539. sql=sql,
  540. query_result=query_result,
  541. summary=summary,
  542. session_id=browser_session_id,
  543. execution_path=execution_path,
  544. classification_info=classification_info,
  545. conversation_id=conversation_id,
  546. user_id=user_id,
  547. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  548. context_used=bool(context),
  549. from_cache=False,
  550. conversation_status=conversation_status["status"],
  551. conversation_message=conversation_status["message"],
  552. requested_conversation_id=conversation_status.get("requested_id")
  553. ))
  554. else:
  555. # 错误处理(修正:确保使用现有的错误响应格式)
  556. error_message = agent_result.get("error", "Agent处理失败")
  557. error_code = agent_result.get("error_code", 500)
  558. return jsonify(agent_error_response(
  559. response_text=error_message,
  560. error_type="agent_processing_failed",
  561. code=error_code,
  562. session_id=browser_session_id,
  563. conversation_id=conversation_id,
  564. user_id=user_id
  565. )), error_code
  566. except Exception as e:
  567. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  568. return jsonify(internal_error_response(
  569. response_text="查询处理失败,请稍后重试"
  570. )), 500
  571. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  572. def agent_health():
  573. """
  574. Agent健康检查接口
  575. 响应格式:
  576. {
  577. "success": true/false,
  578. "code": 200/503,
  579. "message": "healthy/degraded/unhealthy",
  580. "data": {
  581. "status": "healthy/degraded/unhealthy",
  582. "test_result": true/false,
  583. "workflow_compiled": true/false,
  584. "tools_count": 4,
  585. "message": "详细信息",
  586. "timestamp": "2024-01-01T12:00:00",
  587. "checks": {
  588. "agent_creation": true/false,
  589. "tools_import": true/false,
  590. "llm_connection": true/false,
  591. "classifier_ready": true/false
  592. }
  593. }
  594. }
  595. """
  596. try:
  597. # 基础健康检查
  598. health_data = {
  599. "status": "unknown",
  600. "test_result": False,
  601. "workflow_compiled": False,
  602. "tools_count": 0,
  603. "message": "",
  604. "timestamp": datetime.now().isoformat(),
  605. "checks": {
  606. "agent_creation": False,
  607. "tools_import": False,
  608. "llm_connection": False,
  609. "classifier_ready": False
  610. }
  611. }
  612. # 检查1: Agent创建
  613. try:
  614. agent = get_citu_langraph_agent()
  615. health_data["checks"]["agent_creation"] = True
  616. health_data["workflow_compiled"] = agent.workflow is not None
  617. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  618. except Exception as e:
  619. health_data["message"] = f"Agent创建失败: {str(e)}"
  620. from common.result import health_error_response
  621. return jsonify(health_error_response(
  622. status="unhealthy",
  623. **health_data
  624. )), 503
  625. # 检查2: 工具导入
  626. try:
  627. from agent.tools import TOOLS
  628. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  629. except Exception as e:
  630. health_data["message"] = f"工具导入失败: {str(e)}"
  631. # 检查3: LLM连接(简单测试)
  632. try:
  633. from agent.utils import get_compatible_llm
  634. llm = get_compatible_llm()
  635. health_data["checks"]["llm_connection"] = llm is not None
  636. except Exception as e:
  637. health_data["message"] = f"LLM连接失败: {str(e)}"
  638. # 检查4: 分类器准备
  639. try:
  640. from agent.classifier import QuestionClassifier
  641. classifier = QuestionClassifier()
  642. health_data["checks"]["classifier_ready"] = True
  643. except Exception as e:
  644. health_data["message"] = f"分类器失败: {str(e)}"
  645. # 检查5: 完整流程测试(可选)
  646. try:
  647. if all(health_data["checks"].values()):
  648. test_result = agent.health_check()
  649. health_data["test_result"] = test_result.get("status") == "healthy"
  650. health_data["status"] = test_result.get("status", "unknown")
  651. health_data["message"] = test_result.get("message", "健康检查完成")
  652. else:
  653. health_data["status"] = "degraded"
  654. health_data["message"] = "部分组件异常"
  655. except Exception as e:
  656. health_data["status"] = "degraded"
  657. health_data["message"] = f"完整测试失败: {str(e)}"
  658. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  659. from common.result import health_success_response, health_error_response
  660. if health_data["status"] == "healthy":
  661. return jsonify(health_success_response(**health_data))
  662. elif health_data["status"] == "degraded":
  663. return jsonify(health_error_response(status="degraded", **health_data)), 503
  664. else:
  665. return jsonify(health_error_response(status="unhealthy", **health_data)), 503
  666. except Exception as e:
  667. print(f"[ERROR] 健康检查异常: {str(e)}")
  668. from common.result import internal_error_response
  669. return jsonify(internal_error_response(
  670. response_text="健康检查失败,请稍后重试"
  671. )), 500
  672. # ==================== 日常管理API ====================
  673. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  674. def cache_overview():
  675. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  676. try:
  677. cache = app.cache
  678. result_data = {
  679. 'overview_summary': {
  680. 'total_conversations': 0,
  681. 'total_sessions': 0,
  682. 'query_time': datetime.now().isoformat()
  683. },
  684. 'recent_conversations': [], # 最近的对话
  685. 'session_summary': [] # 会话摘要
  686. }
  687. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  688. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  689. # 获取会话信息
  690. if hasattr(cache, 'get_all_sessions'):
  691. all_sessions = cache.get_all_sessions()
  692. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  693. # 会话摘要(按最近活动排序)
  694. session_list = []
  695. for session_id, session_data in all_sessions.items():
  696. session_summary = {
  697. 'session_id': session_id,
  698. 'start_time': session_data['start_time'].isoformat(),
  699. 'conversation_count': session_data.get('conversation_count', 0),
  700. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  701. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  702. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  703. }
  704. session_list.append(session_summary)
  705. # 按最后活动时间排序
  706. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  707. result_data['session_summary'] = session_list
  708. # 最近的对话(最多显示10个)
  709. conversation_list = []
  710. for conversation_id, conversation_data in cache.cache.items():
  711. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  712. conversation_info = {
  713. 'conversation_id': conversation_id,
  714. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  715. 'session_id': cache.conversation_to_session.get(conversation_id),
  716. 'has_question': 'question' in conversation_data,
  717. 'has_sql': 'sql' in conversation_data,
  718. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  719. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  720. }
  721. # 计算对话持续时间
  722. if conversation_start_time:
  723. duration = datetime.now() - conversation_start_time
  724. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  725. conversation_list.append(conversation_info)
  726. # 按对话开始时间排序,显示最新的10个
  727. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  728. result_data['recent_conversations'] = conversation_list[:10]
  729. from common.result import success_response
  730. return jsonify(success_response(
  731. response_text="缓存概览查询完成",
  732. data=result_data
  733. ))
  734. except Exception as e:
  735. from common.result import internal_error_response
  736. return jsonify(internal_error_response(
  737. response_text="获取缓存概览失败,请稍后重试"
  738. )), 500
  739. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  740. def cache_stats():
  741. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  742. try:
  743. cache = app.cache
  744. current_time = datetime.now()
  745. stats = {
  746. 'basic_stats': {
  747. 'total_sessions': len(getattr(cache, 'session_info', {})),
  748. 'total_conversations': len(getattr(cache, 'cache', {})),
  749. 'active_sessions': 0, # 最近30分钟有活动
  750. 'average_conversations_per_session': 0
  751. },
  752. 'time_distribution': {
  753. 'sessions': {
  754. 'last_1_hour': 0,
  755. 'last_6_hours': 0,
  756. 'last_24_hours': 0,
  757. 'last_7_days': 0,
  758. 'older': 0
  759. },
  760. 'conversations': {
  761. 'last_1_hour': 0,
  762. 'last_6_hours': 0,
  763. 'last_24_hours': 0,
  764. 'last_7_days': 0,
  765. 'older': 0
  766. }
  767. },
  768. 'session_details': [],
  769. 'time_ranges': {
  770. 'oldest_session': None,
  771. 'newest_session': None,
  772. 'oldest_conversation': None,
  773. 'newest_conversation': None
  774. }
  775. }
  776. # 会话统计
  777. if hasattr(cache, 'session_info'):
  778. session_times = []
  779. total_conversations = 0
  780. for session_id, session_data in cache.session_info.items():
  781. start_time = session_data['start_time']
  782. session_times.append(start_time)
  783. conversation_count = len(session_data.get('conversations', []))
  784. total_conversations += conversation_count
  785. # 检查活跃状态
  786. last_activity = session_data.get('last_activity', session_data['start_time'])
  787. if (current_time - last_activity).total_seconds() < 1800:
  788. stats['basic_stats']['active_sessions'] += 1
  789. # 时间分布统计
  790. age_hours = (current_time - start_time).total_seconds() / 3600
  791. if age_hours <= 1:
  792. stats['time_distribution']['sessions']['last_1_hour'] += 1
  793. elif age_hours <= 6:
  794. stats['time_distribution']['sessions']['last_6_hours'] += 1
  795. elif age_hours <= 24:
  796. stats['time_distribution']['sessions']['last_24_hours'] += 1
  797. elif age_hours <= 168: # 7 days
  798. stats['time_distribution']['sessions']['last_7_days'] += 1
  799. else:
  800. stats['time_distribution']['sessions']['older'] += 1
  801. # 会话详细信息
  802. session_duration = current_time - start_time
  803. stats['session_details'].append({
  804. 'session_id': session_id,
  805. 'start_time': start_time.isoformat(),
  806. 'last_activity': last_activity.isoformat(),
  807. 'conversation_count': conversation_count,
  808. 'duration_seconds': session_duration.total_seconds(),
  809. 'duration_formatted': str(session_duration),
  810. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  811. 'browser_session_id': session_data.get('browser_session_id')
  812. })
  813. # 计算平均值
  814. if len(cache.session_info) > 0:
  815. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  816. # 时间范围
  817. if session_times:
  818. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  819. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  820. # 对话统计
  821. if hasattr(cache, 'conversation_start_times'):
  822. conversation_times = []
  823. for conv_time in cache.conversation_start_times.values():
  824. conversation_times.append(conv_time)
  825. age_hours = (current_time - conv_time).total_seconds() / 3600
  826. if age_hours <= 1:
  827. stats['time_distribution']['conversations']['last_1_hour'] += 1
  828. elif age_hours <= 6:
  829. stats['time_distribution']['conversations']['last_6_hours'] += 1
  830. elif age_hours <= 24:
  831. stats['time_distribution']['conversations']['last_24_hours'] += 1
  832. elif age_hours <= 168:
  833. stats['time_distribution']['conversations']['last_7_days'] += 1
  834. else:
  835. stats['time_distribution']['conversations']['older'] += 1
  836. if conversation_times:
  837. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  838. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  839. # 按最近活动排序会话详情
  840. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  841. from common.result import success_response
  842. return jsonify(success_response(
  843. response_text="缓存统计信息查询完成",
  844. data=stats
  845. ))
  846. except Exception as e:
  847. from common.result import internal_error_response
  848. return jsonify(internal_error_response(
  849. response_text="获取缓存统计失败,请稍后重试"
  850. )), 500
  851. # ==================== 高级功能API ====================
  852. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  853. def cache_export():
  854. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  855. try:
  856. cache = app.cache
  857. # 验证缓存的实际结构
  858. if not hasattr(cache, 'cache'):
  859. from common.result import internal_error_response
  860. return jsonify(internal_error_response(
  861. response_text="缓存对象结构异常,请联系系统管理员"
  862. )), 500
  863. if not isinstance(cache.cache, dict):
  864. from common.result import internal_error_response
  865. return jsonify(internal_error_response(
  866. response_text="缓存数据类型异常,请联系系统管理员"
  867. )), 500
  868. # 定义JSON序列化辅助函数
  869. def make_json_serializable(obj):
  870. """将对象转换为JSON可序列化的格式"""
  871. if obj is None:
  872. return None
  873. elif isinstance(obj, (str, int, float, bool)):
  874. return obj
  875. elif isinstance(obj, (list, tuple)):
  876. return [make_json_serializable(item) for item in obj]
  877. elif isinstance(obj, dict):
  878. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  879. elif hasattr(obj, 'isoformat'): # datetime objects
  880. return obj.isoformat()
  881. elif hasattr(obj, 'item'): # numpy scalars
  882. return obj.item()
  883. elif hasattr(obj, 'tolist'): # numpy arrays
  884. return obj.tolist()
  885. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  886. return str(obj)
  887. else:
  888. return str(obj)
  889. # 获取完整的原始缓存数据
  890. raw_cache = cache.cache
  891. # 获取会话和对话时间信息
  892. conversation_times = getattr(cache, 'conversation_start_times', {})
  893. session_info = getattr(cache, 'session_info', {})
  894. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  895. export_data = {
  896. 'export_metadata': {
  897. 'export_time': datetime.now().isoformat(),
  898. 'total_conversations': len(raw_cache),
  899. 'total_sessions': len(session_info),
  900. 'cache_type': type(cache).__name__,
  901. 'cache_object_info': str(cache),
  902. 'has_session_times': bool(session_info),
  903. 'has_conversation_times': bool(conversation_times)
  904. },
  905. 'session_info': {
  906. session_id: {
  907. 'start_time': session_data['start_time'].isoformat(),
  908. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  909. 'conversations': session_data['conversations'],
  910. 'conversation_count': len(session_data['conversations']),
  911. 'browser_session_id': session_data.get('browser_session_id'),
  912. 'user_info': session_data.get('user_info', {})
  913. }
  914. for session_id, session_data in session_info.items()
  915. },
  916. 'conversation_times': {
  917. conversation_id: start_time.isoformat()
  918. for conversation_id, start_time in conversation_times.items()
  919. },
  920. 'conversation_to_session_mapping': conversation_to_session,
  921. 'conversations': {}
  922. }
  923. # 处理每个对话的完整数据
  924. for conversation_id, conversation_data in raw_cache.items():
  925. # 获取时间信息
  926. conversation_start_time = conversation_times.get(conversation_id)
  927. session_id = conversation_to_session.get(conversation_id)
  928. session_start_time = None
  929. if session_id and session_id in session_info:
  930. session_start_time = session_info[session_id]['start_time']
  931. processed_conversation = {
  932. 'conversation_id': conversation_id,
  933. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  934. 'session_id': session_id,
  935. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  936. 'field_count': len(conversation_data),
  937. 'fields': {}
  938. }
  939. # 添加时间计算
  940. if conversation_start_time:
  941. conversation_duration = datetime.now() - conversation_start_time
  942. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  943. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  944. if session_start_time:
  945. session_duration = datetime.now() - session_start_time
  946. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  947. processed_conversation['session_duration_formatted'] = str(session_duration)
  948. # 处理每个字段,确保JSON序列化安全
  949. for field_name, field_value in conversation_data.items():
  950. field_info = {
  951. 'field_name': field_name,
  952. 'data_type': type(field_value).__name__,
  953. 'is_none': field_value is None
  954. }
  955. try:
  956. if field_value is None:
  957. field_info['value'] = None
  958. elif field_name in ['conversation_start_time', 'session_start_time']:
  959. # 处理时间字段
  960. field_info['content'] = make_json_serializable(field_value)
  961. elif field_name == 'df' and field_value is not None:
  962. # DataFrame的安全处理
  963. if hasattr(field_value, 'to_dict'):
  964. # 安全地处理dtypes
  965. try:
  966. dtypes_dict = {}
  967. for col, dtype in field_value.dtypes.items():
  968. dtypes_dict[col] = str(dtype)
  969. except Exception:
  970. dtypes_dict = {"error": "无法序列化dtypes"}
  971. # 安全地处理内存使用
  972. try:
  973. memory_usage = field_value.memory_usage(deep=True)
  974. memory_dict = {}
  975. for idx, usage in memory_usage.items():
  976. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  977. except Exception:
  978. memory_dict = {"error": "无法获取内存使用信息"}
  979. field_info.update({
  980. 'dataframe_info': {
  981. 'shape': list(field_value.shape),
  982. 'columns': list(field_value.columns),
  983. 'dtypes': dtypes_dict,
  984. 'index_info': {
  985. 'type': type(field_value.index).__name__,
  986. 'length': len(field_value.index)
  987. }
  988. },
  989. 'data': make_json_serializable(field_value.to_dict('records')),
  990. 'memory_usage': memory_dict
  991. })
  992. else:
  993. field_info['value'] = str(field_value)
  994. field_info['note'] = 'not_standard_dataframe'
  995. elif field_name == 'fig_json':
  996. # 图表JSON数据处理
  997. if isinstance(field_value, str):
  998. try:
  999. import json
  1000. parsed_fig = json.loads(field_value)
  1001. field_info.update({
  1002. 'json_valid': True,
  1003. 'json_size_bytes': len(field_value),
  1004. 'plotly_structure': {
  1005. 'has_data': 'data' in parsed_fig,
  1006. 'has_layout': 'layout' in parsed_fig,
  1007. 'data_traces_count': len(parsed_fig.get('data', [])),
  1008. },
  1009. 'raw_json': field_value
  1010. })
  1011. except json.JSONDecodeError:
  1012. field_info.update({
  1013. 'json_valid': False,
  1014. 'raw_content': str(field_value)
  1015. })
  1016. else:
  1017. field_info['value'] = make_json_serializable(field_value)
  1018. elif field_name == 'followup_questions':
  1019. # 后续问题列表
  1020. field_info.update({
  1021. 'content': make_json_serializable(field_value)
  1022. })
  1023. elif field_name in ['question', 'sql', 'summary']:
  1024. # 文本字段
  1025. if isinstance(field_value, str):
  1026. field_info.update({
  1027. 'text_length': len(field_value),
  1028. 'content': field_value
  1029. })
  1030. else:
  1031. field_info['value'] = make_json_serializable(field_value)
  1032. else:
  1033. # 未知字段的安全处理
  1034. field_info['content'] = make_json_serializable(field_value)
  1035. except Exception as e:
  1036. field_info.update({
  1037. 'processing_error': str(e),
  1038. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1039. })
  1040. processed_conversation['fields'][field_name] = field_info
  1041. export_data['conversations'][conversation_id] = processed_conversation
  1042. # 添加缓存统计信息
  1043. field_frequency = {}
  1044. data_types_found = set()
  1045. total_dataframes = 0
  1046. total_questions = 0
  1047. for conv_data in export_data['conversations'].values():
  1048. for field_name, field_info in conv_data['fields'].items():
  1049. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1050. data_types_found.add(field_info['data_type'])
  1051. if field_name == 'df' and not field_info['is_none']:
  1052. total_dataframes += 1
  1053. if field_name == 'question' and not field_info['is_none']:
  1054. total_questions += 1
  1055. export_data['cache_statistics'] = {
  1056. 'field_frequency': field_frequency,
  1057. 'data_types_found': list(data_types_found),
  1058. 'total_dataframes': total_dataframes,
  1059. 'total_questions': total_questions,
  1060. 'has_session_timing': 'session_start_time' in field_frequency,
  1061. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1062. }
  1063. from common.result import success_response
  1064. return jsonify(success_response(
  1065. response_text="缓存数据导出完成",
  1066. data=export_data
  1067. ))
  1068. except Exception as e:
  1069. import traceback
  1070. error_details = {
  1071. 'error_message': str(e),
  1072. 'error_type': type(e).__name__,
  1073. 'traceback': traceback.format_exc()
  1074. }
  1075. from common.result import internal_error_response
  1076. return jsonify(internal_error_response(
  1077. response_text="导出缓存失败,请稍后重试"
  1078. )), 500
  1079. # ==================== 清理功能API ====================
  1080. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1081. def cache_preview_cleanup():
  1082. """清理功能:预览删除操作 - 保持原功能"""
  1083. try:
  1084. req = request.get_json(force=True)
  1085. # 时间条件 - 支持三种方式
  1086. older_than_hours = req.get('older_than_hours')
  1087. older_than_days = req.get('older_than_days')
  1088. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1089. cache = app.cache
  1090. # 计算截止时间
  1091. cutoff_time = None
  1092. time_condition = None
  1093. if older_than_hours:
  1094. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1095. time_condition = f"older_than_hours: {older_than_hours}"
  1096. elif older_than_days:
  1097. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1098. time_condition = f"older_than_days: {older_than_days}"
  1099. elif before_timestamp:
  1100. try:
  1101. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1102. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1103. time_condition = f"before_timestamp: {before_timestamp}"
  1104. except ValueError:
  1105. from common.result import validation_failed_response
  1106. return jsonify(validation_failed_response(
  1107. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1108. )), 422
  1109. else:
  1110. from common.result import bad_request_response
  1111. return jsonify(bad_request_response(
  1112. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1113. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1114. )), 400
  1115. preview = {
  1116. 'time_condition': time_condition,
  1117. 'cutoff_time': cutoff_time.isoformat(),
  1118. 'will_be_removed': {
  1119. 'sessions': []
  1120. },
  1121. 'will_be_kept': {
  1122. 'sessions_count': 0,
  1123. 'conversations_count': 0
  1124. },
  1125. 'summary': {
  1126. 'sessions_to_remove': 0,
  1127. 'conversations_to_remove': 0,
  1128. 'sessions_to_keep': 0,
  1129. 'conversations_to_keep': 0
  1130. }
  1131. }
  1132. # 预览按session删除
  1133. sessions_to_remove_count = 0
  1134. conversations_to_remove_count = 0
  1135. for session_id, session_data in cache.session_info.items():
  1136. session_preview = {
  1137. 'session_id': session_id,
  1138. 'start_time': session_data['start_time'].isoformat(),
  1139. 'conversation_count': len(session_data['conversations']),
  1140. 'conversations': []
  1141. }
  1142. # 添加conversation详情
  1143. for conv_id in session_data['conversations']:
  1144. if conv_id in cache.cache:
  1145. conv_data = cache.cache[conv_id]
  1146. session_preview['conversations'].append({
  1147. 'conversation_id': conv_id,
  1148. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1149. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1150. })
  1151. if session_data['start_time'] < cutoff_time:
  1152. preview['will_be_removed']['sessions'].append(session_preview)
  1153. sessions_to_remove_count += 1
  1154. conversations_to_remove_count += len(session_data['conversations'])
  1155. else:
  1156. preview['will_be_kept']['sessions_count'] += 1
  1157. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1158. # 更新摘要统计
  1159. preview['summary'] = {
  1160. 'sessions_to_remove': sessions_to_remove_count,
  1161. 'conversations_to_remove': conversations_to_remove_count,
  1162. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1163. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1164. }
  1165. from common.result import success_response
  1166. return jsonify(success_response(
  1167. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1168. data=preview
  1169. ))
  1170. except Exception as e:
  1171. from common.result import internal_error_response
  1172. return jsonify(internal_error_response(
  1173. response_text="预览清理操作失败,请稍后重试"
  1174. )), 500
  1175. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1176. def cache_cleanup():
  1177. """清理功能:实际删除缓存 - 保持原功能"""
  1178. try:
  1179. req = request.get_json(force=True)
  1180. # 时间条件 - 支持三种方式
  1181. older_than_hours = req.get('older_than_hours')
  1182. older_than_days = req.get('older_than_days')
  1183. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1184. cache = app.cache
  1185. if not hasattr(cache, 'session_info'):
  1186. from common.result import service_unavailable_response
  1187. return jsonify(service_unavailable_response(
  1188. response_text="缓存不支持会话功能"
  1189. )), 503
  1190. # 计算截止时间
  1191. cutoff_time = None
  1192. time_condition = None
  1193. if older_than_hours:
  1194. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1195. time_condition = f"older_than_hours: {older_than_hours}"
  1196. elif older_than_days:
  1197. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1198. time_condition = f"older_than_days: {older_than_days}"
  1199. elif before_timestamp:
  1200. try:
  1201. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1202. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1203. time_condition = f"before_timestamp: {before_timestamp}"
  1204. except ValueError:
  1205. from common.result import validation_failed_response
  1206. return jsonify(validation_failed_response(
  1207. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1208. )), 422
  1209. else:
  1210. from common.result import bad_request_response
  1211. return jsonify(bad_request_response(
  1212. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1213. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1214. )), 400
  1215. cleanup_stats = {
  1216. 'time_condition': time_condition,
  1217. 'cutoff_time': cutoff_time.isoformat(),
  1218. 'sessions_removed': 0,
  1219. 'conversations_removed': 0,
  1220. 'sessions_kept': 0,
  1221. 'conversations_kept': 0,
  1222. 'removed_session_ids': [],
  1223. 'removed_conversation_ids': []
  1224. }
  1225. # 按session删除
  1226. sessions_to_remove = []
  1227. for session_id, session_data in cache.session_info.items():
  1228. if session_data['start_time'] < cutoff_time:
  1229. sessions_to_remove.append(session_id)
  1230. # 删除符合条件的sessions及其所有conversations
  1231. for session_id in sessions_to_remove:
  1232. session_data = cache.session_info[session_id]
  1233. conversations_in_session = session_data['conversations'].copy()
  1234. # 删除session中的所有conversations
  1235. for conv_id in conversations_in_session:
  1236. if conv_id in cache.cache:
  1237. del cache.cache[conv_id]
  1238. cleanup_stats['conversations_removed'] += 1
  1239. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1240. # 清理conversation相关的时间记录
  1241. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1242. del cache.conversation_start_times[conv_id]
  1243. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1244. del cache.conversation_to_session[conv_id]
  1245. # 删除session记录
  1246. del cache.session_info[session_id]
  1247. cleanup_stats['sessions_removed'] += 1
  1248. cleanup_stats['removed_session_ids'].append(session_id)
  1249. # 统计保留的sessions和conversations
  1250. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1251. cleanup_stats['conversations_kept'] = len(cache.cache)
  1252. from common.result import success_response
  1253. return jsonify(success_response(
  1254. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1255. data=cleanup_stats
  1256. ))
  1257. except Exception as e:
  1258. from common.result import internal_error_response
  1259. return jsonify(internal_error_response(
  1260. response_text="缓存清理失败,请稍后重试"
  1261. )), 500
  1262. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1263. def training_error_question_sql():
  1264. """
  1265. 存储错误的question-sql对到error_sql集合中
  1266. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1267. Args:
  1268. question (str, required): 用户问题
  1269. sql (str, required): 对应的错误SQL查询语句
  1270. Returns:
  1271. JSON: 包含训练ID和成功消息的响应
  1272. """
  1273. try:
  1274. data = request.get_json()
  1275. question = data.get('question')
  1276. sql = data.get('sql')
  1277. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1278. if not question or not sql:
  1279. from common.result import bad_request_response
  1280. missing_params = []
  1281. if not question:
  1282. missing_params.append("question")
  1283. if not sql:
  1284. missing_params.append("sql")
  1285. return jsonify(bad_request_response(
  1286. response_text="question和sql参数都是必需的",
  1287. missing_params=missing_params
  1288. )), 400
  1289. # 使用vn实例的train_error_sql方法存储错误SQL
  1290. id = vn.train_error_sql(question=question, sql=sql)
  1291. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1292. from common.result import success_response
  1293. return jsonify(success_response(
  1294. response_text="错误SQL对已成功存储",
  1295. data={
  1296. "id": id,
  1297. "message": "错误SQL对已成功存储到error_sql集合"
  1298. }
  1299. ))
  1300. except Exception as e:
  1301. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1302. from common.result import internal_error_response
  1303. return jsonify(internal_error_response(
  1304. response_text="存储错误SQL失败,请稍后重试"
  1305. )), 500
  1306. # ==================== Redis对话管理API ====================
  1307. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1308. def get_user_conversations(user_id: str):
  1309. """获取用户的对话列表(按时间倒序)"""
  1310. try:
  1311. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1312. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1313. return jsonify(success_response(
  1314. response_text="获取用户对话列表成功",
  1315. data={
  1316. "user_id": user_id,
  1317. "conversations": conversations,
  1318. "total_count": len(conversations)
  1319. }
  1320. ))
  1321. except Exception as e:
  1322. return jsonify(internal_error_response(
  1323. response_text="获取对话列表失败,请稍后重试"
  1324. )), 500
  1325. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1326. def get_conversation_messages(conversation_id: str):
  1327. """获取特定对话的消息历史"""
  1328. try:
  1329. limit = request.args.get('limit', type=int) # 可选参数
  1330. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1331. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1332. return jsonify(success_response(
  1333. response_text="获取对话消息成功",
  1334. data={
  1335. "conversation_id": conversation_id,
  1336. "conversation_meta": meta,
  1337. "messages": messages,
  1338. "message_count": len(messages)
  1339. }
  1340. ))
  1341. except Exception as e:
  1342. return jsonify(internal_error_response(
  1343. response_text="获取对话消息失败"
  1344. )), 500
  1345. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1346. def get_conversation_context(conversation_id: str):
  1347. """获取对话上下文(格式化用于LLM)"""
  1348. try:
  1349. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1350. context = redis_conversation_manager.get_context(conversation_id, count)
  1351. return jsonify(success_response(
  1352. response_text="获取对话上下文成功",
  1353. data={
  1354. "conversation_id": conversation_id,
  1355. "context": context,
  1356. "context_message_count": count
  1357. }
  1358. ))
  1359. except Exception as e:
  1360. return jsonify(internal_error_response(
  1361. response_text="获取对话上下文失败"
  1362. )), 500
  1363. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1364. def conversation_stats():
  1365. """获取对话系统统计信息"""
  1366. try:
  1367. stats = redis_conversation_manager.get_stats()
  1368. return jsonify(success_response(
  1369. response_text="获取统计信息成功",
  1370. data=stats
  1371. ))
  1372. except Exception as e:
  1373. return jsonify(internal_error_response(
  1374. response_text="获取统计信息失败,请稍后重试"
  1375. )), 500
  1376. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1377. def conversation_cleanup():
  1378. """手动清理过期对话"""
  1379. try:
  1380. redis_conversation_manager.cleanup_expired_conversations()
  1381. return jsonify(success_response(
  1382. response_text="对话清理完成"
  1383. ))
  1384. except Exception as e:
  1385. return jsonify(internal_error_response(
  1386. response_text="对话清理失败,请稍后重试"
  1387. )), 500
  1388. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1389. def get_user_conversations_with_messages(user_id: str):
  1390. """
  1391. 获取用户的完整对话数据(包含所有消息)
  1392. 一次性返回用户的所有对话和每个对话下的消息历史
  1393. Args:
  1394. user_id: 用户ID(路径参数)
  1395. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1396. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1397. Returns:
  1398. 包含用户所有对话和消息的完整数据
  1399. """
  1400. try:
  1401. # 获取可选参数,不传递时使用None(返回所有记录)
  1402. conversation_limit = request.args.get('conversation_limit', type=int)
  1403. message_limit = request.args.get('message_limit', type=int)
  1404. # 获取用户的对话列表
  1405. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1406. # 为每个对话获取消息历史
  1407. full_conversations = []
  1408. total_messages = 0
  1409. for conversation in conversations:
  1410. conversation_id = conversation['conversation_id']
  1411. # 获取对话消息
  1412. messages = redis_conversation_manager.get_conversation_messages(
  1413. conversation_id, message_limit
  1414. )
  1415. # 获取对话元数据
  1416. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1417. # 组合完整数据
  1418. full_conversation = {
  1419. **conversation, # 基础对话信息
  1420. 'meta': meta, # 对话元数据
  1421. 'messages': messages, # 消息列表
  1422. 'message_count': len(messages)
  1423. }
  1424. full_conversations.append(full_conversation)
  1425. total_messages += len(messages)
  1426. return jsonify(success_response(
  1427. response_text="获取用户完整对话数据成功",
  1428. data={
  1429. "user_id": user_id,
  1430. "conversations": full_conversations,
  1431. "total_conversations": len(full_conversations),
  1432. "total_messages": total_messages,
  1433. "conversation_limit_applied": conversation_limit,
  1434. "message_limit_applied": message_limit,
  1435. "query_time": datetime.now().isoformat()
  1436. }
  1437. ))
  1438. except Exception as e:
  1439. print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
  1440. return jsonify(internal_error_response(
  1441. response_text="获取用户对话数据失败,请稍后重试"
  1442. )), 500
  1443. # 前端JavaScript示例 - 如何维持会话
  1444. """
  1445. // 前端需要维护一个会话ID
  1446. class ChatSession {
  1447. constructor() {
  1448. // 从localStorage获取或创建新的会话ID
  1449. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1450. localStorage.setItem('chat_session_id', this.sessionId);
  1451. }
  1452. generateSessionId() {
  1453. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1454. }
  1455. async askQuestion(question) {
  1456. const response = await fetch('/api/v0/ask', {
  1457. method: 'POST',
  1458. headers: {
  1459. 'Content-Type': 'application/json',
  1460. },
  1461. body: JSON.stringify({
  1462. question: question,
  1463. session_id: this.sessionId // 关键:传递会话ID
  1464. })
  1465. });
  1466. return await response.json();
  1467. }
  1468. // 开始新会话
  1469. startNewSession() {
  1470. this.sessionId = this.generateSessionId();
  1471. localStorage.setItem('chat_session_id', this.sessionId);
  1472. }
  1473. }
  1474. // 使用示例
  1475. const chatSession = new ChatSession();
  1476. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1477. """
  1478. print("正在启动Flask应用: http://localhost:8084")
  1479. app.run(host="0.0.0.0", port=8084, debug=True)