citu_app.py 111 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. import sqlparse # 用于SQL语法检查
  15. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  16. from common.qa_feedback_manager import QAFeedbackManager
  17. from common.result import success_response, bad_request_response, not_found_response, internal_error_response
  18. from common.result import ( # 统一导入所有需要的响应函数
  19. bad_request_response, service_unavailable_response,
  20. agent_success_response, agent_error_response,
  21. internal_error_response, success_response,
  22. validation_failed_response
  23. )
  24. from app_config import ( # 添加Redis相关配置导入
  25. USER_MAX_CONVERSATIONS,
  26. CONVERSATION_CONTEXT_COUNT,
  27. DEFAULT_ANONYMOUS_USER,
  28. ENABLE_QUESTION_ANSWER_CACHE
  29. )
  30. # 设置默认的最大返回行数
  31. DEFAULT_MAX_RETURN_ROWS = 200
  32. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  33. vn = create_vanna_instance()
  34. # 创建带时间戳的缓存
  35. timestamped_cache = WebSessionAwareMemoryCache()
  36. # 实例化 VannaFlaskApp,使用自定义缓存
  37. app = VannaFlaskApp(
  38. vn,
  39. cache=timestamped_cache, # 使用带时间戳的缓存
  40. title="辞图智能数据问答平台",
  41. logo = "https://www.citupro.com/img/logo-black-2.png",
  42. subtitle="让 AI 为你写 SQL",
  43. chart=False,
  44. allow_llm_to_see_data=True,
  45. ask_results_correct=True,
  46. followup_questions=True,
  47. debug=True
  48. )
  49. # 创建Redis对话管理器实例
  50. redis_conversation_manager = RedisConversationManager()
  51. # 修改ask接口,支持前端传递session_id
  52. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  53. def ask_full():
  54. req = request.get_json(force=True)
  55. question = req.get("question", None)
  56. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  57. if not question:
  58. from common.result import bad_request_response
  59. return jsonify(bad_request_response(
  60. response_text="缺少必需参数:question",
  61. missing_params=["question"]
  62. )), 400
  63. # 如果使用WebSessionAwareMemoryCache
  64. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  65. # 这里需要修改vanna的ask方法来支持传递session_id
  66. # 或者预先调用generate_id来建立会话关联
  67. conversation_id = app.cache.generate_id_with_browser_session(
  68. question=question,
  69. browser_session_id=browser_session_id
  70. )
  71. try:
  72. sql, df, _ = vn.ask(
  73. question=question,
  74. print_results=False,
  75. visualize=False,
  76. allow_llm_to_see_data=True
  77. )
  78. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  79. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  80. # 在解释性文本末尾添加提示语
  81. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  82. # 使用标准化错误响应
  83. from common.result import validation_failed_response
  84. return jsonify(validation_failed_response(
  85. response_text=explanation_message
  86. )), 422 # 修改HTTP状态码为422
  87. # 如果sql为None但没有解释性文本,返回通用错误
  88. if sql is None:
  89. from common.result import validation_failed_response
  90. return jsonify(validation_failed_response(
  91. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  92. )), 422
  93. # 处理返回数据 - 使用新的query_result结构
  94. query_result = {
  95. "rows": [],
  96. "columns": [],
  97. "row_count": 0,
  98. "is_limited": False,
  99. "total_row_count": 0
  100. }
  101. summary = None
  102. if isinstance(df, pd.DataFrame):
  103. query_result["columns"] = list(df.columns)
  104. if not df.empty:
  105. total_rows = len(df)
  106. limited_df = df.head(MAX_RETURN_ROWS)
  107. query_result["rows"] = limited_df.to_dict(orient="records")
  108. query_result["row_count"] = len(limited_df)
  109. query_result["total_row_count"] = total_rows
  110. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  111. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  112. if ENABLE_RESULT_SUMMARY:
  113. try:
  114. summary = vn.generate_summary(question=question, df=df)
  115. print(f"[INFO] 成功生成摘要: {summary}")
  116. except Exception as e:
  117. print(f"[WARNING] 生成摘要失败: {str(e)}")
  118. summary = None
  119. # 构建返回数据
  120. response_data = {
  121. "sql": sql,
  122. "query_result": query_result,
  123. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  124. "session_id": browser_session_id
  125. }
  126. # 添加摘要(如果启用且生成成功)
  127. if ENABLE_RESULT_SUMMARY and summary is not None:
  128. response_data["summary"] = summary
  129. response_data["response"] = summary # 同时添加response字段
  130. from common.result import success_response
  131. return jsonify(success_response(
  132. response_text="查询执行完成" if summary is None else None,
  133. data=response_data
  134. ))
  135. except Exception as e:
  136. print(f"[ERROR] ask_full执行失败: {str(e)}")
  137. # 即使发生异常,也检查是否有业务层面的解释
  138. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  139. # 在解释性文本末尾添加提示语
  140. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  141. from common.result import validation_failed_response
  142. return jsonify(validation_failed_response(
  143. response_text=explanation_message
  144. )), 422
  145. else:
  146. # 技术错误,使用500错误码
  147. from common.result import internal_error_response
  148. return jsonify(internal_error_response(
  149. response_text="查询处理失败,请稍后重试"
  150. )), 500
  151. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  152. def citu_run_sql():
  153. req = request.get_json(force=True)
  154. sql = req.get('sql')
  155. if not sql:
  156. from common.result import bad_request_response
  157. return jsonify(bad_request_response(
  158. response_text="缺少必需参数:sql",
  159. missing_params=["sql"]
  160. )), 400
  161. try:
  162. df = vn.run_sql(sql)
  163. # 处理返回数据 - 使用新的query_result结构
  164. query_result = {
  165. "rows": [],
  166. "columns": [],
  167. "row_count": 0,
  168. "is_limited": False,
  169. "total_row_count": 0
  170. }
  171. if isinstance(df, pd.DataFrame):
  172. query_result["columns"] = list(df.columns)
  173. if not df.empty:
  174. total_rows = len(df)
  175. limited_df = df.head(MAX_RETURN_ROWS)
  176. query_result["rows"] = limited_df.to_dict(orient="records")
  177. query_result["row_count"] = len(limited_df)
  178. query_result["total_row_count"] = total_rows
  179. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  180. from common.result import success_response
  181. return jsonify(success_response(
  182. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  183. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  184. data={
  185. "sql": sql,
  186. "query_result": query_result
  187. }
  188. ))
  189. except Exception as e:
  190. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  191. from common.result import internal_error_response
  192. return jsonify(internal_error_response(
  193. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  194. )), 500
  195. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  196. def ask_cached():
  197. """
  198. 带缓存功能的智能查询接口
  199. 支持会话管理和结果缓存,提高查询效率
  200. """
  201. req = request.get_json(force=True)
  202. question = req.get("question", None)
  203. browser_session_id = req.get("session_id", None)
  204. if not question:
  205. from common.result import bad_request_response
  206. return jsonify(bad_request_response(
  207. response_text="缺少必需参数:question",
  208. missing_params=["question"]
  209. )), 400
  210. try:
  211. # 生成conversation_id
  212. # 调试:查看generate_id的实际行为
  213. print(f"[DEBUG] 输入问题: '{question}'")
  214. conversation_id = app.cache.generate_id(question=question)
  215. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  216. # 再次用相同问题测试
  217. conversation_id2 = app.cache.generate_id(question=question)
  218. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  219. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  220. # 检查缓存
  221. cached_sql = app.cache.get(id=conversation_id, field="sql")
  222. if cached_sql is not None:
  223. # 缓存命中
  224. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  225. sql = cached_sql
  226. df = app.cache.get(id=conversation_id, field="df")
  227. summary = app.cache.get(id=conversation_id, field="summary")
  228. else:
  229. # 缓存未命中,执行新查询
  230. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  231. sql, df, _ = vn.ask(
  232. question=question,
  233. print_results=False,
  234. visualize=False,
  235. allow_llm_to_see_data=True
  236. )
  237. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  238. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  239. # 在解释性文本末尾添加提示语
  240. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  241. from common.result import validation_failed_response
  242. return jsonify(validation_failed_response(
  243. response_text=explanation_message
  244. )), 422
  245. # 如果sql为None但没有解释性文本,返回通用错误
  246. if sql is None:
  247. from common.result import validation_failed_response
  248. return jsonify(validation_failed_response(
  249. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  250. )), 422
  251. # 缓存结果
  252. app.cache.set(id=conversation_id, field="question", value=question)
  253. app.cache.set(id=conversation_id, field="sql", value=sql)
  254. app.cache.set(id=conversation_id, field="df", value=df)
  255. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  256. summary = None
  257. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  258. try:
  259. summary = vn.generate_summary(question=question, df=df)
  260. print(f"[INFO] 成功生成摘要: {summary}")
  261. except Exception as e:
  262. print(f"[WARNING] 生成摘要失败: {str(e)}")
  263. summary = None
  264. app.cache.set(id=conversation_id, field="summary", value=summary)
  265. # 处理返回数据 - 使用新的query_result结构
  266. query_result = {
  267. "rows": [],
  268. "columns": [],
  269. "row_count": 0,
  270. "is_limited": False,
  271. "total_row_count": 0
  272. }
  273. if isinstance(df, pd.DataFrame):
  274. query_result["columns"] = list(df.columns)
  275. if not df.empty:
  276. total_rows = len(df)
  277. limited_df = df.head(MAX_RETURN_ROWS)
  278. query_result["rows"] = limited_df.to_dict(orient="records")
  279. query_result["row_count"] = len(limited_df)
  280. query_result["total_row_count"] = total_rows
  281. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  282. # 构建返回数据
  283. response_data = {
  284. "sql": sql,
  285. "query_result": query_result,
  286. "conversation_id": conversation_id,
  287. "session_id": browser_session_id,
  288. "cached": cached_sql is not None # 标识是否来自缓存
  289. }
  290. # 添加摘要(如果启用且生成成功)
  291. if ENABLE_RESULT_SUMMARY and summary is not None:
  292. response_data["summary"] = summary
  293. response_data["response"] = summary # 同时添加response字段
  294. from common.result import success_response
  295. return jsonify(success_response(
  296. response_text="查询执行完成" if summary is None else None,
  297. data=response_data
  298. ))
  299. except Exception as e:
  300. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  301. from common.result import internal_error_response
  302. return jsonify(internal_error_response(
  303. response_text="查询处理失败,请稍后重试"
  304. )), 500
  305. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  306. def citu_train_question_sql():
  307. """
  308. 训练问题-SQL对接口
  309. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  310. 支持仅传入SQL或同时传入问题和SQL进行训练。
  311. Args:
  312. question (str, optional): 用户问题
  313. sql (str, required): 对应的SQL查询语句
  314. Returns:
  315. JSON: 包含训练ID和成功消息的响应
  316. """
  317. try:
  318. req = request.get_json(force=True)
  319. question = req.get('question')
  320. sql = req.get('sql')
  321. if not sql:
  322. from common.result import bad_request_response
  323. return jsonify(bad_request_response(
  324. response_text="缺少必需参数:sql",
  325. missing_params=["sql"]
  326. )), 400
  327. # 正确的调用方式:同时传递question和sql
  328. if question:
  329. training_id = vn.train(question=question, sql=sql)
  330. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  331. else:
  332. training_id = vn.train(sql=sql)
  333. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  334. from common.result import success_response
  335. return jsonify(success_response(
  336. response_text="问题-SQL对训练成功",
  337. data={
  338. "training_id": training_id,
  339. "message": "Question-SQL pair trained successfully"
  340. }
  341. ))
  342. except Exception as e:
  343. from common.result import internal_error_response
  344. return jsonify(internal_error_response(
  345. response_text="训练失败,请稍后重试"
  346. )), 500
  347. # ============ LangGraph Agent 集成 ============
  348. # 全局Agent实例(单例模式)
  349. citu_langraph_agent = None
  350. def get_citu_langraph_agent():
  351. """获取LangGraph Agent实例(懒加载)"""
  352. global citu_langraph_agent
  353. if citu_langraph_agent is None:
  354. try:
  355. from agent.citu_agent import CituLangGraphAgent
  356. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  357. citu_langraph_agent = CituLangGraphAgent()
  358. print("[CITU_APP] LangGraph Agent实例创建成功")
  359. except ImportError as e:
  360. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  361. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  362. raise Exception(f"Agent模块导入失败: {str(e)}")
  363. except Exception as e:
  364. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  365. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  366. # 提供更有用的错误信息
  367. if "config" in str(e).lower():
  368. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  369. elif "llm" in str(e).lower():
  370. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  371. elif "tool" in str(e).lower():
  372. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  373. raise Exception(f"Agent初始化失败: {str(e)}")
  374. return citu_langraph_agent
  375. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  376. def ask_agent():
  377. """
  378. 支持对话上下文的ask_agent API - 修正版
  379. """
  380. req = request.get_json(force=True)
  381. question = req.get("question", None)
  382. browser_session_id = req.get("session_id", None)
  383. # 新增参数解析
  384. user_id_input = req.get("user_id", None)
  385. conversation_id_input = req.get("conversation_id", None)
  386. continue_conversation = req.get("continue_conversation", False)
  387. # 新增:路由模式参数解析和验证
  388. api_routing_mode = req.get("routing_mode", None)
  389. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  390. if not question:
  391. return jsonify(bad_request_response(
  392. response_text="缺少必需参数:question",
  393. missing_params=["question"]
  394. )), 400
  395. # 验证routing_mode参数
  396. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  397. return jsonify(bad_request_response(
  398. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  399. invalid_params=["routing_mode"]
  400. )), 400
  401. try:
  402. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  403. login_user_id = session.get('user_id') if 'user_id' in session else None
  404. # 2. 智能ID解析(修正:传入登录用户ID)
  405. user_id = redis_conversation_manager.resolve_user_id(
  406. user_id_input, browser_session_id, request.remote_addr, login_user_id
  407. )
  408. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  409. user_id, conversation_id_input, continue_conversation
  410. )
  411. # 3. 获取上下文和上下文类型(提前到缓存检查之前)
  412. context = redis_conversation_manager.get_context(conversation_id)
  413. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  414. context_type = None
  415. if context:
  416. try:
  417. # 获取最后一条助手消息的metadata
  418. messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
  419. for message in reversed(messages): # 从最新的开始找
  420. if message.get("role") == "assistant":
  421. metadata = message.get("metadata", {})
  422. context_type = metadata.get("type")
  423. if context_type:
  424. print(f"[AGENT_API] 检测到上下文类型: {context_type}")
  425. break
  426. except Exception as e:
  427. print(f"[WARNING] 获取上下文类型失败: {str(e)}")
  428. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  429. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  430. if cached_answer:
  431. print(f"[AGENT_API] 使用缓存答案")
  432. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  433. cached_response_type = cached_answer.get("type", "UNKNOWN")
  434. if cached_response_type == "DATABASE":
  435. # DATABASE类型:按优先级选择内容
  436. if cached_answer.get("response"):
  437. # 优先级1:错误或解释性回复(如SQL生成失败)
  438. assistant_response = cached_answer.get("response")
  439. elif cached_answer.get("summary"):
  440. # 优先级2:查询成功的摘要
  441. assistant_response = cached_answer.get("summary")
  442. elif cached_answer.get("query_result"):
  443. # 优先级3:构造简单描述
  444. query_result = cached_answer.get("query_result")
  445. row_count = query_result.get("row_count", 0)
  446. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  447. else:
  448. # 异常情况
  449. assistant_response = "数据库查询已处理。"
  450. else:
  451. # CHAT类型:直接使用response
  452. assistant_response = cached_answer.get("response", "")
  453. # 更新对话历史
  454. redis_conversation_manager.save_message(conversation_id, "user", question)
  455. redis_conversation_manager.save_message(
  456. conversation_id, "assistant",
  457. assistant_response,
  458. metadata={"from_cache": True}
  459. )
  460. # 添加对话信息到缓存结果
  461. cached_answer["conversation_id"] = conversation_id
  462. cached_answer["user_id"] = user_id
  463. cached_answer["from_cache"] = True
  464. cached_answer.update(conversation_status)
  465. # 使用agent_success_response返回标准格式
  466. return jsonify(agent_success_response(
  467. response_type=cached_answer.get("type", "UNKNOWN"),
  468. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  469. sql=cached_answer.get("sql"),
  470. records=cached_answer.get("query_result"), # 修改:query_result改为records
  471. summary=cached_answer.get("summary"),
  472. session_id=browser_session_id,
  473. execution_path=cached_answer.get("execution_path", []),
  474. classification_info=cached_answer.get("classification_info", {}),
  475. conversation_id=conversation_id,
  476. user_id=user_id,
  477. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  478. context_used=bool(context),
  479. from_cache=True,
  480. conversation_status=conversation_status["status"],
  481. conversation_message=conversation_status["message"],
  482. requested_conversation_id=conversation_status.get("requested_id")
  483. ))
  484. # 5. 保存用户消息
  485. redis_conversation_manager.save_message(conversation_id, "user", question)
  486. # 6. 构建带上下文的问题
  487. if context:
  488. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  489. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  490. else:
  491. enhanced_question = question
  492. print(f"[AGENT_API] 新对话,无上下文")
  493. # 7. 确定最终使用的路由模式(优先级逻辑)
  494. if api_routing_mode:
  495. # API传了参数,优先使用
  496. effective_routing_mode = api_routing_mode
  497. print(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
  498. else:
  499. # API没传参数,使用配置文件
  500. try:
  501. from app_config import QUESTION_ROUTING_MODE
  502. effective_routing_mode = QUESTION_ROUTING_MODE
  503. print(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
  504. except ImportError:
  505. effective_routing_mode = "hybrid"
  506. print(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  507. # 8. 现有Agent处理逻辑(修改为传递路由模式)
  508. try:
  509. agent = get_citu_langraph_agent()
  510. except Exception as e:
  511. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  512. return jsonify(service_unavailable_response(
  513. response_text="AI服务暂时不可用,请稍后重试",
  514. can_retry=True
  515. )), 503
  516. # 异步调用Agent处理问题
  517. import asyncio
  518. agent_result = asyncio.run(agent.process_question(
  519. question=enhanced_question, # 使用增强后的问题
  520. session_id=browser_session_id,
  521. context_type=context_type, # 传递上下文类型
  522. routing_mode=effective_routing_mode # 新增:传递路由模式
  523. ))
  524. # 8. 处理Agent结果
  525. if agent_result.get("success", False):
  526. # 修正:直接从agent_result获取字段,因为它就是final_response
  527. response_type = agent_result.get("type", "UNKNOWN")
  528. response_text = agent_result.get("response", "")
  529. sql = agent_result.get("sql")
  530. query_result = agent_result.get("query_result")
  531. summary = agent_result.get("summary")
  532. execution_path = agent_result.get("execution_path", [])
  533. classification_info = agent_result.get("classification_info", {})
  534. # 确定助手回复内容的优先级
  535. if response_type == "DATABASE":
  536. # DATABASE类型:按优先级选择内容
  537. if response_text:
  538. # 优先级1:错误或解释性回复(如SQL生成失败)
  539. assistant_response = response_text
  540. elif summary:
  541. # 优先级2:查询成功的摘要
  542. assistant_response = summary
  543. elif query_result:
  544. # 优先级3:构造简单描述
  545. row_count = query_result.get("row_count", 0)
  546. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  547. else:
  548. # 异常情况
  549. assistant_response = "数据库查询已处理。"
  550. else:
  551. # CHAT类型:直接使用response
  552. assistant_response = response_text
  553. # 保存助手回复
  554. redis_conversation_manager.save_message(
  555. conversation_id, "assistant", assistant_response,
  556. metadata={
  557. "type": response_type,
  558. "sql": sql,
  559. "execution_path": execution_path
  560. }
  561. )
  562. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  563. # 直接缓存agent_result,它已经包含所有需要的字段
  564. redis_conversation_manager.cache_answer(question, agent_result, context)
  565. # 使用agent_success_response的正确方式
  566. return jsonify(agent_success_response(
  567. response_type=response_type,
  568. response=response_text, # 修正:使用response而不是response_text
  569. sql=sql,
  570. records=query_result, # 修改:query_result改为records
  571. summary=summary,
  572. session_id=browser_session_id,
  573. execution_path=execution_path,
  574. classification_info=classification_info,
  575. conversation_id=conversation_id,
  576. user_id=user_id,
  577. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  578. context_used=bool(context),
  579. from_cache=False,
  580. conversation_status=conversation_status["status"],
  581. conversation_message=conversation_status["message"],
  582. requested_conversation_id=conversation_status.get("requested_id"),
  583. routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
  584. routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
  585. ))
  586. else:
  587. # 错误处理(修正:确保使用现有的错误响应格式)
  588. error_message = agent_result.get("error", "Agent处理失败")
  589. error_code = agent_result.get("error_code", 500)
  590. return jsonify(agent_error_response(
  591. response_text=error_message,
  592. error_type="agent_processing_failed",
  593. code=error_code,
  594. session_id=browser_session_id,
  595. conversation_id=conversation_id,
  596. user_id=user_id
  597. )), error_code
  598. except Exception as e:
  599. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  600. return jsonify(internal_error_response(
  601. response_text="查询处理失败,请稍后重试"
  602. )), 500
  603. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  604. def agent_health():
  605. """
  606. Agent健康检查接口
  607. 响应格式:
  608. {
  609. "success": true/false,
  610. "code": 200/503,
  611. "message": "healthy/degraded/unhealthy",
  612. "data": {
  613. "status": "healthy/degraded/unhealthy",
  614. "test_result": true/false,
  615. "workflow_compiled": true/false,
  616. "tools_count": 4,
  617. "message": "详细信息",
  618. "timestamp": "2024-01-01T12:00:00",
  619. "checks": {
  620. "agent_creation": true/false,
  621. "tools_import": true/false,
  622. "llm_connection": true/false,
  623. "classifier_ready": true/false
  624. }
  625. }
  626. }
  627. """
  628. try:
  629. # 基础健康检查
  630. health_data = {
  631. "status": "unknown",
  632. "test_result": False,
  633. "workflow_compiled": False,
  634. "tools_count": 0,
  635. "message": "",
  636. "timestamp": datetime.now().isoformat(),
  637. "checks": {
  638. "agent_creation": False,
  639. "tools_import": False,
  640. "llm_connection": False,
  641. "classifier_ready": False
  642. }
  643. }
  644. # 检查1: Agent创建
  645. try:
  646. agent = get_citu_langraph_agent()
  647. health_data["checks"]["agent_creation"] = True
  648. # 修正:Agent现在是动态创建workflow的,不再有预创建的workflow属性
  649. health_data["workflow_compiled"] = True # 动态创建,始终可用
  650. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  651. except Exception as e:
  652. health_data["message"] = f"Agent创建失败: {str(e)}"
  653. health_data["status"] = "unhealthy" # 设置状态
  654. from common.result import health_error_response
  655. return jsonify(health_error_response(**health_data)), 503
  656. # 检查2: 工具导入
  657. try:
  658. from agent.tools import TOOLS
  659. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  660. except Exception as e:
  661. health_data["message"] = f"工具导入失败: {str(e)}"
  662. # 检查3: LLM连接(简单测试)
  663. try:
  664. from agent.utils import get_compatible_llm
  665. llm = get_compatible_llm()
  666. health_data["checks"]["llm_connection"] = llm is not None
  667. except Exception as e:
  668. health_data["message"] = f"LLM连接失败: {str(e)}"
  669. # 检查4: 分类器准备
  670. try:
  671. from agent.classifier import QuestionClassifier
  672. classifier = QuestionClassifier()
  673. health_data["checks"]["classifier_ready"] = True
  674. except Exception as e:
  675. health_data["message"] = f"分类器失败: {str(e)}"
  676. # 检查5: 完整流程测试(可选)
  677. try:
  678. if all(health_data["checks"].values()):
  679. import asyncio
  680. # 异步调用健康检查
  681. test_result = asyncio.run(agent.health_check())
  682. health_data["test_result"] = test_result.get("status") == "healthy"
  683. health_data["status"] = test_result.get("status", "unknown")
  684. health_data["message"] = test_result.get("message", "健康检查完成")
  685. else:
  686. health_data["status"] = "degraded"
  687. health_data["message"] = "部分组件异常"
  688. except Exception as e:
  689. print(f"[ERROR] 健康检查异常: {str(e)}")
  690. import traceback
  691. print(f"[ERROR] 详细健康检查错误: {traceback.format_exc()}")
  692. health_data["status"] = "degraded"
  693. health_data["message"] = f"完整测试失败: {str(e)}"
  694. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  695. from common.result import health_success_response, health_error_response
  696. if health_data["status"] == "healthy":
  697. return jsonify(health_success_response(**health_data))
  698. elif health_data["status"] == "degraded":
  699. return jsonify(health_error_response(**health_data)), 503
  700. else:
  701. # 确保状态设置为unhealthy
  702. health_data["status"] = "unhealthy"
  703. return jsonify(health_error_response(**health_data)), 503
  704. except Exception as e:
  705. print(f"[ERROR] 顶层健康检查异常: {str(e)}")
  706. import traceback
  707. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  708. from common.result import internal_error_response
  709. return jsonify(internal_error_response(
  710. response_text="健康检查失败,请稍后重试"
  711. )), 500
  712. # ==================== 日常管理API ====================
  713. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  714. def cache_overview():
  715. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  716. try:
  717. cache = app.cache
  718. result_data = {
  719. 'overview_summary': {
  720. 'total_conversations': 0,
  721. 'total_sessions': 0,
  722. 'query_time': datetime.now().isoformat()
  723. },
  724. 'recent_conversations': [], # 最近的对话
  725. 'session_summary': [] # 会话摘要
  726. }
  727. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  728. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  729. # 获取会话信息
  730. if hasattr(cache, 'get_all_sessions'):
  731. all_sessions = cache.get_all_sessions()
  732. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  733. # 会话摘要(按最近活动排序)
  734. session_list = []
  735. for session_id, session_data in all_sessions.items():
  736. session_summary = {
  737. 'session_id': session_id,
  738. 'start_time': session_data['start_time'].isoformat(),
  739. 'conversation_count': session_data.get('conversation_count', 0),
  740. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  741. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  742. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  743. }
  744. session_list.append(session_summary)
  745. # 按最后活动时间排序
  746. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  747. result_data['session_summary'] = session_list
  748. # 最近的对话(最多显示10个)
  749. conversation_list = []
  750. for conversation_id, conversation_data in cache.cache.items():
  751. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  752. conversation_info = {
  753. 'conversation_id': conversation_id,
  754. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  755. 'session_id': cache.conversation_to_session.get(conversation_id),
  756. 'has_question': 'question' in conversation_data,
  757. 'has_sql': 'sql' in conversation_data,
  758. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  759. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  760. }
  761. # 计算对话持续时间
  762. if conversation_start_time:
  763. duration = datetime.now() - conversation_start_time
  764. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  765. conversation_list.append(conversation_info)
  766. # 按对话开始时间排序,显示最新的10个
  767. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  768. result_data['recent_conversations'] = conversation_list[:10]
  769. from common.result import success_response
  770. return jsonify(success_response(
  771. response_text="缓存概览查询完成",
  772. data=result_data
  773. ))
  774. except Exception as e:
  775. from common.result import internal_error_response
  776. return jsonify(internal_error_response(
  777. response_text="获取缓存概览失败,请稍后重试"
  778. )), 500
  779. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  780. def cache_stats():
  781. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  782. try:
  783. cache = app.cache
  784. current_time = datetime.now()
  785. stats = {
  786. 'basic_stats': {
  787. 'total_sessions': len(getattr(cache, 'session_info', {})),
  788. 'total_conversations': len(getattr(cache, 'cache', {})),
  789. 'active_sessions': 0, # 最近30分钟有活动
  790. 'average_conversations_per_session': 0
  791. },
  792. 'time_distribution': {
  793. 'sessions': {
  794. 'last_1_hour': 0,
  795. 'last_6_hours': 0,
  796. 'last_24_hours': 0,
  797. 'last_7_days': 0,
  798. 'older': 0
  799. },
  800. 'conversations': {
  801. 'last_1_hour': 0,
  802. 'last_6_hours': 0,
  803. 'last_24_hours': 0,
  804. 'last_7_days': 0,
  805. 'older': 0
  806. }
  807. },
  808. 'session_details': [],
  809. 'time_ranges': {
  810. 'oldest_session': None,
  811. 'newest_session': None,
  812. 'oldest_conversation': None,
  813. 'newest_conversation': None
  814. }
  815. }
  816. # 会话统计
  817. if hasattr(cache, 'session_info'):
  818. session_times = []
  819. total_conversations = 0
  820. for session_id, session_data in cache.session_info.items():
  821. start_time = session_data['start_time']
  822. session_times.append(start_time)
  823. conversation_count = len(session_data.get('conversations', []))
  824. total_conversations += conversation_count
  825. # 检查活跃状态
  826. last_activity = session_data.get('last_activity', session_data['start_time'])
  827. if (current_time - last_activity).total_seconds() < 1800:
  828. stats['basic_stats']['active_sessions'] += 1
  829. # 时间分布统计
  830. age_hours = (current_time - start_time).total_seconds() / 3600
  831. if age_hours <= 1:
  832. stats['time_distribution']['sessions']['last_1_hour'] += 1
  833. elif age_hours <= 6:
  834. stats['time_distribution']['sessions']['last_6_hours'] += 1
  835. elif age_hours <= 24:
  836. stats['time_distribution']['sessions']['last_24_hours'] += 1
  837. elif age_hours <= 168: # 7 days
  838. stats['time_distribution']['sessions']['last_7_days'] += 1
  839. else:
  840. stats['time_distribution']['sessions']['older'] += 1
  841. # 会话详细信息
  842. session_duration = current_time - start_time
  843. stats['session_details'].append({
  844. 'session_id': session_id,
  845. 'start_time': start_time.isoformat(),
  846. 'last_activity': last_activity.isoformat(),
  847. 'conversation_count': conversation_count,
  848. 'duration_seconds': session_duration.total_seconds(),
  849. 'duration_formatted': str(session_duration),
  850. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  851. 'browser_session_id': session_data.get('browser_session_id')
  852. })
  853. # 计算平均值
  854. if len(cache.session_info) > 0:
  855. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  856. # 时间范围
  857. if session_times:
  858. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  859. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  860. # 对话统计
  861. if hasattr(cache, 'conversation_start_times'):
  862. conversation_times = []
  863. for conv_time in cache.conversation_start_times.values():
  864. conversation_times.append(conv_time)
  865. age_hours = (current_time - conv_time).total_seconds() / 3600
  866. if age_hours <= 1:
  867. stats['time_distribution']['conversations']['last_1_hour'] += 1
  868. elif age_hours <= 6:
  869. stats['time_distribution']['conversations']['last_6_hours'] += 1
  870. elif age_hours <= 24:
  871. stats['time_distribution']['conversations']['last_24_hours'] += 1
  872. elif age_hours <= 168:
  873. stats['time_distribution']['conversations']['last_7_days'] += 1
  874. else:
  875. stats['time_distribution']['conversations']['older'] += 1
  876. if conversation_times:
  877. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  878. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  879. # 按最近活动排序会话详情
  880. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  881. from common.result import success_response
  882. return jsonify(success_response(
  883. response_text="缓存统计信息查询完成",
  884. data=stats
  885. ))
  886. except Exception as e:
  887. from common.result import internal_error_response
  888. return jsonify(internal_error_response(
  889. response_text="获取缓存统计失败,请稍后重试"
  890. )), 500
  891. # ==================== 高级功能API ====================
  892. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  893. def cache_export():
  894. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  895. try:
  896. cache = app.cache
  897. # 验证缓存的实际结构
  898. if not hasattr(cache, 'cache'):
  899. from common.result import internal_error_response
  900. return jsonify(internal_error_response(
  901. response_text="缓存对象结构异常,请联系系统管理员"
  902. )), 500
  903. if not isinstance(cache.cache, dict):
  904. from common.result import internal_error_response
  905. return jsonify(internal_error_response(
  906. response_text="缓存数据类型异常,请联系系统管理员"
  907. )), 500
  908. # 定义JSON序列化辅助函数
  909. def make_json_serializable(obj):
  910. """将对象转换为JSON可序列化的格式"""
  911. if obj is None:
  912. return None
  913. elif isinstance(obj, (str, int, float, bool)):
  914. return obj
  915. elif isinstance(obj, (list, tuple)):
  916. return [make_json_serializable(item) for item in obj]
  917. elif isinstance(obj, dict):
  918. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  919. elif hasattr(obj, 'isoformat'): # datetime objects
  920. return obj.isoformat()
  921. elif hasattr(obj, 'item'): # numpy scalars
  922. return obj.item()
  923. elif hasattr(obj, 'tolist'): # numpy arrays
  924. return obj.tolist()
  925. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  926. return str(obj)
  927. else:
  928. return str(obj)
  929. # 获取完整的原始缓存数据
  930. raw_cache = cache.cache
  931. # 获取会话和对话时间信息
  932. conversation_times = getattr(cache, 'conversation_start_times', {})
  933. session_info = getattr(cache, 'session_info', {})
  934. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  935. export_data = {
  936. 'export_metadata': {
  937. 'export_time': datetime.now().isoformat(),
  938. 'total_conversations': len(raw_cache),
  939. 'total_sessions': len(session_info),
  940. 'cache_type': type(cache).__name__,
  941. 'cache_object_info': str(cache),
  942. 'has_session_times': bool(session_info),
  943. 'has_conversation_times': bool(conversation_times)
  944. },
  945. 'session_info': {
  946. session_id: {
  947. 'start_time': session_data['start_time'].isoformat(),
  948. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  949. 'conversations': session_data['conversations'],
  950. 'conversation_count': len(session_data['conversations']),
  951. 'browser_session_id': session_data.get('browser_session_id'),
  952. 'user_info': session_data.get('user_info', {})
  953. }
  954. for session_id, session_data in session_info.items()
  955. },
  956. 'conversation_times': {
  957. conversation_id: start_time.isoformat()
  958. for conversation_id, start_time in conversation_times.items()
  959. },
  960. 'conversation_to_session_mapping': conversation_to_session,
  961. 'conversations': {}
  962. }
  963. # 处理每个对话的完整数据
  964. for conversation_id, conversation_data in raw_cache.items():
  965. # 获取时间信息
  966. conversation_start_time = conversation_times.get(conversation_id)
  967. session_id = conversation_to_session.get(conversation_id)
  968. session_start_time = None
  969. if session_id and session_id in session_info:
  970. session_start_time = session_info[session_id]['start_time']
  971. processed_conversation = {
  972. 'conversation_id': conversation_id,
  973. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  974. 'session_id': session_id,
  975. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  976. 'field_count': len(conversation_data),
  977. 'fields': {}
  978. }
  979. # 添加时间计算
  980. if conversation_start_time:
  981. conversation_duration = datetime.now() - conversation_start_time
  982. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  983. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  984. if session_start_time:
  985. session_duration = datetime.now() - session_start_time
  986. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  987. processed_conversation['session_duration_formatted'] = str(session_duration)
  988. # 处理每个字段,确保JSON序列化安全
  989. for field_name, field_value in conversation_data.items():
  990. field_info = {
  991. 'field_name': field_name,
  992. 'data_type': type(field_value).__name__,
  993. 'is_none': field_value is None
  994. }
  995. try:
  996. if field_value is None:
  997. field_info['value'] = None
  998. elif field_name in ['conversation_start_time', 'session_start_time']:
  999. # 处理时间字段
  1000. field_info['content'] = make_json_serializable(field_value)
  1001. elif field_name == 'df' and field_value is not None:
  1002. # DataFrame的安全处理
  1003. if hasattr(field_value, 'to_dict'):
  1004. # 安全地处理dtypes
  1005. try:
  1006. dtypes_dict = {}
  1007. for col, dtype in field_value.dtypes.items():
  1008. dtypes_dict[col] = str(dtype)
  1009. except Exception:
  1010. dtypes_dict = {"error": "无法序列化dtypes"}
  1011. # 安全地处理内存使用
  1012. try:
  1013. memory_usage = field_value.memory_usage(deep=True)
  1014. memory_dict = {}
  1015. for idx, usage in memory_usage.items():
  1016. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  1017. except Exception:
  1018. memory_dict = {"error": "无法获取内存使用信息"}
  1019. field_info.update({
  1020. 'dataframe_info': {
  1021. 'shape': list(field_value.shape),
  1022. 'columns': list(field_value.columns),
  1023. 'dtypes': dtypes_dict,
  1024. 'index_info': {
  1025. 'type': type(field_value.index).__name__,
  1026. 'length': len(field_value.index)
  1027. }
  1028. },
  1029. 'data': make_json_serializable(field_value.to_dict('records')),
  1030. 'memory_usage': memory_dict
  1031. })
  1032. else:
  1033. field_info['value'] = str(field_value)
  1034. field_info['note'] = 'not_standard_dataframe'
  1035. elif field_name == 'fig_json':
  1036. # 图表JSON数据处理
  1037. if isinstance(field_value, str):
  1038. try:
  1039. import json
  1040. parsed_fig = json.loads(field_value)
  1041. field_info.update({
  1042. 'json_valid': True,
  1043. 'json_size_bytes': len(field_value),
  1044. 'plotly_structure': {
  1045. 'has_data': 'data' in parsed_fig,
  1046. 'has_layout': 'layout' in parsed_fig,
  1047. 'data_traces_count': len(parsed_fig.get('data', [])),
  1048. },
  1049. 'raw_json': field_value
  1050. })
  1051. except json.JSONDecodeError:
  1052. field_info.update({
  1053. 'json_valid': False,
  1054. 'raw_content': str(field_value)
  1055. })
  1056. else:
  1057. field_info['value'] = make_json_serializable(field_value)
  1058. elif field_name == 'followup_questions':
  1059. # 后续问题列表
  1060. field_info.update({
  1061. 'content': make_json_serializable(field_value)
  1062. })
  1063. elif field_name in ['question', 'sql', 'summary']:
  1064. # 文本字段
  1065. if isinstance(field_value, str):
  1066. field_info.update({
  1067. 'text_length': len(field_value),
  1068. 'content': field_value
  1069. })
  1070. else:
  1071. field_info['value'] = make_json_serializable(field_value)
  1072. else:
  1073. # 未知字段的安全处理
  1074. field_info['content'] = make_json_serializable(field_value)
  1075. except Exception as e:
  1076. field_info.update({
  1077. 'processing_error': str(e),
  1078. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1079. })
  1080. processed_conversation['fields'][field_name] = field_info
  1081. export_data['conversations'][conversation_id] = processed_conversation
  1082. # 添加缓存统计信息
  1083. field_frequency = {}
  1084. data_types_found = set()
  1085. total_dataframes = 0
  1086. total_questions = 0
  1087. for conv_data in export_data['conversations'].values():
  1088. for field_name, field_info in conv_data['fields'].items():
  1089. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1090. data_types_found.add(field_info['data_type'])
  1091. if field_name == 'df' and not field_info['is_none']:
  1092. total_dataframes += 1
  1093. if field_name == 'question' and not field_info['is_none']:
  1094. total_questions += 1
  1095. export_data['cache_statistics'] = {
  1096. 'field_frequency': field_frequency,
  1097. 'data_types_found': list(data_types_found),
  1098. 'total_dataframes': total_dataframes,
  1099. 'total_questions': total_questions,
  1100. 'has_session_timing': 'session_start_time' in field_frequency,
  1101. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1102. }
  1103. from common.result import success_response
  1104. return jsonify(success_response(
  1105. response_text="缓存数据导出完成",
  1106. data=export_data
  1107. ))
  1108. except Exception as e:
  1109. import traceback
  1110. error_details = {
  1111. 'error_message': str(e),
  1112. 'error_type': type(e).__name__,
  1113. 'traceback': traceback.format_exc()
  1114. }
  1115. from common.result import internal_error_response
  1116. return jsonify(internal_error_response(
  1117. response_text="导出缓存失败,请稍后重试"
  1118. )), 500
  1119. # ==================== 清理功能API ====================
  1120. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1121. def cache_preview_cleanup():
  1122. """清理功能:预览删除操作 - 保持原功能"""
  1123. try:
  1124. req = request.get_json(force=True)
  1125. # 时间条件 - 支持三种方式
  1126. older_than_hours = req.get('older_than_hours')
  1127. older_than_days = req.get('older_than_days')
  1128. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1129. cache = app.cache
  1130. # 计算截止时间
  1131. cutoff_time = None
  1132. time_condition = None
  1133. if older_than_hours:
  1134. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1135. time_condition = f"older_than_hours: {older_than_hours}"
  1136. elif older_than_days:
  1137. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1138. time_condition = f"older_than_days: {older_than_days}"
  1139. elif before_timestamp:
  1140. try:
  1141. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1142. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1143. time_condition = f"before_timestamp: {before_timestamp}"
  1144. except ValueError:
  1145. from common.result import validation_failed_response
  1146. return jsonify(validation_failed_response(
  1147. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1148. )), 422
  1149. else:
  1150. from common.result import bad_request_response
  1151. return jsonify(bad_request_response(
  1152. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1153. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1154. )), 400
  1155. preview = {
  1156. 'time_condition': time_condition,
  1157. 'cutoff_time': cutoff_time.isoformat(),
  1158. 'will_be_removed': {
  1159. 'sessions': []
  1160. },
  1161. 'will_be_kept': {
  1162. 'sessions_count': 0,
  1163. 'conversations_count': 0
  1164. },
  1165. 'summary': {
  1166. 'sessions_to_remove': 0,
  1167. 'conversations_to_remove': 0,
  1168. 'sessions_to_keep': 0,
  1169. 'conversations_to_keep': 0
  1170. }
  1171. }
  1172. # 预览按session删除
  1173. sessions_to_remove_count = 0
  1174. conversations_to_remove_count = 0
  1175. for session_id, session_data in cache.session_info.items():
  1176. session_preview = {
  1177. 'session_id': session_id,
  1178. 'start_time': session_data['start_time'].isoformat(),
  1179. 'conversation_count': len(session_data['conversations']),
  1180. 'conversations': []
  1181. }
  1182. # 添加conversation详情
  1183. for conv_id in session_data['conversations']:
  1184. if conv_id in cache.cache:
  1185. conv_data = cache.cache[conv_id]
  1186. session_preview['conversations'].append({
  1187. 'conversation_id': conv_id,
  1188. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1189. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1190. })
  1191. if session_data['start_time'] < cutoff_time:
  1192. preview['will_be_removed']['sessions'].append(session_preview)
  1193. sessions_to_remove_count += 1
  1194. conversations_to_remove_count += len(session_data['conversations'])
  1195. else:
  1196. preview['will_be_kept']['sessions_count'] += 1
  1197. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1198. # 更新摘要统计
  1199. preview['summary'] = {
  1200. 'sessions_to_remove': sessions_to_remove_count,
  1201. 'conversations_to_remove': conversations_to_remove_count,
  1202. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1203. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1204. }
  1205. from common.result import success_response
  1206. return jsonify(success_response(
  1207. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1208. data=preview
  1209. ))
  1210. except Exception as e:
  1211. from common.result import internal_error_response
  1212. return jsonify(internal_error_response(
  1213. response_text="预览清理操作失败,请稍后重试"
  1214. )), 500
  1215. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1216. def cache_cleanup():
  1217. """清理功能:实际删除缓存 - 保持原功能"""
  1218. try:
  1219. req = request.get_json(force=True)
  1220. # 时间条件 - 支持三种方式
  1221. older_than_hours = req.get('older_than_hours')
  1222. older_than_days = req.get('older_than_days')
  1223. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1224. cache = app.cache
  1225. if not hasattr(cache, 'session_info'):
  1226. from common.result import service_unavailable_response
  1227. return jsonify(service_unavailable_response(
  1228. response_text="缓存不支持会话功能"
  1229. )), 503
  1230. # 计算截止时间
  1231. cutoff_time = None
  1232. time_condition = None
  1233. if older_than_hours:
  1234. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1235. time_condition = f"older_than_hours: {older_than_hours}"
  1236. elif older_than_days:
  1237. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1238. time_condition = f"older_than_days: {older_than_days}"
  1239. elif before_timestamp:
  1240. try:
  1241. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1242. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1243. time_condition = f"before_timestamp: {before_timestamp}"
  1244. except ValueError:
  1245. from common.result import validation_failed_response
  1246. return jsonify(validation_failed_response(
  1247. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1248. )), 422
  1249. else:
  1250. from common.result import bad_request_response
  1251. return jsonify(bad_request_response(
  1252. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1253. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1254. )), 400
  1255. cleanup_stats = {
  1256. 'time_condition': time_condition,
  1257. 'cutoff_time': cutoff_time.isoformat(),
  1258. 'sessions_removed': 0,
  1259. 'conversations_removed': 0,
  1260. 'sessions_kept': 0,
  1261. 'conversations_kept': 0,
  1262. 'removed_session_ids': [],
  1263. 'removed_conversation_ids': []
  1264. }
  1265. # 按session删除
  1266. sessions_to_remove = []
  1267. for session_id, session_data in cache.session_info.items():
  1268. if session_data['start_time'] < cutoff_time:
  1269. sessions_to_remove.append(session_id)
  1270. # 删除符合条件的sessions及其所有conversations
  1271. for session_id in sessions_to_remove:
  1272. session_data = cache.session_info[session_id]
  1273. conversations_in_session = session_data['conversations'].copy()
  1274. # 删除session中的所有conversations
  1275. for conv_id in conversations_in_session:
  1276. if conv_id in cache.cache:
  1277. del cache.cache[conv_id]
  1278. cleanup_stats['conversations_removed'] += 1
  1279. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1280. # 清理conversation相关的时间记录
  1281. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1282. del cache.conversation_start_times[conv_id]
  1283. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1284. del cache.conversation_to_session[conv_id]
  1285. # 删除session记录
  1286. del cache.session_info[session_id]
  1287. cleanup_stats['sessions_removed'] += 1
  1288. cleanup_stats['removed_session_ids'].append(session_id)
  1289. # 统计保留的sessions和conversations
  1290. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1291. cleanup_stats['conversations_kept'] = len(cache.cache)
  1292. from common.result import success_response
  1293. return jsonify(success_response(
  1294. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1295. data=cleanup_stats
  1296. ))
  1297. except Exception as e:
  1298. from common.result import internal_error_response
  1299. return jsonify(internal_error_response(
  1300. response_text="缓存清理失败,请稍后重试"
  1301. )), 500
  1302. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1303. def training_error_question_sql():
  1304. """
  1305. 存储错误的question-sql对到error_sql集合中
  1306. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1307. Args:
  1308. question (str, required): 用户问题
  1309. sql (str, required): 对应的错误SQL查询语句
  1310. Returns:
  1311. JSON: 包含训练ID和成功消息的响应
  1312. """
  1313. try:
  1314. data = request.get_json()
  1315. question = data.get('question')
  1316. sql = data.get('sql')
  1317. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1318. if not question or not sql:
  1319. from common.result import bad_request_response
  1320. missing_params = []
  1321. if not question:
  1322. missing_params.append("question")
  1323. if not sql:
  1324. missing_params.append("sql")
  1325. return jsonify(bad_request_response(
  1326. response_text="question和sql参数都是必需的",
  1327. missing_params=missing_params
  1328. )), 400
  1329. # 使用vn实例的train_error_sql方法存储错误SQL
  1330. id = vn.train_error_sql(question=question, sql=sql)
  1331. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1332. from common.result import success_response
  1333. return jsonify(success_response(
  1334. response_text="错误SQL对已成功存储",
  1335. data={
  1336. "id": id,
  1337. "message": "错误SQL对已成功存储到error_sql集合"
  1338. }
  1339. ))
  1340. except Exception as e:
  1341. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1342. from common.result import internal_error_response
  1343. return jsonify(internal_error_response(
  1344. response_text="存储错误SQL失败,请稍后重试"
  1345. )), 500
  1346. # ==================== Redis对话管理API ====================
  1347. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1348. def get_user_conversations(user_id: str):
  1349. """获取用户的对话列表(按时间倒序)"""
  1350. try:
  1351. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1352. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1353. # 为每个对话动态获取标题(第一条用户消息)
  1354. for conversation in conversations:
  1355. conversation_id = conversation['conversation_id']
  1356. try:
  1357. # 获取所有消息,然后取第一条用户消息作为标题
  1358. messages = redis_conversation_manager.get_conversation_messages(conversation_id)
  1359. if messages and len(messages) > 0:
  1360. # 找到第一条用户消息(按时间顺序)
  1361. first_user_message = None
  1362. for message in messages:
  1363. if message.get('role') == 'user':
  1364. first_user_message = message
  1365. break
  1366. if first_user_message:
  1367. title = first_user_message.get('content', '对话').strip()
  1368. # 限制标题长度,保持整洁
  1369. if len(title) > 50:
  1370. conversation['conversation_title'] = title[:47] + "..."
  1371. else:
  1372. conversation['conversation_title'] = title
  1373. else:
  1374. conversation['conversation_title'] = "对话"
  1375. else:
  1376. conversation['conversation_title'] = "空对话"
  1377. except Exception as e:
  1378. print(f"[WARNING] 获取对话标题失败 {conversation_id}: {str(e)}")
  1379. conversation['conversation_title'] = "对话"
  1380. return jsonify(success_response(
  1381. response_text="获取用户对话列表成功",
  1382. data={
  1383. "user_id": user_id,
  1384. "conversations": conversations,
  1385. "total_count": len(conversations)
  1386. }
  1387. ))
  1388. except Exception as e:
  1389. return jsonify(internal_error_response(
  1390. response_text="获取对话列表失败,请稍后重试"
  1391. )), 500
  1392. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1393. def get_conversation_messages(conversation_id: str):
  1394. """获取特定对话的消息历史"""
  1395. try:
  1396. limit = request.args.get('limit', type=int) # 可选参数
  1397. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1398. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1399. return jsonify(success_response(
  1400. response_text="获取对话消息成功",
  1401. data={
  1402. "conversation_id": conversation_id,
  1403. "conversation_meta": meta,
  1404. "messages": messages,
  1405. "message_count": len(messages)
  1406. }
  1407. ))
  1408. except Exception as e:
  1409. return jsonify(internal_error_response(
  1410. response_text="获取对话消息失败"
  1411. )), 500
  1412. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1413. def get_conversation_context(conversation_id: str):
  1414. """获取对话上下文(格式化用于LLM)"""
  1415. try:
  1416. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1417. context = redis_conversation_manager.get_context_for_display(conversation_id, count)
  1418. return jsonify(success_response(
  1419. response_text="获取对话上下文成功",
  1420. data={
  1421. "conversation_id": conversation_id,
  1422. "context": context,
  1423. "context_message_count": count
  1424. }
  1425. ))
  1426. except Exception as e:
  1427. return jsonify(internal_error_response(
  1428. response_text="获取对话上下文失败"
  1429. )), 500
  1430. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1431. def conversation_stats():
  1432. """获取对话系统统计信息"""
  1433. try:
  1434. stats = redis_conversation_manager.get_stats()
  1435. return jsonify(success_response(
  1436. response_text="获取统计信息成功",
  1437. data=stats
  1438. ))
  1439. except Exception as e:
  1440. return jsonify(internal_error_response(
  1441. response_text="获取统计信息失败,请稍后重试"
  1442. )), 500
  1443. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1444. def conversation_cleanup():
  1445. """手动清理过期对话"""
  1446. try:
  1447. redis_conversation_manager.cleanup_expired_conversations()
  1448. return jsonify(success_response(
  1449. response_text="对话清理完成"
  1450. ))
  1451. except Exception as e:
  1452. return jsonify(internal_error_response(
  1453. response_text="对话清理失败,请稍后重试"
  1454. )), 500
  1455. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1456. def get_user_conversations_with_messages(user_id: str):
  1457. """
  1458. 获取用户的完整对话数据(包含所有消息)
  1459. 一次性返回用户的所有对话和每个对话下的消息历史
  1460. Args:
  1461. user_id: 用户ID(路径参数)
  1462. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1463. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1464. Returns:
  1465. 包含用户所有对话和消息的完整数据
  1466. """
  1467. try:
  1468. # 获取可选参数,不传递时使用None(返回所有记录)
  1469. conversation_limit = request.args.get('conversation_limit', type=int)
  1470. message_limit = request.args.get('message_limit', type=int)
  1471. # 获取用户的对话列表
  1472. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1473. # 为每个对话获取消息历史
  1474. full_conversations = []
  1475. total_messages = 0
  1476. for conversation in conversations:
  1477. conversation_id = conversation['conversation_id']
  1478. # 获取对话消息
  1479. messages = redis_conversation_manager.get_conversation_messages(
  1480. conversation_id, message_limit
  1481. )
  1482. # 获取对话元数据
  1483. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1484. # 组合完整数据
  1485. full_conversation = {
  1486. **conversation, # 基础对话信息
  1487. 'meta': meta, # 对话元数据
  1488. 'messages': messages, # 消息列表
  1489. 'message_count': len(messages)
  1490. }
  1491. full_conversations.append(full_conversation)
  1492. total_messages += len(messages)
  1493. return jsonify(success_response(
  1494. response_text="获取用户完整对话数据成功",
  1495. data={
  1496. "user_id": user_id,
  1497. "conversations": full_conversations,
  1498. "total_conversations": len(full_conversations),
  1499. "total_messages": total_messages,
  1500. "conversation_limit_applied": conversation_limit,
  1501. "message_limit_applied": message_limit,
  1502. "query_time": datetime.now().isoformat()
  1503. }
  1504. ))
  1505. except Exception as e:
  1506. print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
  1507. return jsonify(internal_error_response(
  1508. response_text="获取用户对话数据失败,请稍后重试"
  1509. )), 500
  1510. # ==================== Embedding缓存管理接口 ====================
  1511. @app.flask_app.route('/api/v0/embedding_cache_stats', methods=['GET'])
  1512. def embedding_cache_stats():
  1513. """获取embedding缓存统计信息"""
  1514. try:
  1515. from common.embedding_cache_manager import get_embedding_cache_manager
  1516. cache_manager = get_embedding_cache_manager()
  1517. stats = cache_manager.get_cache_stats()
  1518. return jsonify(success_response(
  1519. response_text="获取embedding缓存统计成功",
  1520. data=stats
  1521. ))
  1522. except Exception as e:
  1523. print(f"[ERROR] 获取embedding缓存统计失败: {str(e)}")
  1524. return jsonify(internal_error_response(
  1525. response_text="获取embedding缓存统计失败,请稍后重试"
  1526. )), 500
  1527. @app.flask_app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
  1528. def embedding_cache_cleanup():
  1529. """清空所有embedding缓存"""
  1530. try:
  1531. from common.embedding_cache_manager import get_embedding_cache_manager
  1532. cache_manager = get_embedding_cache_manager()
  1533. if not cache_manager.is_available():
  1534. return jsonify(internal_error_response(
  1535. response_text="Embedding缓存功能未启用或不可用"
  1536. )), 400
  1537. success = cache_manager.clear_all_cache()
  1538. if success:
  1539. return jsonify(success_response(
  1540. response_text="所有embedding缓存已清空",
  1541. data={"cleared": True}
  1542. ))
  1543. else:
  1544. return jsonify(internal_error_response(
  1545. response_text="清空embedding缓存失败"
  1546. )), 500
  1547. except Exception as e:
  1548. print(f"[ERROR] 清空embedding缓存失败: {str(e)}")
  1549. return jsonify(internal_error_response(
  1550. response_text="清空embedding缓存失败,请稍后重试"
  1551. )), 500
  1552. # ==================== QA反馈系统接口 ====================
  1553. # 全局反馈管理器实例
  1554. qa_feedback_manager = None
  1555. def get_qa_feedback_manager():
  1556. """获取QA反馈管理器实例(懒加载)- 复用Vanna连接版本"""
  1557. global qa_feedback_manager
  1558. if qa_feedback_manager is None:
  1559. try:
  1560. # 优先尝试复用vanna连接
  1561. vanna_instance = None
  1562. try:
  1563. # 尝试获取现有的vanna实例
  1564. if 'get_citu_langraph_agent' in globals():
  1565. agent = get_citu_langraph_agent()
  1566. if hasattr(agent, 'vn'):
  1567. vanna_instance = agent.vn
  1568. elif 'vn' in globals():
  1569. vanna_instance = vn
  1570. else:
  1571. print("[INFO] 未找到可用的vanna实例,将创建新的数据库连接")
  1572. except Exception as e:
  1573. print(f"[INFO] 获取vanna实例失败: {e},将创建新的数据库连接")
  1574. vanna_instance = None
  1575. qa_feedback_manager = QAFeedbackManager(vanna_instance=vanna_instance)
  1576. print("[CITU_APP] QA反馈管理器实例创建成功")
  1577. except Exception as e:
  1578. print(f"[CRITICAL] QA反馈管理器创建失败: {str(e)}")
  1579. raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
  1580. return qa_feedback_manager
  1581. @app.flask_app.route('/api/v0/qa_feedback/query', methods=['POST'])
  1582. def qa_feedback_query():
  1583. """
  1584. 查询反馈记录API
  1585. 支持分页、筛选和排序功能
  1586. """
  1587. try:
  1588. req = request.get_json(force=True)
  1589. # 解析参数,设置默认值
  1590. page = req.get('page', 1)
  1591. page_size = req.get('page_size', 20)
  1592. is_thumb_up = req.get('is_thumb_up')
  1593. create_time_start = req.get('create_time_start')
  1594. create_time_end = req.get('create_time_end')
  1595. is_in_training_data = req.get('is_in_training_data')
  1596. sort_by = req.get('sort_by', 'create_time')
  1597. sort_order = req.get('sort_order', 'desc')
  1598. # 参数验证
  1599. if page < 1:
  1600. return jsonify(bad_request_response(
  1601. response_text="页码必须大于0",
  1602. invalid_params=["page"]
  1603. )), 400
  1604. if page_size < 1 or page_size > 100:
  1605. return jsonify(bad_request_response(
  1606. response_text="每页大小必须在1-100之间",
  1607. invalid_params=["page_size"]
  1608. )), 400
  1609. # 获取反馈管理器并查询
  1610. manager = get_qa_feedback_manager()
  1611. records, total = manager.query_feedback(
  1612. page=page,
  1613. page_size=page_size,
  1614. is_thumb_up=is_thumb_up,
  1615. create_time_start=create_time_start,
  1616. create_time_end=create_time_end,
  1617. is_in_training_data=is_in_training_data,
  1618. sort_by=sort_by,
  1619. sort_order=sort_order
  1620. )
  1621. # 计算分页信息
  1622. total_pages = (total + page_size - 1) // page_size
  1623. return jsonify(success_response(
  1624. response_text=f"查询成功,共找到 {total} 条记录",
  1625. data={
  1626. "records": records,
  1627. "pagination": {
  1628. "page": page,
  1629. "page_size": page_size,
  1630. "total": total,
  1631. "total_pages": total_pages,
  1632. "has_next": page < total_pages,
  1633. "has_prev": page > 1
  1634. }
  1635. }
  1636. ))
  1637. except Exception as e:
  1638. print(f"[ERROR] qa_feedback_query执行失败: {str(e)}")
  1639. return jsonify(internal_error_response(
  1640. response_text="查询反馈记录失败,请稍后重试"
  1641. )), 500
  1642. @app.flask_app.route('/api/v0/qa_feedback/delete/<int:feedback_id>', methods=['DELETE'])
  1643. def qa_feedback_delete(feedback_id):
  1644. """
  1645. 删除反馈记录API
  1646. """
  1647. try:
  1648. manager = get_qa_feedback_manager()
  1649. success = manager.delete_feedback(feedback_id)
  1650. if success:
  1651. return jsonify(success_response(
  1652. response_text=f"反馈记录删除成功",
  1653. data={"deleted_id": feedback_id}
  1654. ))
  1655. else:
  1656. return jsonify(not_found_response(
  1657. response_text=f"反馈记录不存在 (ID: {feedback_id})"
  1658. )), 404
  1659. except Exception as e:
  1660. print(f"[ERROR] qa_feedback_delete执行失败: {str(e)}")
  1661. return jsonify(internal_error_response(
  1662. response_text="删除反馈记录失败,请稍后重试"
  1663. )), 500
  1664. @app.flask_app.route('/api/v0/qa_feedback/update/<int:feedback_id>', methods=['PUT'])
  1665. def qa_feedback_update(feedback_id):
  1666. """
  1667. 更新反馈记录API
  1668. """
  1669. try:
  1670. req = request.get_json(force=True)
  1671. # 提取允许更新的字段
  1672. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  1673. update_data = {}
  1674. for field in allowed_fields:
  1675. if field in req:
  1676. update_data[field] = req[field]
  1677. if not update_data:
  1678. return jsonify(bad_request_response(
  1679. response_text="没有提供有效的更新字段",
  1680. missing_params=allowed_fields
  1681. )), 400
  1682. manager = get_qa_feedback_manager()
  1683. success = manager.update_feedback(feedback_id, **update_data)
  1684. if success:
  1685. return jsonify(success_response(
  1686. response_text="反馈记录更新成功",
  1687. data={
  1688. "updated_id": feedback_id,
  1689. "updated_fields": list(update_data.keys())
  1690. }
  1691. ))
  1692. else:
  1693. return jsonify(not_found_response(
  1694. response_text=f"反馈记录不存在或无变化 (ID: {feedback_id})"
  1695. )), 404
  1696. except Exception as e:
  1697. print(f"[ERROR] qa_feedback_update执行失败: {str(e)}")
  1698. return jsonify(internal_error_response(
  1699. response_text="更新反馈记录失败,请稍后重试"
  1700. )), 500
  1701. @app.flask_app.route('/api/v0/qa_feedback/add_to_training', methods=['POST'])
  1702. def qa_feedback_add_to_training():
  1703. """
  1704. 将反馈记录添加到训练数据集API
  1705. 支持混合批量处理:正向反馈加入SQL训练集,负向反馈加入error_sql训练集
  1706. """
  1707. try:
  1708. req = request.get_json(force=True)
  1709. feedback_ids = req.get('feedback_ids', [])
  1710. if not feedback_ids or not isinstance(feedback_ids, list):
  1711. return jsonify(bad_request_response(
  1712. response_text="缺少有效的反馈ID列表",
  1713. missing_params=["feedback_ids"]
  1714. )), 400
  1715. manager = get_qa_feedback_manager()
  1716. # 获取反馈记录
  1717. records = manager.get_feedback_by_ids(feedback_ids)
  1718. if not records:
  1719. return jsonify(not_found_response(
  1720. response_text="未找到任何有效的反馈记录"
  1721. )), 404
  1722. # 分别处理正向和负向反馈
  1723. positive_count = 0 # 正向训练计数
  1724. negative_count = 0 # 负向训练计数
  1725. already_trained_count = 0 # 已训练计数
  1726. error_count = 0 # 错误计数
  1727. successfully_trained_ids = [] # 成功训练的ID列表
  1728. for record in records:
  1729. try:
  1730. # 检查是否已经在训练数据中
  1731. if record['is_in_training_data']:
  1732. already_trained_count += 1
  1733. continue
  1734. if record['is_thumb_up']:
  1735. # 正向反馈 - 加入标准SQL训练集
  1736. training_id = vn.train(
  1737. question=record['question'],
  1738. sql=record['sql']
  1739. )
  1740. positive_count += 1
  1741. print(f"[TRAINING] 正向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1742. else:
  1743. # 负向反馈 - 加入错误SQL训练集
  1744. training_id = vn.train_error_sql(
  1745. question=record['question'],
  1746. sql=record['sql']
  1747. )
  1748. negative_count += 1
  1749. print(f"[TRAINING] 负向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1750. successfully_trained_ids.append(record['id'])
  1751. except Exception as e:
  1752. print(f"[ERROR] 训练失败 - 反馈ID: {record['id']}, 错误: {e}")
  1753. error_count += 1
  1754. # 更新训练状态
  1755. if successfully_trained_ids:
  1756. updated_count = manager.mark_training_status(successfully_trained_ids, True)
  1757. print(f"[TRAINING] 批量更新训练状态完成,影响 {updated_count} 条记录")
  1758. # 构建响应
  1759. total_processed = positive_count + negative_count + already_trained_count + error_count
  1760. return jsonify(success_response(
  1761. response_text=f"训练数据添加完成,成功处理 {positive_count + negative_count} 条记录",
  1762. data={
  1763. "summary": {
  1764. "total_requested": len(feedback_ids),
  1765. "total_processed": total_processed,
  1766. "positive_trained": positive_count,
  1767. "negative_trained": negative_count,
  1768. "already_trained": already_trained_count,
  1769. "errors": error_count
  1770. },
  1771. "successfully_trained_ids": successfully_trained_ids,
  1772. "training_details": {
  1773. "sql_training_count": positive_count,
  1774. "error_sql_training_count": negative_count
  1775. }
  1776. }
  1777. ))
  1778. except Exception as e:
  1779. print(f"[ERROR] qa_feedback_add_to_training执行失败: {str(e)}")
  1780. return jsonify(internal_error_response(
  1781. response_text="添加训练数据失败,请稍后重试"
  1782. )), 500
  1783. @app.flask_app.route('/api/v0/qa_feedback/add', methods=['POST'])
  1784. def qa_feedback_add():
  1785. """
  1786. 添加反馈记录API
  1787. 用于前端直接创建反馈记录
  1788. """
  1789. try:
  1790. req = request.get_json(force=True)
  1791. question = req.get('question')
  1792. sql = req.get('sql')
  1793. is_thumb_up = req.get('is_thumb_up')
  1794. user_id = req.get('user_id', 'guest')
  1795. # 参数验证
  1796. if not question:
  1797. return jsonify(bad_request_response(
  1798. response_text="缺少必需参数:question",
  1799. missing_params=["question"]
  1800. )), 400
  1801. if not sql:
  1802. return jsonify(bad_request_response(
  1803. response_text="缺少必需参数:sql",
  1804. missing_params=["sql"]
  1805. )), 400
  1806. if is_thumb_up is None:
  1807. return jsonify(bad_request_response(
  1808. response_text="缺少必需参数:is_thumb_up",
  1809. missing_params=["is_thumb_up"]
  1810. )), 400
  1811. manager = get_qa_feedback_manager()
  1812. feedback_id = manager.add_feedback(
  1813. question=question,
  1814. sql=sql,
  1815. is_thumb_up=bool(is_thumb_up),
  1816. user_id=user_id
  1817. )
  1818. return jsonify(success_response(
  1819. response_text="反馈记录创建成功",
  1820. data={
  1821. "feedback_id": feedback_id
  1822. }
  1823. ))
  1824. except Exception as e:
  1825. print(f"[ERROR] qa_feedback_add执行失败: {str(e)}")
  1826. return jsonify(internal_error_response(
  1827. response_text="创建反馈记录失败,请稍后重试"
  1828. )), 500
  1829. @app.flask_app.route('/api/v0/qa_feedback/stats', methods=['GET'])
  1830. def qa_feedback_stats():
  1831. """
  1832. 反馈统计API
  1833. 返回反馈数据的统计信息
  1834. """
  1835. try:
  1836. manager = get_qa_feedback_manager()
  1837. # 查询各种统计数据
  1838. all_records, total_count = manager.query_feedback(page=1, page_size=1)
  1839. positive_records, positive_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=True)
  1840. negative_records, negative_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=False)
  1841. trained_records, trained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=True)
  1842. untrained_records, untrained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=False)
  1843. return jsonify(success_response(
  1844. response_text="统计信息获取成功",
  1845. data={
  1846. "total_feedback": total_count,
  1847. "positive_feedback": positive_count,
  1848. "negative_feedback": negative_count,
  1849. "trained_feedback": trained_count,
  1850. "untrained_feedback": untrained_count,
  1851. "positive_rate": round(positive_count / max(total_count, 1) * 100, 2),
  1852. "training_rate": round(trained_count / max(total_count, 1) * 100, 2)
  1853. }
  1854. ))
  1855. except Exception as e:
  1856. print(f"[ERROR] qa_feedback_stats执行失败: {str(e)}")
  1857. return jsonify(internal_error_response(
  1858. response_text="获取统计信息失败,请稍后重试"
  1859. )), 500
  1860. # ==================== 问答缓存管理接口 ====================
  1861. @app.flask_app.route('/api/v0/qa_cache_stats', methods=['GET'])
  1862. def qa_cache_stats():
  1863. """获取问答缓存统计信息"""
  1864. try:
  1865. stats = redis_conversation_manager.get_qa_cache_stats()
  1866. return jsonify(success_response(
  1867. response_text="获取问答缓存统计成功",
  1868. data=stats
  1869. ))
  1870. except Exception as e:
  1871. print(f"[ERROR] 获取问答缓存统计失败: {str(e)}")
  1872. return jsonify(internal_error_response(
  1873. response_text="获取问答缓存统计失败,请稍后重试"
  1874. )), 500
  1875. @app.flask_app.route('/api/v0/qa_cache_list', methods=['GET'])
  1876. def qa_cache_list():
  1877. """获取问答缓存列表(支持分页)"""
  1878. try:
  1879. # 获取分页参数,默认限制50条
  1880. limit = request.args.get('limit', 50, type=int)
  1881. # 限制最大返回数量,防止一次性返回过多数据
  1882. if limit > 500:
  1883. limit = 500
  1884. elif limit <= 0:
  1885. limit = 50
  1886. cache_list = redis_conversation_manager.get_qa_cache_list(limit)
  1887. return jsonify(success_response(
  1888. response_text="获取问答缓存列表成功",
  1889. data={
  1890. "cache_list": cache_list,
  1891. "total_returned": len(cache_list),
  1892. "limit_applied": limit,
  1893. "note": "按缓存时间倒序排列,最新的在前面"
  1894. }
  1895. ))
  1896. except Exception as e:
  1897. print(f"[ERROR] 获取问答缓存列表失败: {str(e)}")
  1898. return jsonify(internal_error_response(
  1899. response_text="获取问答缓存列表失败,请稍后重试"
  1900. )), 500
  1901. @app.flask_app.route('/api/v0/qa_cache_cleanup', methods=['POST'])
  1902. def qa_cache_cleanup():
  1903. """清空所有问答缓存"""
  1904. try:
  1905. if not redis_conversation_manager.is_available():
  1906. return jsonify(internal_error_response(
  1907. response_text="Redis连接不可用,无法执行清理操作"
  1908. )), 500
  1909. deleted_count = redis_conversation_manager.clear_all_qa_cache()
  1910. return jsonify(success_response(
  1911. response_text="问答缓存清理完成",
  1912. data={
  1913. "deleted_count": deleted_count,
  1914. "cleared": deleted_count > 0,
  1915. "cleanup_time": datetime.now().isoformat()
  1916. }
  1917. ))
  1918. except Exception as e:
  1919. print(f"[ERROR] 清空问答缓存失败: {str(e)}")
  1920. return jsonify(internal_error_response(
  1921. response_text="清空问答缓存失败,请稍后重试"
  1922. )), 500
  1923. # ==================== 训练数据管理接口 ====================
  1924. def validate_sql_syntax(sql: str) -> tuple[bool, str]:
  1925. """SQL语法检查(仅对sql类型)"""
  1926. try:
  1927. parsed = sqlparse.parse(sql.strip())
  1928. if not parsed or not parsed[0].tokens:
  1929. return False, "SQL语法错误:空语句"
  1930. # 基本语法检查
  1931. sql_upper = sql.strip().upper()
  1932. if not any(sql_upper.startswith(keyword) for keyword in
  1933. ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
  1934. return False, "SQL语法错误:不是有效的SQL语句"
  1935. # 安全检查:禁止危险的SQL操作
  1936. dangerous_operations = ['UPDATE', 'DELETE', 'ALERT', 'DROP']
  1937. for operation in dangerous_operations:
  1938. if sql_upper.startswith(operation):
  1939. return False, f'在训练集中禁止使用"{",".join(dangerous_operations)}"'
  1940. return True, ""
  1941. except Exception as e:
  1942. return False, f"SQL语法错误:{str(e)}"
  1943. def paginate_data(data_list: list, page: int, page_size: int):
  1944. """分页处理算法"""
  1945. total = len(data_list)
  1946. start_idx = (page - 1) * page_size
  1947. end_idx = start_idx + page_size
  1948. page_data = data_list[start_idx:end_idx]
  1949. return {
  1950. "data": page_data,
  1951. "pagination": {
  1952. "page": page,
  1953. "page_size": page_size,
  1954. "total": total,
  1955. "total_pages": (total + page_size - 1) // page_size,
  1956. "has_next": end_idx < total,
  1957. "has_prev": page > 1
  1958. }
  1959. }
  1960. def filter_by_type(data_list: list, training_data_type: str):
  1961. """按类型筛选算法"""
  1962. if not training_data_type:
  1963. return data_list
  1964. return [
  1965. record for record in data_list
  1966. if record.get('training_data_type') == training_data_type
  1967. ]
  1968. def search_in_data(data_list: list, search_keyword: str):
  1969. """在数据中搜索关键词"""
  1970. if not search_keyword:
  1971. return data_list
  1972. keyword_lower = search_keyword.lower()
  1973. return [
  1974. record for record in data_list
  1975. if (record.get('question') and keyword_lower in record['question'].lower()) or
  1976. (record.get('content') and keyword_lower in record['content'].lower())
  1977. ]
  1978. def process_single_training_item(item: dict, index: int) -> dict:
  1979. """处理单个训练数据项"""
  1980. training_type = item.get('training_data_type')
  1981. if training_type == 'sql':
  1982. sql = item.get('sql')
  1983. if not sql:
  1984. raise ValueError("SQL字段是必需的")
  1985. # SQL语法检查
  1986. is_valid, error_msg = validate_sql_syntax(sql)
  1987. if not is_valid:
  1988. raise ValueError(error_msg)
  1989. question = item.get('question')
  1990. if question:
  1991. training_id = vn.train(question=question, sql=sql)
  1992. else:
  1993. training_id = vn.train(sql=sql)
  1994. elif training_type == 'error_sql':
  1995. # error_sql不需要语法检查
  1996. question = item.get('question')
  1997. sql = item.get('sql')
  1998. if not question or not sql:
  1999. raise ValueError("question和sql字段都是必需的")
  2000. training_id = vn.train_error_sql(question=question, sql=sql)
  2001. elif training_type == 'documentation':
  2002. content = item.get('content')
  2003. if not content:
  2004. raise ValueError("content字段是必需的")
  2005. training_id = vn.train(documentation=content)
  2006. elif training_type == 'ddl':
  2007. ddl = item.get('ddl')
  2008. if not ddl:
  2009. raise ValueError("ddl字段是必需的")
  2010. training_id = vn.train(ddl=ddl)
  2011. else:
  2012. raise ValueError(f"不支持的训练数据类型: {training_type}")
  2013. return {
  2014. "index": index,
  2015. "success": True,
  2016. "training_id": training_id,
  2017. "type": training_type,
  2018. "message": f"{training_type}训练数据创建成功"
  2019. }
  2020. def get_total_training_count():
  2021. """获取当前训练数据总数"""
  2022. try:
  2023. training_data = vn.get_training_data()
  2024. if training_data is not None and not training_data.empty:
  2025. return len(training_data)
  2026. return 0
  2027. except Exception as e:
  2028. print(f"[WARNING] 获取训练数据总数失败: {e}")
  2029. return 0
  2030. @app.flask_app.route('/api/v0/training_data/query', methods=['POST'])
  2031. def training_data_query():
  2032. """
  2033. 分页查询训练数据API
  2034. 支持类型筛选、搜索和排序功能
  2035. """
  2036. try:
  2037. req = request.get_json(force=True)
  2038. # 解析参数,设置默认值
  2039. page = req.get('page', 1)
  2040. page_size = req.get('page_size', 20)
  2041. training_data_type = req.get('training_data_type')
  2042. sort_by = req.get('sort_by', 'id')
  2043. sort_order = req.get('sort_order', 'desc')
  2044. search_keyword = req.get('search_keyword')
  2045. # 参数验证
  2046. if page < 1:
  2047. return jsonify(bad_request_response(
  2048. response_text="页码必须大于0",
  2049. missing_params=["page"]
  2050. )), 400
  2051. if page_size < 1 or page_size > 100:
  2052. return jsonify(bad_request_response(
  2053. response_text="每页大小必须在1-100之间",
  2054. missing_params=["page_size"]
  2055. )), 400
  2056. if search_keyword and len(search_keyword) > 100:
  2057. return jsonify(bad_request_response(
  2058. response_text="搜索关键词最大长度为100字符",
  2059. missing_params=["search_keyword"]
  2060. )), 400
  2061. # 获取训练数据
  2062. training_data = vn.get_training_data()
  2063. if training_data is None or training_data.empty:
  2064. return jsonify(success_response(
  2065. response_text="查询成功,暂无训练数据",
  2066. data={
  2067. "records": [],
  2068. "pagination": {
  2069. "page": page,
  2070. "page_size": page_size,
  2071. "total": 0,
  2072. "total_pages": 0,
  2073. "has_next": False,
  2074. "has_prev": False
  2075. },
  2076. "filters_applied": {
  2077. "training_data_type": training_data_type,
  2078. "search_keyword": search_keyword
  2079. }
  2080. }
  2081. ))
  2082. # 转换为列表格式
  2083. records = training_data.to_dict(orient="records")
  2084. # 应用筛选条件
  2085. if training_data_type:
  2086. records = filter_by_type(records, training_data_type)
  2087. if search_keyword:
  2088. records = search_in_data(records, search_keyword)
  2089. # 排序
  2090. if sort_by in ['id', 'training_data_type']:
  2091. reverse = (sort_order.lower() == 'desc')
  2092. records.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
  2093. # 分页
  2094. paginated_result = paginate_data(records, page, page_size)
  2095. return jsonify(success_response(
  2096. response_text=f"查询成功,共找到 {paginated_result['pagination']['total']} 条记录",
  2097. data={
  2098. "records": paginated_result["data"],
  2099. "pagination": paginated_result["pagination"],
  2100. "filters_applied": {
  2101. "training_data_type": training_data_type,
  2102. "search_keyword": search_keyword
  2103. }
  2104. }
  2105. ))
  2106. except Exception as e:
  2107. print(f"[ERROR] training_data_query执行失败: {str(e)}")
  2108. return jsonify(internal_error_response(
  2109. response_text="查询训练数据失败,请稍后重试"
  2110. )), 500
  2111. @app.flask_app.route('/api/v0/training_data/create', methods=['POST'])
  2112. def training_data_create():
  2113. """
  2114. 创建训练数据API
  2115. 支持单条和批量创建,支持四种数据类型
  2116. """
  2117. try:
  2118. req = request.get_json(force=True)
  2119. data = req.get('data')
  2120. if not data:
  2121. return jsonify(bad_request_response(
  2122. response_text="缺少必需参数:data",
  2123. missing_params=["data"]
  2124. )), 400
  2125. # 统一处理为列表格式
  2126. if isinstance(data, dict):
  2127. data_list = [data]
  2128. elif isinstance(data, list):
  2129. data_list = data
  2130. else:
  2131. return jsonify(bad_request_response(
  2132. response_text="data字段格式错误,应为对象或数组"
  2133. )), 400
  2134. # 批量操作限制
  2135. if len(data_list) > 50:
  2136. return jsonify(bad_request_response(
  2137. response_text="批量操作最大支持50条记录"
  2138. )), 400
  2139. results = []
  2140. successful_count = 0
  2141. type_summary = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2142. for index, item in enumerate(data_list):
  2143. try:
  2144. result = process_single_training_item(item, index)
  2145. results.append(result)
  2146. if result['success']:
  2147. successful_count += 1
  2148. type_summary[result['type']] += 1
  2149. except Exception as e:
  2150. results.append({
  2151. "index": index,
  2152. "success": False,
  2153. "type": item.get('training_data_type', 'unknown'),
  2154. "error": str(e),
  2155. "message": "创建失败"
  2156. })
  2157. # 获取创建后的总记录数
  2158. current_total = get_total_training_count()
  2159. return jsonify(success_response(
  2160. response_text="训练数据创建完成",
  2161. data={
  2162. "total_requested": len(data_list),
  2163. "successfully_created": successful_count,
  2164. "failed_count": len(data_list) - successful_count,
  2165. "results": results,
  2166. "summary": type_summary,
  2167. "current_total_count": current_total
  2168. }
  2169. ))
  2170. except Exception as e:
  2171. print(f"[ERROR] training_data_create执行失败: {str(e)}")
  2172. return jsonify(internal_error_response(
  2173. response_text="创建训练数据失败,请稍后重试"
  2174. )), 500
  2175. @app.flask_app.route('/api/v0/training_data/delete', methods=['POST'])
  2176. def training_data_delete():
  2177. """
  2178. 删除训练数据API
  2179. 支持批量删除
  2180. """
  2181. try:
  2182. req = request.get_json(force=True)
  2183. ids = req.get('ids', [])
  2184. confirm = req.get('confirm', False)
  2185. if not ids or not isinstance(ids, list):
  2186. return jsonify(bad_request_response(
  2187. response_text="缺少有效的ID列表",
  2188. missing_params=["ids"]
  2189. )), 400
  2190. if not confirm:
  2191. return jsonify(bad_request_response(
  2192. response_text="删除操作需要确认,请设置confirm为true"
  2193. )), 400
  2194. # 批量操作限制
  2195. if len(ids) > 50:
  2196. return jsonify(bad_request_response(
  2197. response_text="批量删除最大支持50条记录"
  2198. )), 400
  2199. deleted_ids = []
  2200. failed_ids = []
  2201. failed_details = []
  2202. for training_id in ids:
  2203. try:
  2204. success = vn.remove_training_data(training_id)
  2205. if success:
  2206. deleted_ids.append(training_id)
  2207. else:
  2208. failed_ids.append(training_id)
  2209. failed_details.append({
  2210. "id": training_id,
  2211. "error": "记录不存在或删除失败"
  2212. })
  2213. except Exception as e:
  2214. failed_ids.append(training_id)
  2215. failed_details.append({
  2216. "id": training_id,
  2217. "error": str(e)
  2218. })
  2219. # 获取删除后的总记录数
  2220. current_total = get_total_training_count()
  2221. return jsonify(success_response(
  2222. response_text="训练数据删除完成",
  2223. data={
  2224. "total_requested": len(ids),
  2225. "successfully_deleted": len(deleted_ids),
  2226. "failed_count": len(failed_ids),
  2227. "deleted_ids": deleted_ids,
  2228. "failed_ids": failed_ids,
  2229. "failed_details": failed_details,
  2230. "current_total_count": current_total
  2231. }
  2232. ))
  2233. except Exception as e:
  2234. print(f"[ERROR] training_data_delete执行失败: {str(e)}")
  2235. return jsonify(internal_error_response(
  2236. response_text="删除训练数据失败,请稍后重试"
  2237. )), 500
  2238. @app.flask_app.route('/api/v0/training_data/stats', methods=['GET'])
  2239. def training_data_stats():
  2240. """
  2241. 获取训练数据统计信息API
  2242. """
  2243. try:
  2244. training_data = vn.get_training_data()
  2245. if training_data is None or training_data.empty:
  2246. return jsonify(success_response(
  2247. response_text="统计信息获取成功",
  2248. data={
  2249. "total_count": 0,
  2250. "type_breakdown": {
  2251. "sql": 0,
  2252. "documentation": 0,
  2253. "ddl": 0,
  2254. "error_sql": 0
  2255. },
  2256. "type_percentages": {
  2257. "sql": 0.0,
  2258. "documentation": 0.0,
  2259. "ddl": 0.0,
  2260. "error_sql": 0.0
  2261. },
  2262. "last_updated": datetime.now().isoformat()
  2263. }
  2264. ))
  2265. total_count = len(training_data)
  2266. # 统计各类型数量
  2267. type_breakdown = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2268. if 'training_data_type' in training_data.columns:
  2269. type_counts = training_data['training_data_type'].value_counts()
  2270. for data_type, count in type_counts.items():
  2271. if data_type in type_breakdown:
  2272. type_breakdown[data_type] = int(count)
  2273. # 计算百分比
  2274. type_percentages = {}
  2275. for data_type, count in type_breakdown.items():
  2276. type_percentages[data_type] = round(count / max(total_count, 1) * 100, 2)
  2277. return jsonify(success_response(
  2278. response_text="统计信息获取成功",
  2279. data={
  2280. "total_count": total_count,
  2281. "type_breakdown": type_breakdown,
  2282. "type_percentages": type_percentages,
  2283. "last_updated": datetime.now().isoformat()
  2284. }
  2285. ))
  2286. except Exception as e:
  2287. print(f"[ERROR] training_data_stats执行失败: {str(e)}")
  2288. return jsonify(internal_error_response(
  2289. response_text="获取统计信息失败,请稍后重试"
  2290. )), 500
  2291. @app.flask_app.route('/api/v0/cache_overview_full', methods=['GET'])
  2292. def cache_overview_full():
  2293. """获取所有缓存系统的综合概览"""
  2294. try:
  2295. from common.embedding_cache_manager import get_embedding_cache_manager
  2296. from common.vanna_instance import get_vanna_instance
  2297. # 获取现有的缓存统计
  2298. vanna_cache = get_vanna_instance()
  2299. # 直接使用应用中的缓存实例
  2300. cache = app.cache
  2301. cache_overview = {
  2302. "conversation_aware_cache": {
  2303. "enabled": True,
  2304. "total_items": len(cache.cache) if hasattr(cache, 'cache') else 0,
  2305. "sessions": list(cache.cache.keys()) if hasattr(cache, 'cache') else [],
  2306. "cache_type": type(cache).__name__
  2307. },
  2308. "question_answer_cache": redis_conversation_manager.get_qa_cache_stats() if redis_conversation_manager.is_available() else {"available": False},
  2309. "embedding_cache": get_embedding_cache_manager().get_cache_stats(),
  2310. "redis_conversation_stats": redis_conversation_manager.get_stats() if redis_conversation_manager.is_available() else None
  2311. }
  2312. return jsonify(success_response(
  2313. response_text="获取综合缓存概览成功",
  2314. data=cache_overview
  2315. ))
  2316. except Exception as e:
  2317. print(f"[ERROR] 获取综合缓存概览失败: {str(e)}")
  2318. return jsonify(internal_error_response(
  2319. response_text="获取缓存概览失败,请稍后重试"
  2320. )), 500
  2321. # 前端JavaScript示例 - 如何维持会话
  2322. """
  2323. // 前端需要维护一个会话ID
  2324. class ChatSession {
  2325. constructor() {
  2326. // 从localStorage获取或创建新的会话ID
  2327. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  2328. localStorage.setItem('chat_session_id', this.sessionId);
  2329. }
  2330. generateSessionId() {
  2331. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  2332. }
  2333. async askQuestion(question) {
  2334. const response = await fetch('/api/v0/ask', {
  2335. method: 'POST',
  2336. headers: {
  2337. 'Content-Type': 'application/json',
  2338. },
  2339. body: JSON.stringify({
  2340. question: question,
  2341. session_id: this.sessionId // 关键:传递会话ID
  2342. })
  2343. });
  2344. return await response.json();
  2345. }
  2346. // 开始新会话
  2347. startNewSession() {
  2348. this.sessionId = this.generateSessionId();
  2349. localStorage.setItem('chat_session_id', this.sessionId);
  2350. }
  2351. }
  2352. // 使用示例
  2353. const chatSession = new ChatSession();
  2354. chatSession.askQuestion("各年龄段客户的流失率如何?");
  2355. """
  2356. print("正在启动Flask应用: http://localhost:8084")
  2357. app.run(host="0.0.0.0", port=8084, debug=True)