citu_app.py 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  15. from common.result import ( # 统一导入所有需要的响应函数
  16. bad_request_response, service_unavailable_response,
  17. agent_success_response, agent_error_response,
  18. internal_error_response, success_response,
  19. validation_failed_response
  20. )
  21. from app_config import ( # 添加Redis相关配置导入
  22. USER_MAX_CONVERSATIONS,
  23. CONVERSATION_CONTEXT_COUNT,
  24. DEFAULT_ANONYMOUS_USER
  25. )
  26. # 设置默认的最大返回行数
  27. DEFAULT_MAX_RETURN_ROWS = 200
  28. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  29. vn = create_vanna_instance()
  30. # 创建带时间戳的缓存
  31. timestamped_cache = WebSessionAwareMemoryCache()
  32. # 实例化 VannaFlaskApp,使用自定义缓存
  33. app = VannaFlaskApp(
  34. vn,
  35. cache=timestamped_cache, # 使用带时间戳的缓存
  36. title="辞图智能数据问答平台",
  37. logo = "https://www.citupro.com/img/logo-black-2.png",
  38. subtitle="让 AI 为你写 SQL",
  39. chart=False,
  40. allow_llm_to_see_data=True,
  41. ask_results_correct=True,
  42. followup_questions=True,
  43. debug=True
  44. )
  45. # 创建Redis对话管理器实例
  46. redis_conversation_manager = RedisConversationManager()
  47. # 修改ask接口,支持前端传递session_id
  48. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  49. def ask_full():
  50. req = request.get_json(force=True)
  51. question = req.get("question", None)
  52. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  53. if not question:
  54. from common.result import bad_request_response
  55. return jsonify(bad_request_response(
  56. response_text="缺少必需参数:question",
  57. missing_params=["question"]
  58. )), 400
  59. # 如果使用WebSessionAwareMemoryCache
  60. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  61. # 这里需要修改vanna的ask方法来支持传递session_id
  62. # 或者预先调用generate_id来建立会话关联
  63. conversation_id = app.cache.generate_id_with_browser_session(
  64. question=question,
  65. browser_session_id=browser_session_id
  66. )
  67. try:
  68. sql, df, _ = vn.ask(
  69. question=question,
  70. print_results=False,
  71. visualize=False,
  72. allow_llm_to_see_data=True
  73. )
  74. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  75. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  76. # 在解释性文本末尾添加提示语
  77. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  78. # 使用标准化错误响应
  79. from common.result import validation_failed_response
  80. return jsonify(validation_failed_response(
  81. response_text=explanation_message
  82. )), 422 # 修改HTTP状态码为422
  83. # 如果sql为None但没有解释性文本,返回通用错误
  84. if sql is None:
  85. from common.result import validation_failed_response
  86. return jsonify(validation_failed_response(
  87. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  88. )), 422
  89. # 处理返回数据 - 使用新的query_result结构
  90. query_result = {
  91. "rows": [],
  92. "columns": [],
  93. "row_count": 0,
  94. "is_limited": False,
  95. "total_row_count": 0
  96. }
  97. summary = None
  98. if isinstance(df, pd.DataFrame):
  99. query_result["columns"] = list(df.columns)
  100. if not df.empty:
  101. total_rows = len(df)
  102. limited_df = df.head(MAX_RETURN_ROWS)
  103. query_result["rows"] = limited_df.to_dict(orient="records")
  104. query_result["row_count"] = len(limited_df)
  105. query_result["total_row_count"] = total_rows
  106. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  107. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  108. if ENABLE_RESULT_SUMMARY:
  109. try:
  110. summary = vn.generate_summary(question=question, df=df)
  111. print(f"[INFO] 成功生成摘要: {summary}")
  112. except Exception as e:
  113. print(f"[WARNING] 生成摘要失败: {str(e)}")
  114. summary = None
  115. # 构建返回数据
  116. response_data = {
  117. "sql": sql,
  118. "query_result": query_result,
  119. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  120. "session_id": browser_session_id
  121. }
  122. # 添加摘要(如果启用且生成成功)
  123. if ENABLE_RESULT_SUMMARY and summary is not None:
  124. response_data["summary"] = summary
  125. response_data["response"] = summary # 同时添加response字段
  126. from common.result import success_response
  127. return jsonify(success_response(
  128. response_text="查询执行完成" if summary is None else None,
  129. data=response_data
  130. ))
  131. except Exception as e:
  132. print(f"[ERROR] ask_full执行失败: {str(e)}")
  133. # 即使发生异常,也检查是否有业务层面的解释
  134. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  135. # 在解释性文本末尾添加提示语
  136. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  137. from common.result import validation_failed_response
  138. return jsonify(validation_failed_response(
  139. response_text=explanation_message
  140. )), 422
  141. else:
  142. # 技术错误,使用500错误码
  143. from common.result import internal_error_response
  144. return jsonify(internal_error_response(
  145. response_text="查询处理失败,请稍后重试"
  146. )), 500
  147. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  148. def citu_run_sql():
  149. req = request.get_json(force=True)
  150. sql = req.get('sql')
  151. if not sql:
  152. from common.result import bad_request_response
  153. return jsonify(bad_request_response(
  154. response_text="缺少必需参数:sql",
  155. missing_params=["sql"]
  156. )), 400
  157. try:
  158. df = vn.run_sql(sql)
  159. # 处理返回数据 - 使用新的query_result结构
  160. query_result = {
  161. "rows": [],
  162. "columns": [],
  163. "row_count": 0,
  164. "is_limited": False,
  165. "total_row_count": 0
  166. }
  167. if isinstance(df, pd.DataFrame):
  168. query_result["columns"] = list(df.columns)
  169. if not df.empty:
  170. total_rows = len(df)
  171. limited_df = df.head(MAX_RETURN_ROWS)
  172. query_result["rows"] = limited_df.to_dict(orient="records")
  173. query_result["row_count"] = len(limited_df)
  174. query_result["total_row_count"] = total_rows
  175. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  176. from common.result import success_response
  177. return jsonify(success_response(
  178. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  179. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  180. data={
  181. "sql": sql,
  182. "query_result": query_result
  183. }
  184. ))
  185. except Exception as e:
  186. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  187. from common.result import internal_error_response
  188. return jsonify(internal_error_response(
  189. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  190. )), 500
  191. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  192. def ask_cached():
  193. """
  194. 带缓存功能的智能查询接口
  195. 支持会话管理和结果缓存,提高查询效率
  196. """
  197. req = request.get_json(force=True)
  198. question = req.get("question", None)
  199. browser_session_id = req.get("session_id", None)
  200. if not question:
  201. from common.result import bad_request_response
  202. return jsonify(bad_request_response(
  203. response_text="缺少必需参数:question",
  204. missing_params=["question"]
  205. )), 400
  206. try:
  207. # 生成conversation_id
  208. # 调试:查看generate_id的实际行为
  209. print(f"[DEBUG] 输入问题: '{question}'")
  210. conversation_id = app.cache.generate_id(question=question)
  211. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  212. # 再次用相同问题测试
  213. conversation_id2 = app.cache.generate_id(question=question)
  214. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  215. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  216. # 检查缓存
  217. cached_sql = app.cache.get(id=conversation_id, field="sql")
  218. if cached_sql is not None:
  219. # 缓存命中
  220. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  221. sql = cached_sql
  222. df = app.cache.get(id=conversation_id, field="df")
  223. summary = app.cache.get(id=conversation_id, field="summary")
  224. else:
  225. # 缓存未命中,执行新查询
  226. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  227. sql, df, _ = vn.ask(
  228. question=question,
  229. print_results=False,
  230. visualize=False,
  231. allow_llm_to_see_data=True
  232. )
  233. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  234. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  235. # 在解释性文本末尾添加提示语
  236. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  237. from common.result import validation_failed_response
  238. return jsonify(validation_failed_response(
  239. response_text=explanation_message
  240. )), 422
  241. # 如果sql为None但没有解释性文本,返回通用错误
  242. if sql is None:
  243. from common.result import validation_failed_response
  244. return jsonify(validation_failed_response(
  245. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  246. )), 422
  247. # 缓存结果
  248. app.cache.set(id=conversation_id, field="question", value=question)
  249. app.cache.set(id=conversation_id, field="sql", value=sql)
  250. app.cache.set(id=conversation_id, field="df", value=df)
  251. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  252. summary = None
  253. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  254. try:
  255. summary = vn.generate_summary(question=question, df=df)
  256. print(f"[INFO] 成功生成摘要: {summary}")
  257. except Exception as e:
  258. print(f"[WARNING] 生成摘要失败: {str(e)}")
  259. summary = None
  260. app.cache.set(id=conversation_id, field="summary", value=summary)
  261. # 处理返回数据 - 使用新的query_result结构
  262. query_result = {
  263. "rows": [],
  264. "columns": [],
  265. "row_count": 0,
  266. "is_limited": False,
  267. "total_row_count": 0
  268. }
  269. if isinstance(df, pd.DataFrame):
  270. query_result["columns"] = list(df.columns)
  271. if not df.empty:
  272. total_rows = len(df)
  273. limited_df = df.head(MAX_RETURN_ROWS)
  274. query_result["rows"] = limited_df.to_dict(orient="records")
  275. query_result["row_count"] = len(limited_df)
  276. query_result["total_row_count"] = total_rows
  277. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  278. # 构建返回数据
  279. response_data = {
  280. "sql": sql,
  281. "query_result": query_result,
  282. "conversation_id": conversation_id,
  283. "session_id": browser_session_id,
  284. "cached": cached_sql is not None # 标识是否来自缓存
  285. }
  286. # 添加摘要(如果启用且生成成功)
  287. if ENABLE_RESULT_SUMMARY and summary is not None:
  288. response_data["summary"] = summary
  289. response_data["response"] = summary # 同时添加response字段
  290. from common.result import success_response
  291. return jsonify(success_response(
  292. response_text="查询执行完成" if summary is None else None,
  293. data=response_data
  294. ))
  295. except Exception as e:
  296. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  297. from common.result import internal_error_response
  298. return jsonify(internal_error_response(
  299. response_text="查询处理失败,请稍后重试"
  300. )), 500
  301. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  302. def citu_train_question_sql():
  303. """
  304. 训练问题-SQL对接口
  305. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  306. 支持仅传入SQL或同时传入问题和SQL进行训练。
  307. Args:
  308. question (str, optional): 用户问题
  309. sql (str, required): 对应的SQL查询语句
  310. Returns:
  311. JSON: 包含训练ID和成功消息的响应
  312. """
  313. try:
  314. req = request.get_json(force=True)
  315. question = req.get('question')
  316. sql = req.get('sql')
  317. if not sql:
  318. from common.result import bad_request_response
  319. return jsonify(bad_request_response(
  320. response_text="缺少必需参数:sql",
  321. missing_params=["sql"]
  322. )), 400
  323. # 正确的调用方式:同时传递question和sql
  324. if question:
  325. training_id = vn.train(question=question, sql=sql)
  326. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  327. else:
  328. training_id = vn.train(sql=sql)
  329. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  330. from common.result import success_response
  331. return jsonify(success_response(
  332. response_text="问题-SQL对训练成功",
  333. data={
  334. "training_id": training_id,
  335. "message": "Question-SQL pair trained successfully"
  336. }
  337. ))
  338. except Exception as e:
  339. from common.result import internal_error_response
  340. return jsonify(internal_error_response(
  341. response_text="训练失败,请稍后重试"
  342. )), 500
  343. # ============ LangGraph Agent 集成 ============
  344. # 全局Agent实例(单例模式)
  345. citu_langraph_agent = None
  346. def get_citu_langraph_agent():
  347. """获取LangGraph Agent实例(懒加载)"""
  348. global citu_langraph_agent
  349. if citu_langraph_agent is None:
  350. try:
  351. from agent.citu_agent import CituLangGraphAgent
  352. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  353. citu_langraph_agent = CituLangGraphAgent()
  354. print("[CITU_APP] LangGraph Agent实例创建成功")
  355. except ImportError as e:
  356. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  357. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  358. raise Exception(f"Agent模块导入失败: {str(e)}")
  359. except Exception as e:
  360. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  361. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  362. # 提供更有用的错误信息
  363. if "config" in str(e).lower():
  364. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  365. elif "llm" in str(e).lower():
  366. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  367. elif "tool" in str(e).lower():
  368. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  369. raise Exception(f"Agent初始化失败: {str(e)}")
  370. return citu_langraph_agent
  371. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  372. def ask_agent():
  373. """
  374. 支持对话上下文的ask_agent API - 修正版
  375. """
  376. req = request.get_json(force=True)
  377. question = req.get("question", None)
  378. browser_session_id = req.get("session_id", None)
  379. # 新增参数解析
  380. user_id_input = req.get("user_id", None)
  381. conversation_id_input = req.get("conversation_id", None)
  382. continue_conversation = req.get("continue_conversation", False)
  383. # 新增:路由模式参数解析和验证
  384. api_routing_mode = req.get("routing_mode", None)
  385. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  386. if not question:
  387. return jsonify(bad_request_response(
  388. response_text="缺少必需参数:question",
  389. missing_params=["question"]
  390. )), 400
  391. # 验证routing_mode参数
  392. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  393. return jsonify(bad_request_response(
  394. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  395. invalid_params=["routing_mode"]
  396. )), 400
  397. try:
  398. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  399. login_user_id = session.get('user_id') if 'user_id' in session else None
  400. # 2. 智能ID解析(修正:传入登录用户ID)
  401. user_id = redis_conversation_manager.resolve_user_id(
  402. user_id_input, browser_session_id, request.remote_addr, login_user_id
  403. )
  404. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  405. user_id, conversation_id_input, continue_conversation
  406. )
  407. # 3. 获取上下文和上下文类型(提前到缓存检查之前)
  408. context = redis_conversation_manager.get_context(conversation_id)
  409. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  410. context_type = None
  411. if context:
  412. try:
  413. # 获取最后一条助手消息的metadata
  414. messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
  415. for message in reversed(messages): # 从最新的开始找
  416. if message.get("role") == "assistant":
  417. metadata = message.get("metadata", {})
  418. context_type = metadata.get("type")
  419. if context_type:
  420. print(f"[AGENT_API] 检测到上下文类型: {context_type}")
  421. break
  422. except Exception as e:
  423. print(f"[WARNING] 获取上下文类型失败: {str(e)}")
  424. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  425. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  426. if cached_answer:
  427. print(f"[AGENT_API] 使用缓存答案")
  428. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  429. cached_response_type = cached_answer.get("type", "UNKNOWN")
  430. if cached_response_type == "DATABASE":
  431. # DATABASE类型:按优先级选择内容
  432. if cached_answer.get("response"):
  433. # 优先级1:错误或解释性回复(如SQL生成失败)
  434. assistant_response = cached_answer.get("response")
  435. elif cached_answer.get("summary"):
  436. # 优先级2:查询成功的摘要
  437. assistant_response = cached_answer.get("summary")
  438. elif cached_answer.get("query_result"):
  439. # 优先级3:构造简单描述
  440. query_result = cached_answer.get("query_result")
  441. row_count = query_result.get("row_count", 0)
  442. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  443. else:
  444. # 异常情况
  445. assistant_response = "数据库查询已处理。"
  446. else:
  447. # CHAT类型:直接使用response
  448. assistant_response = cached_answer.get("response", "")
  449. # 更新对话历史
  450. redis_conversation_manager.save_message(conversation_id, "user", question)
  451. redis_conversation_manager.save_message(
  452. conversation_id, "assistant",
  453. assistant_response,
  454. metadata={"from_cache": True}
  455. )
  456. # 添加对话信息到缓存结果
  457. cached_answer["conversation_id"] = conversation_id
  458. cached_answer["user_id"] = user_id
  459. cached_answer["from_cache"] = True
  460. cached_answer.update(conversation_status)
  461. # 使用agent_success_response返回标准格式
  462. return jsonify(agent_success_response(
  463. response_type=cached_answer.get("type", "UNKNOWN"),
  464. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  465. sql=cached_answer.get("sql"),
  466. query_result=cached_answer.get("query_result"),
  467. summary=cached_answer.get("summary"),
  468. session_id=browser_session_id,
  469. execution_path=cached_answer.get("execution_path", []),
  470. classification_info=cached_answer.get("classification_info", {}),
  471. conversation_id=conversation_id,
  472. user_id=user_id,
  473. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  474. context_used=bool(context),
  475. from_cache=True,
  476. conversation_status=conversation_status["status"],
  477. conversation_message=conversation_status["message"],
  478. requested_conversation_id=conversation_status.get("requested_id")
  479. ))
  480. # 5. 保存用户消息
  481. redis_conversation_manager.save_message(conversation_id, "user", question)
  482. # 6. 构建带上下文的问题
  483. if context:
  484. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  485. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  486. else:
  487. enhanced_question = question
  488. print(f"[AGENT_API] 新对话,无上下文")
  489. # 7. 确定最终使用的路由模式(优先级逻辑)
  490. if api_routing_mode:
  491. # API传了参数,优先使用
  492. effective_routing_mode = api_routing_mode
  493. print(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
  494. else:
  495. # API没传参数,使用配置文件
  496. try:
  497. from app_config import QUESTION_ROUTING_MODE
  498. effective_routing_mode = QUESTION_ROUTING_MODE
  499. print(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
  500. except ImportError:
  501. effective_routing_mode = "hybrid"
  502. print(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  503. # 8. 现有Agent处理逻辑(修改为传递路由模式)
  504. try:
  505. agent = get_citu_langraph_agent()
  506. except Exception as e:
  507. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  508. return jsonify(service_unavailable_response(
  509. response_text="AI服务暂时不可用,请稍后重试",
  510. can_retry=True
  511. )), 503
  512. agent_result = agent.process_question(
  513. question=enhanced_question, # 使用增强后的问题
  514. session_id=browser_session_id,
  515. context_type=context_type, # 传递上下文类型
  516. routing_mode=effective_routing_mode # 新增:传递路由模式
  517. )
  518. # 8. 处理Agent结果
  519. if agent_result.get("success", False):
  520. # 修正:直接从agent_result获取字段,因为它就是final_response
  521. response_type = agent_result.get("type", "UNKNOWN")
  522. response_text = agent_result.get("response", "")
  523. sql = agent_result.get("sql")
  524. query_result = agent_result.get("query_result")
  525. summary = agent_result.get("summary")
  526. execution_path = agent_result.get("execution_path", [])
  527. classification_info = agent_result.get("classification_info", {})
  528. # 确定助手回复内容的优先级
  529. if response_type == "DATABASE":
  530. # DATABASE类型:按优先级选择内容
  531. if response_text:
  532. # 优先级1:错误或解释性回复(如SQL生成失败)
  533. assistant_response = response_text
  534. elif summary:
  535. # 优先级2:查询成功的摘要
  536. assistant_response = summary
  537. elif query_result:
  538. # 优先级3:构造简单描述
  539. row_count = query_result.get("row_count", 0)
  540. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  541. else:
  542. # 异常情况
  543. assistant_response = "数据库查询已处理。"
  544. else:
  545. # CHAT类型:直接使用response
  546. assistant_response = response_text
  547. # 保存助手回复
  548. redis_conversation_manager.save_message(
  549. conversation_id, "assistant", assistant_response,
  550. metadata={
  551. "type": response_type,
  552. "sql": sql,
  553. "execution_path": execution_path
  554. }
  555. )
  556. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  557. # 直接缓存agent_result,它已经包含所有需要的字段
  558. redis_conversation_manager.cache_answer(question, agent_result, context)
  559. # 使用agent_success_response的正确方式
  560. return jsonify(agent_success_response(
  561. response_type=response_type,
  562. response=response_text, # 修正:使用response而不是response_text
  563. sql=sql,
  564. query_result=query_result,
  565. summary=summary,
  566. session_id=browser_session_id,
  567. execution_path=execution_path,
  568. classification_info=classification_info,
  569. conversation_id=conversation_id,
  570. user_id=user_id,
  571. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  572. context_used=bool(context),
  573. from_cache=False,
  574. conversation_status=conversation_status["status"],
  575. conversation_message=conversation_status["message"],
  576. requested_conversation_id=conversation_status.get("requested_id"),
  577. routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
  578. routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
  579. ))
  580. else:
  581. # 错误处理(修正:确保使用现有的错误响应格式)
  582. error_message = agent_result.get("error", "Agent处理失败")
  583. error_code = agent_result.get("error_code", 500)
  584. return jsonify(agent_error_response(
  585. response_text=error_message,
  586. error_type="agent_processing_failed",
  587. code=error_code,
  588. session_id=browser_session_id,
  589. conversation_id=conversation_id,
  590. user_id=user_id
  591. )), error_code
  592. except Exception as e:
  593. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  594. return jsonify(internal_error_response(
  595. response_text="查询处理失败,请稍后重试"
  596. )), 500
  597. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  598. def agent_health():
  599. """
  600. Agent健康检查接口
  601. 响应格式:
  602. {
  603. "success": true/false,
  604. "code": 200/503,
  605. "message": "healthy/degraded/unhealthy",
  606. "data": {
  607. "status": "healthy/degraded/unhealthy",
  608. "test_result": true/false,
  609. "workflow_compiled": true/false,
  610. "tools_count": 4,
  611. "message": "详细信息",
  612. "timestamp": "2024-01-01T12:00:00",
  613. "checks": {
  614. "agent_creation": true/false,
  615. "tools_import": true/false,
  616. "llm_connection": true/false,
  617. "classifier_ready": true/false
  618. }
  619. }
  620. }
  621. """
  622. try:
  623. # 基础健康检查
  624. health_data = {
  625. "status": "unknown",
  626. "test_result": False,
  627. "workflow_compiled": False,
  628. "tools_count": 0,
  629. "message": "",
  630. "timestamp": datetime.now().isoformat(),
  631. "checks": {
  632. "agent_creation": False,
  633. "tools_import": False,
  634. "llm_connection": False,
  635. "classifier_ready": False
  636. }
  637. }
  638. # 检查1: Agent创建
  639. try:
  640. agent = get_citu_langraph_agent()
  641. health_data["checks"]["agent_creation"] = True
  642. health_data["workflow_compiled"] = agent.workflow is not None
  643. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  644. except Exception as e:
  645. health_data["message"] = f"Agent创建失败: {str(e)}"
  646. from common.result import health_error_response
  647. return jsonify(health_error_response(
  648. status="unhealthy",
  649. **health_data
  650. )), 503
  651. # 检查2: 工具导入
  652. try:
  653. from agent.tools import TOOLS
  654. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  655. except Exception as e:
  656. health_data["message"] = f"工具导入失败: {str(e)}"
  657. # 检查3: LLM连接(简单测试)
  658. try:
  659. from agent.utils import get_compatible_llm
  660. llm = get_compatible_llm()
  661. health_data["checks"]["llm_connection"] = llm is not None
  662. except Exception as e:
  663. health_data["message"] = f"LLM连接失败: {str(e)}"
  664. # 检查4: 分类器准备
  665. try:
  666. from agent.classifier import QuestionClassifier
  667. classifier = QuestionClassifier()
  668. health_data["checks"]["classifier_ready"] = True
  669. except Exception as e:
  670. health_data["message"] = f"分类器失败: {str(e)}"
  671. # 检查5: 完整流程测试(可选)
  672. try:
  673. if all(health_data["checks"].values()):
  674. test_result = agent.health_check()
  675. health_data["test_result"] = test_result.get("status") == "healthy"
  676. health_data["status"] = test_result.get("status", "unknown")
  677. health_data["message"] = test_result.get("message", "健康检查完成")
  678. else:
  679. health_data["status"] = "degraded"
  680. health_data["message"] = "部分组件异常"
  681. except Exception as e:
  682. health_data["status"] = "degraded"
  683. health_data["message"] = f"完整测试失败: {str(e)}"
  684. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  685. from common.result import health_success_response, health_error_response
  686. if health_data["status"] == "healthy":
  687. return jsonify(health_success_response(**health_data))
  688. elif health_data["status"] == "degraded":
  689. return jsonify(health_error_response(status="degraded", **health_data)), 503
  690. else:
  691. return jsonify(health_error_response(status="unhealthy", **health_data)), 503
  692. except Exception as e:
  693. print(f"[ERROR] 健康检查异常: {str(e)}")
  694. from common.result import internal_error_response
  695. return jsonify(internal_error_response(
  696. response_text="健康检查失败,请稍后重试"
  697. )), 500
  698. # ==================== 日常管理API ====================
  699. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  700. def cache_overview():
  701. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  702. try:
  703. cache = app.cache
  704. result_data = {
  705. 'overview_summary': {
  706. 'total_conversations': 0,
  707. 'total_sessions': 0,
  708. 'query_time': datetime.now().isoformat()
  709. },
  710. 'recent_conversations': [], # 最近的对话
  711. 'session_summary': [] # 会话摘要
  712. }
  713. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  714. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  715. # 获取会话信息
  716. if hasattr(cache, 'get_all_sessions'):
  717. all_sessions = cache.get_all_sessions()
  718. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  719. # 会话摘要(按最近活动排序)
  720. session_list = []
  721. for session_id, session_data in all_sessions.items():
  722. session_summary = {
  723. 'session_id': session_id,
  724. 'start_time': session_data['start_time'].isoformat(),
  725. 'conversation_count': session_data.get('conversation_count', 0),
  726. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  727. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  728. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  729. }
  730. session_list.append(session_summary)
  731. # 按最后活动时间排序
  732. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  733. result_data['session_summary'] = session_list
  734. # 最近的对话(最多显示10个)
  735. conversation_list = []
  736. for conversation_id, conversation_data in cache.cache.items():
  737. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  738. conversation_info = {
  739. 'conversation_id': conversation_id,
  740. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  741. 'session_id': cache.conversation_to_session.get(conversation_id),
  742. 'has_question': 'question' in conversation_data,
  743. 'has_sql': 'sql' in conversation_data,
  744. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  745. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  746. }
  747. # 计算对话持续时间
  748. if conversation_start_time:
  749. duration = datetime.now() - conversation_start_time
  750. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  751. conversation_list.append(conversation_info)
  752. # 按对话开始时间排序,显示最新的10个
  753. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  754. result_data['recent_conversations'] = conversation_list[:10]
  755. from common.result import success_response
  756. return jsonify(success_response(
  757. response_text="缓存概览查询完成",
  758. data=result_data
  759. ))
  760. except Exception as e:
  761. from common.result import internal_error_response
  762. return jsonify(internal_error_response(
  763. response_text="获取缓存概览失败,请稍后重试"
  764. )), 500
  765. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  766. def cache_stats():
  767. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  768. try:
  769. cache = app.cache
  770. current_time = datetime.now()
  771. stats = {
  772. 'basic_stats': {
  773. 'total_sessions': len(getattr(cache, 'session_info', {})),
  774. 'total_conversations': len(getattr(cache, 'cache', {})),
  775. 'active_sessions': 0, # 最近30分钟有活动
  776. 'average_conversations_per_session': 0
  777. },
  778. 'time_distribution': {
  779. 'sessions': {
  780. 'last_1_hour': 0,
  781. 'last_6_hours': 0,
  782. 'last_24_hours': 0,
  783. 'last_7_days': 0,
  784. 'older': 0
  785. },
  786. 'conversations': {
  787. 'last_1_hour': 0,
  788. 'last_6_hours': 0,
  789. 'last_24_hours': 0,
  790. 'last_7_days': 0,
  791. 'older': 0
  792. }
  793. },
  794. 'session_details': [],
  795. 'time_ranges': {
  796. 'oldest_session': None,
  797. 'newest_session': None,
  798. 'oldest_conversation': None,
  799. 'newest_conversation': None
  800. }
  801. }
  802. # 会话统计
  803. if hasattr(cache, 'session_info'):
  804. session_times = []
  805. total_conversations = 0
  806. for session_id, session_data in cache.session_info.items():
  807. start_time = session_data['start_time']
  808. session_times.append(start_time)
  809. conversation_count = len(session_data.get('conversations', []))
  810. total_conversations += conversation_count
  811. # 检查活跃状态
  812. last_activity = session_data.get('last_activity', session_data['start_time'])
  813. if (current_time - last_activity).total_seconds() < 1800:
  814. stats['basic_stats']['active_sessions'] += 1
  815. # 时间分布统计
  816. age_hours = (current_time - start_time).total_seconds() / 3600
  817. if age_hours <= 1:
  818. stats['time_distribution']['sessions']['last_1_hour'] += 1
  819. elif age_hours <= 6:
  820. stats['time_distribution']['sessions']['last_6_hours'] += 1
  821. elif age_hours <= 24:
  822. stats['time_distribution']['sessions']['last_24_hours'] += 1
  823. elif age_hours <= 168: # 7 days
  824. stats['time_distribution']['sessions']['last_7_days'] += 1
  825. else:
  826. stats['time_distribution']['sessions']['older'] += 1
  827. # 会话详细信息
  828. session_duration = current_time - start_time
  829. stats['session_details'].append({
  830. 'session_id': session_id,
  831. 'start_time': start_time.isoformat(),
  832. 'last_activity': last_activity.isoformat(),
  833. 'conversation_count': conversation_count,
  834. 'duration_seconds': session_duration.total_seconds(),
  835. 'duration_formatted': str(session_duration),
  836. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  837. 'browser_session_id': session_data.get('browser_session_id')
  838. })
  839. # 计算平均值
  840. if len(cache.session_info) > 0:
  841. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  842. # 时间范围
  843. if session_times:
  844. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  845. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  846. # 对话统计
  847. if hasattr(cache, 'conversation_start_times'):
  848. conversation_times = []
  849. for conv_time in cache.conversation_start_times.values():
  850. conversation_times.append(conv_time)
  851. age_hours = (current_time - conv_time).total_seconds() / 3600
  852. if age_hours <= 1:
  853. stats['time_distribution']['conversations']['last_1_hour'] += 1
  854. elif age_hours <= 6:
  855. stats['time_distribution']['conversations']['last_6_hours'] += 1
  856. elif age_hours <= 24:
  857. stats['time_distribution']['conversations']['last_24_hours'] += 1
  858. elif age_hours <= 168:
  859. stats['time_distribution']['conversations']['last_7_days'] += 1
  860. else:
  861. stats['time_distribution']['conversations']['older'] += 1
  862. if conversation_times:
  863. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  864. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  865. # 按最近活动排序会话详情
  866. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  867. from common.result import success_response
  868. return jsonify(success_response(
  869. response_text="缓存统计信息查询完成",
  870. data=stats
  871. ))
  872. except Exception as e:
  873. from common.result import internal_error_response
  874. return jsonify(internal_error_response(
  875. response_text="获取缓存统计失败,请稍后重试"
  876. )), 500
  877. # ==================== 高级功能API ====================
  878. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  879. def cache_export():
  880. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  881. try:
  882. cache = app.cache
  883. # 验证缓存的实际结构
  884. if not hasattr(cache, 'cache'):
  885. from common.result import internal_error_response
  886. return jsonify(internal_error_response(
  887. response_text="缓存对象结构异常,请联系系统管理员"
  888. )), 500
  889. if not isinstance(cache.cache, dict):
  890. from common.result import internal_error_response
  891. return jsonify(internal_error_response(
  892. response_text="缓存数据类型异常,请联系系统管理员"
  893. )), 500
  894. # 定义JSON序列化辅助函数
  895. def make_json_serializable(obj):
  896. """将对象转换为JSON可序列化的格式"""
  897. if obj is None:
  898. return None
  899. elif isinstance(obj, (str, int, float, bool)):
  900. return obj
  901. elif isinstance(obj, (list, tuple)):
  902. return [make_json_serializable(item) for item in obj]
  903. elif isinstance(obj, dict):
  904. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  905. elif hasattr(obj, 'isoformat'): # datetime objects
  906. return obj.isoformat()
  907. elif hasattr(obj, 'item'): # numpy scalars
  908. return obj.item()
  909. elif hasattr(obj, 'tolist'): # numpy arrays
  910. return obj.tolist()
  911. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  912. return str(obj)
  913. else:
  914. return str(obj)
  915. # 获取完整的原始缓存数据
  916. raw_cache = cache.cache
  917. # 获取会话和对话时间信息
  918. conversation_times = getattr(cache, 'conversation_start_times', {})
  919. session_info = getattr(cache, 'session_info', {})
  920. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  921. export_data = {
  922. 'export_metadata': {
  923. 'export_time': datetime.now().isoformat(),
  924. 'total_conversations': len(raw_cache),
  925. 'total_sessions': len(session_info),
  926. 'cache_type': type(cache).__name__,
  927. 'cache_object_info': str(cache),
  928. 'has_session_times': bool(session_info),
  929. 'has_conversation_times': bool(conversation_times)
  930. },
  931. 'session_info': {
  932. session_id: {
  933. 'start_time': session_data['start_time'].isoformat(),
  934. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  935. 'conversations': session_data['conversations'],
  936. 'conversation_count': len(session_data['conversations']),
  937. 'browser_session_id': session_data.get('browser_session_id'),
  938. 'user_info': session_data.get('user_info', {})
  939. }
  940. for session_id, session_data in session_info.items()
  941. },
  942. 'conversation_times': {
  943. conversation_id: start_time.isoformat()
  944. for conversation_id, start_time in conversation_times.items()
  945. },
  946. 'conversation_to_session_mapping': conversation_to_session,
  947. 'conversations': {}
  948. }
  949. # 处理每个对话的完整数据
  950. for conversation_id, conversation_data in raw_cache.items():
  951. # 获取时间信息
  952. conversation_start_time = conversation_times.get(conversation_id)
  953. session_id = conversation_to_session.get(conversation_id)
  954. session_start_time = None
  955. if session_id and session_id in session_info:
  956. session_start_time = session_info[session_id]['start_time']
  957. processed_conversation = {
  958. 'conversation_id': conversation_id,
  959. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  960. 'session_id': session_id,
  961. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  962. 'field_count': len(conversation_data),
  963. 'fields': {}
  964. }
  965. # 添加时间计算
  966. if conversation_start_time:
  967. conversation_duration = datetime.now() - conversation_start_time
  968. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  969. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  970. if session_start_time:
  971. session_duration = datetime.now() - session_start_time
  972. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  973. processed_conversation['session_duration_formatted'] = str(session_duration)
  974. # 处理每个字段,确保JSON序列化安全
  975. for field_name, field_value in conversation_data.items():
  976. field_info = {
  977. 'field_name': field_name,
  978. 'data_type': type(field_value).__name__,
  979. 'is_none': field_value is None
  980. }
  981. try:
  982. if field_value is None:
  983. field_info['value'] = None
  984. elif field_name in ['conversation_start_time', 'session_start_time']:
  985. # 处理时间字段
  986. field_info['content'] = make_json_serializable(field_value)
  987. elif field_name == 'df' and field_value is not None:
  988. # DataFrame的安全处理
  989. if hasattr(field_value, 'to_dict'):
  990. # 安全地处理dtypes
  991. try:
  992. dtypes_dict = {}
  993. for col, dtype in field_value.dtypes.items():
  994. dtypes_dict[col] = str(dtype)
  995. except Exception:
  996. dtypes_dict = {"error": "无法序列化dtypes"}
  997. # 安全地处理内存使用
  998. try:
  999. memory_usage = field_value.memory_usage(deep=True)
  1000. memory_dict = {}
  1001. for idx, usage in memory_usage.items():
  1002. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  1003. except Exception:
  1004. memory_dict = {"error": "无法获取内存使用信息"}
  1005. field_info.update({
  1006. 'dataframe_info': {
  1007. 'shape': list(field_value.shape),
  1008. 'columns': list(field_value.columns),
  1009. 'dtypes': dtypes_dict,
  1010. 'index_info': {
  1011. 'type': type(field_value.index).__name__,
  1012. 'length': len(field_value.index)
  1013. }
  1014. },
  1015. 'data': make_json_serializable(field_value.to_dict('records')),
  1016. 'memory_usage': memory_dict
  1017. })
  1018. else:
  1019. field_info['value'] = str(field_value)
  1020. field_info['note'] = 'not_standard_dataframe'
  1021. elif field_name == 'fig_json':
  1022. # 图表JSON数据处理
  1023. if isinstance(field_value, str):
  1024. try:
  1025. import json
  1026. parsed_fig = json.loads(field_value)
  1027. field_info.update({
  1028. 'json_valid': True,
  1029. 'json_size_bytes': len(field_value),
  1030. 'plotly_structure': {
  1031. 'has_data': 'data' in parsed_fig,
  1032. 'has_layout': 'layout' in parsed_fig,
  1033. 'data_traces_count': len(parsed_fig.get('data', [])),
  1034. },
  1035. 'raw_json': field_value
  1036. })
  1037. except json.JSONDecodeError:
  1038. field_info.update({
  1039. 'json_valid': False,
  1040. 'raw_content': str(field_value)
  1041. })
  1042. else:
  1043. field_info['value'] = make_json_serializable(field_value)
  1044. elif field_name == 'followup_questions':
  1045. # 后续问题列表
  1046. field_info.update({
  1047. 'content': make_json_serializable(field_value)
  1048. })
  1049. elif field_name in ['question', 'sql', 'summary']:
  1050. # 文本字段
  1051. if isinstance(field_value, str):
  1052. field_info.update({
  1053. 'text_length': len(field_value),
  1054. 'content': field_value
  1055. })
  1056. else:
  1057. field_info['value'] = make_json_serializable(field_value)
  1058. else:
  1059. # 未知字段的安全处理
  1060. field_info['content'] = make_json_serializable(field_value)
  1061. except Exception as e:
  1062. field_info.update({
  1063. 'processing_error': str(e),
  1064. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1065. })
  1066. processed_conversation['fields'][field_name] = field_info
  1067. export_data['conversations'][conversation_id] = processed_conversation
  1068. # 添加缓存统计信息
  1069. field_frequency = {}
  1070. data_types_found = set()
  1071. total_dataframes = 0
  1072. total_questions = 0
  1073. for conv_data in export_data['conversations'].values():
  1074. for field_name, field_info in conv_data['fields'].items():
  1075. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1076. data_types_found.add(field_info['data_type'])
  1077. if field_name == 'df' and not field_info['is_none']:
  1078. total_dataframes += 1
  1079. if field_name == 'question' and not field_info['is_none']:
  1080. total_questions += 1
  1081. export_data['cache_statistics'] = {
  1082. 'field_frequency': field_frequency,
  1083. 'data_types_found': list(data_types_found),
  1084. 'total_dataframes': total_dataframes,
  1085. 'total_questions': total_questions,
  1086. 'has_session_timing': 'session_start_time' in field_frequency,
  1087. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1088. }
  1089. from common.result import success_response
  1090. return jsonify(success_response(
  1091. response_text="缓存数据导出完成",
  1092. data=export_data
  1093. ))
  1094. except Exception as e:
  1095. import traceback
  1096. error_details = {
  1097. 'error_message': str(e),
  1098. 'error_type': type(e).__name__,
  1099. 'traceback': traceback.format_exc()
  1100. }
  1101. from common.result import internal_error_response
  1102. return jsonify(internal_error_response(
  1103. response_text="导出缓存失败,请稍后重试"
  1104. )), 500
  1105. # ==================== 清理功能API ====================
  1106. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1107. def cache_preview_cleanup():
  1108. """清理功能:预览删除操作 - 保持原功能"""
  1109. try:
  1110. req = request.get_json(force=True)
  1111. # 时间条件 - 支持三种方式
  1112. older_than_hours = req.get('older_than_hours')
  1113. older_than_days = req.get('older_than_days')
  1114. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1115. cache = app.cache
  1116. # 计算截止时间
  1117. cutoff_time = None
  1118. time_condition = None
  1119. if older_than_hours:
  1120. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1121. time_condition = f"older_than_hours: {older_than_hours}"
  1122. elif older_than_days:
  1123. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1124. time_condition = f"older_than_days: {older_than_days}"
  1125. elif before_timestamp:
  1126. try:
  1127. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1128. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1129. time_condition = f"before_timestamp: {before_timestamp}"
  1130. except ValueError:
  1131. from common.result import validation_failed_response
  1132. return jsonify(validation_failed_response(
  1133. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1134. )), 422
  1135. else:
  1136. from common.result import bad_request_response
  1137. return jsonify(bad_request_response(
  1138. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1139. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1140. )), 400
  1141. preview = {
  1142. 'time_condition': time_condition,
  1143. 'cutoff_time': cutoff_time.isoformat(),
  1144. 'will_be_removed': {
  1145. 'sessions': []
  1146. },
  1147. 'will_be_kept': {
  1148. 'sessions_count': 0,
  1149. 'conversations_count': 0
  1150. },
  1151. 'summary': {
  1152. 'sessions_to_remove': 0,
  1153. 'conversations_to_remove': 0,
  1154. 'sessions_to_keep': 0,
  1155. 'conversations_to_keep': 0
  1156. }
  1157. }
  1158. # 预览按session删除
  1159. sessions_to_remove_count = 0
  1160. conversations_to_remove_count = 0
  1161. for session_id, session_data in cache.session_info.items():
  1162. session_preview = {
  1163. 'session_id': session_id,
  1164. 'start_time': session_data['start_time'].isoformat(),
  1165. 'conversation_count': len(session_data['conversations']),
  1166. 'conversations': []
  1167. }
  1168. # 添加conversation详情
  1169. for conv_id in session_data['conversations']:
  1170. if conv_id in cache.cache:
  1171. conv_data = cache.cache[conv_id]
  1172. session_preview['conversations'].append({
  1173. 'conversation_id': conv_id,
  1174. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1175. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1176. })
  1177. if session_data['start_time'] < cutoff_time:
  1178. preview['will_be_removed']['sessions'].append(session_preview)
  1179. sessions_to_remove_count += 1
  1180. conversations_to_remove_count += len(session_data['conversations'])
  1181. else:
  1182. preview['will_be_kept']['sessions_count'] += 1
  1183. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1184. # 更新摘要统计
  1185. preview['summary'] = {
  1186. 'sessions_to_remove': sessions_to_remove_count,
  1187. 'conversations_to_remove': conversations_to_remove_count,
  1188. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1189. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1190. }
  1191. from common.result import success_response
  1192. return jsonify(success_response(
  1193. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1194. data=preview
  1195. ))
  1196. except Exception as e:
  1197. from common.result import internal_error_response
  1198. return jsonify(internal_error_response(
  1199. response_text="预览清理操作失败,请稍后重试"
  1200. )), 500
  1201. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1202. def cache_cleanup():
  1203. """清理功能:实际删除缓存 - 保持原功能"""
  1204. try:
  1205. req = request.get_json(force=True)
  1206. # 时间条件 - 支持三种方式
  1207. older_than_hours = req.get('older_than_hours')
  1208. older_than_days = req.get('older_than_days')
  1209. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1210. cache = app.cache
  1211. if not hasattr(cache, 'session_info'):
  1212. from common.result import service_unavailable_response
  1213. return jsonify(service_unavailable_response(
  1214. response_text="缓存不支持会话功能"
  1215. )), 503
  1216. # 计算截止时间
  1217. cutoff_time = None
  1218. time_condition = None
  1219. if older_than_hours:
  1220. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1221. time_condition = f"older_than_hours: {older_than_hours}"
  1222. elif older_than_days:
  1223. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1224. time_condition = f"older_than_days: {older_than_days}"
  1225. elif before_timestamp:
  1226. try:
  1227. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1228. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1229. time_condition = f"before_timestamp: {before_timestamp}"
  1230. except ValueError:
  1231. from common.result import validation_failed_response
  1232. return jsonify(validation_failed_response(
  1233. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1234. )), 422
  1235. else:
  1236. from common.result import bad_request_response
  1237. return jsonify(bad_request_response(
  1238. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1239. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1240. )), 400
  1241. cleanup_stats = {
  1242. 'time_condition': time_condition,
  1243. 'cutoff_time': cutoff_time.isoformat(),
  1244. 'sessions_removed': 0,
  1245. 'conversations_removed': 0,
  1246. 'sessions_kept': 0,
  1247. 'conversations_kept': 0,
  1248. 'removed_session_ids': [],
  1249. 'removed_conversation_ids': []
  1250. }
  1251. # 按session删除
  1252. sessions_to_remove = []
  1253. for session_id, session_data in cache.session_info.items():
  1254. if session_data['start_time'] < cutoff_time:
  1255. sessions_to_remove.append(session_id)
  1256. # 删除符合条件的sessions及其所有conversations
  1257. for session_id in sessions_to_remove:
  1258. session_data = cache.session_info[session_id]
  1259. conversations_in_session = session_data['conversations'].copy()
  1260. # 删除session中的所有conversations
  1261. for conv_id in conversations_in_session:
  1262. if conv_id in cache.cache:
  1263. del cache.cache[conv_id]
  1264. cleanup_stats['conversations_removed'] += 1
  1265. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1266. # 清理conversation相关的时间记录
  1267. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1268. del cache.conversation_start_times[conv_id]
  1269. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1270. del cache.conversation_to_session[conv_id]
  1271. # 删除session记录
  1272. del cache.session_info[session_id]
  1273. cleanup_stats['sessions_removed'] += 1
  1274. cleanup_stats['removed_session_ids'].append(session_id)
  1275. # 统计保留的sessions和conversations
  1276. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1277. cleanup_stats['conversations_kept'] = len(cache.cache)
  1278. from common.result import success_response
  1279. return jsonify(success_response(
  1280. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1281. data=cleanup_stats
  1282. ))
  1283. except Exception as e:
  1284. from common.result import internal_error_response
  1285. return jsonify(internal_error_response(
  1286. response_text="缓存清理失败,请稍后重试"
  1287. )), 500
  1288. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1289. def training_error_question_sql():
  1290. """
  1291. 存储错误的question-sql对到error_sql集合中
  1292. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1293. Args:
  1294. question (str, required): 用户问题
  1295. sql (str, required): 对应的错误SQL查询语句
  1296. Returns:
  1297. JSON: 包含训练ID和成功消息的响应
  1298. """
  1299. try:
  1300. data = request.get_json()
  1301. question = data.get('question')
  1302. sql = data.get('sql')
  1303. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1304. if not question or not sql:
  1305. from common.result import bad_request_response
  1306. missing_params = []
  1307. if not question:
  1308. missing_params.append("question")
  1309. if not sql:
  1310. missing_params.append("sql")
  1311. return jsonify(bad_request_response(
  1312. response_text="question和sql参数都是必需的",
  1313. missing_params=missing_params
  1314. )), 400
  1315. # 使用vn实例的train_error_sql方法存储错误SQL
  1316. id = vn.train_error_sql(question=question, sql=sql)
  1317. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1318. from common.result import success_response
  1319. return jsonify(success_response(
  1320. response_text="错误SQL对已成功存储",
  1321. data={
  1322. "id": id,
  1323. "message": "错误SQL对已成功存储到error_sql集合"
  1324. }
  1325. ))
  1326. except Exception as e:
  1327. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1328. from common.result import internal_error_response
  1329. return jsonify(internal_error_response(
  1330. response_text="存储错误SQL失败,请稍后重试"
  1331. )), 500
  1332. # ==================== Redis对话管理API ====================
  1333. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1334. def get_user_conversations(user_id: str):
  1335. """获取用户的对话列表(按时间倒序)"""
  1336. try:
  1337. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1338. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1339. return jsonify(success_response(
  1340. response_text="获取用户对话列表成功",
  1341. data={
  1342. "user_id": user_id,
  1343. "conversations": conversations,
  1344. "total_count": len(conversations)
  1345. }
  1346. ))
  1347. except Exception as e:
  1348. return jsonify(internal_error_response(
  1349. response_text="获取对话列表失败,请稍后重试"
  1350. )), 500
  1351. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1352. def get_conversation_messages(conversation_id: str):
  1353. """获取特定对话的消息历史"""
  1354. try:
  1355. limit = request.args.get('limit', type=int) # 可选参数
  1356. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1357. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1358. return jsonify(success_response(
  1359. response_text="获取对话消息成功",
  1360. data={
  1361. "conversation_id": conversation_id,
  1362. "conversation_meta": meta,
  1363. "messages": messages,
  1364. "message_count": len(messages)
  1365. }
  1366. ))
  1367. except Exception as e:
  1368. return jsonify(internal_error_response(
  1369. response_text="获取对话消息失败"
  1370. )), 500
  1371. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1372. def get_conversation_context(conversation_id: str):
  1373. """获取对话上下文(格式化用于LLM)"""
  1374. try:
  1375. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1376. context = redis_conversation_manager.get_context(conversation_id, count)
  1377. return jsonify(success_response(
  1378. response_text="获取对话上下文成功",
  1379. data={
  1380. "conversation_id": conversation_id,
  1381. "context": context,
  1382. "context_message_count": count
  1383. }
  1384. ))
  1385. except Exception as e:
  1386. return jsonify(internal_error_response(
  1387. response_text="获取对话上下文失败"
  1388. )), 500
  1389. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1390. def conversation_stats():
  1391. """获取对话系统统计信息"""
  1392. try:
  1393. stats = redis_conversation_manager.get_stats()
  1394. return jsonify(success_response(
  1395. response_text="获取统计信息成功",
  1396. data=stats
  1397. ))
  1398. except Exception as e:
  1399. return jsonify(internal_error_response(
  1400. response_text="获取统计信息失败,请稍后重试"
  1401. )), 500
  1402. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1403. def conversation_cleanup():
  1404. """手动清理过期对话"""
  1405. try:
  1406. redis_conversation_manager.cleanup_expired_conversations()
  1407. return jsonify(success_response(
  1408. response_text="对话清理完成"
  1409. ))
  1410. except Exception as e:
  1411. return jsonify(internal_error_response(
  1412. response_text="对话清理失败,请稍后重试"
  1413. )), 500
  1414. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1415. def get_user_conversations_with_messages(user_id: str):
  1416. """
  1417. 获取用户的完整对话数据(包含所有消息)
  1418. 一次性返回用户的所有对话和每个对话下的消息历史
  1419. Args:
  1420. user_id: 用户ID(路径参数)
  1421. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1422. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1423. Returns:
  1424. 包含用户所有对话和消息的完整数据
  1425. """
  1426. try:
  1427. # 获取可选参数,不传递时使用None(返回所有记录)
  1428. conversation_limit = request.args.get('conversation_limit', type=int)
  1429. message_limit = request.args.get('message_limit', type=int)
  1430. # 获取用户的对话列表
  1431. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1432. # 为每个对话获取消息历史
  1433. full_conversations = []
  1434. total_messages = 0
  1435. for conversation in conversations:
  1436. conversation_id = conversation['conversation_id']
  1437. # 获取对话消息
  1438. messages = redis_conversation_manager.get_conversation_messages(
  1439. conversation_id, message_limit
  1440. )
  1441. # 获取对话元数据
  1442. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1443. # 组合完整数据
  1444. full_conversation = {
  1445. **conversation, # 基础对话信息
  1446. 'meta': meta, # 对话元数据
  1447. 'messages': messages, # 消息列表
  1448. 'message_count': len(messages)
  1449. }
  1450. full_conversations.append(full_conversation)
  1451. total_messages += len(messages)
  1452. return jsonify(success_response(
  1453. response_text="获取用户完整对话数据成功",
  1454. data={
  1455. "user_id": user_id,
  1456. "conversations": full_conversations,
  1457. "total_conversations": len(full_conversations),
  1458. "total_messages": total_messages,
  1459. "conversation_limit_applied": conversation_limit,
  1460. "message_limit_applied": message_limit,
  1461. "query_time": datetime.now().isoformat()
  1462. }
  1463. ))
  1464. except Exception as e:
  1465. print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
  1466. return jsonify(internal_error_response(
  1467. response_text="获取用户对话数据失败,请稍后重试"
  1468. )), 500
  1469. # 前端JavaScript示例 - 如何维持会话
  1470. """
  1471. // 前端需要维护一个会话ID
  1472. class ChatSession {
  1473. constructor() {
  1474. // 从localStorage获取或创建新的会话ID
  1475. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1476. localStorage.setItem('chat_session_id', this.sessionId);
  1477. }
  1478. generateSessionId() {
  1479. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1480. }
  1481. async askQuestion(question) {
  1482. const response = await fetch('/api/v0/ask', {
  1483. method: 'POST',
  1484. headers: {
  1485. 'Content-Type': 'application/json',
  1486. },
  1487. body: JSON.stringify({
  1488. question: question,
  1489. session_id: this.sessionId // 关键:传递会话ID
  1490. })
  1491. });
  1492. return await response.json();
  1493. }
  1494. // 开始新会话
  1495. startNewSession() {
  1496. this.sessionId = this.generateSessionId();
  1497. localStorage.setItem('chat_session_id', this.sessionId);
  1498. }
  1499. }
  1500. // 使用示例
  1501. const chatSession = new ChatSession();
  1502. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1503. """
  1504. print("正在启动Flask应用: http://localhost:8084")
  1505. app.run(host="0.0.0.0", port=8084, debug=True)