citu_app.py 157 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050
  1. # 给dataops 对话助手返回结果
  2. # 初始化日志系统 - 必须在最前面
  3. from core.logging import initialize_logging, get_app_logger, set_log_context, clear_log_context
  4. initialize_logging()
  5. from vanna.flask import VannaFlaskApp
  6. from core.vanna_llm_factory import create_vanna_instance
  7. from flask import request, jsonify
  8. import pandas as pd
  9. import common.result as result
  10. from datetime import datetime, timedelta
  11. from common.session_aware_cache import WebSessionAwareMemoryCache
  12. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  13. import re
  14. import chainlit as cl
  15. import json
  16. from flask import session # 添加session导入
  17. import sqlparse # 用于SQL语法检查
  18. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  19. from common.qa_feedback_manager import QAFeedbackManager
  20. from common.result import success_response, bad_request_response, not_found_response, internal_error_response
  21. from common.result import ( # 统一导入所有需要的响应函数
  22. bad_request_response, service_unavailable_response,
  23. agent_success_response, agent_error_response,
  24. internal_error_response, success_response,
  25. validation_failed_response
  26. )
  27. from app_config import ( # 添加Redis相关配置导入
  28. USER_MAX_CONVERSATIONS,
  29. CONVERSATION_CONTEXT_COUNT,
  30. DEFAULT_ANONYMOUS_USER,
  31. ENABLE_QUESTION_ANSWER_CACHE
  32. )
  33. # 创建app logger
  34. logger = get_app_logger("CituApp")
  35. # 设置默认的最大返回行数
  36. DEFAULT_MAX_RETURN_ROWS = 200
  37. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  38. vn = create_vanna_instance()
  39. # 创建带时间戳的缓存
  40. timestamped_cache = WebSessionAwareMemoryCache()
  41. # 实例化 VannaFlaskApp,使用自定义缓存
  42. app = VannaFlaskApp(
  43. vn,
  44. cache=timestamped_cache, # 使用带时间戳的缓存
  45. title="辞图智能数据问答平台",
  46. logo = "https://www.citupro.com/img/logo-black-2.png",
  47. subtitle="让 AI 为你写 SQL",
  48. chart=False,
  49. allow_llm_to_see_data=True,
  50. ask_results_correct=True,
  51. followup_questions=True,
  52. debug=True
  53. )
  54. # 创建Redis对话管理器实例
  55. redis_conversation_manager = RedisConversationManager()
  56. # 修改ask接口,支持前端传递session_id
  57. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  58. def ask_full():
  59. req = request.get_json(force=True)
  60. question = req.get("question", None)
  61. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  62. if not question:
  63. from common.result import bad_request_response
  64. return jsonify(bad_request_response(
  65. response_text="缺少必需参数:question",
  66. missing_params=["question"]
  67. )), 400
  68. # 如果使用WebSessionAwareMemoryCache
  69. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  70. # 这里需要修改vanna的ask方法来支持传递session_id
  71. # 或者预先调用generate_id来建立会话关联
  72. conversation_id = app.cache.generate_id_with_browser_session(
  73. question=question,
  74. browser_session_id=browser_session_id
  75. )
  76. try:
  77. sql, df, _ = vn.ask(
  78. question=question,
  79. print_results=False,
  80. visualize=False,
  81. allow_llm_to_see_data=True
  82. )
  83. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  84. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  85. # 在解释性文本末尾添加提示语
  86. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  87. # 使用标准化错误响应
  88. from common.result import validation_failed_response
  89. return jsonify(validation_failed_response(
  90. response_text=explanation_message
  91. )), 422 # 修改HTTP状态码为422
  92. # 如果sql为None但没有解释性文本,返回通用错误
  93. if sql is None:
  94. from common.result import validation_failed_response
  95. return jsonify(validation_failed_response(
  96. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  97. )), 422
  98. # 处理返回数据 - 使用新的query_result结构
  99. query_result = {
  100. "rows": [],
  101. "columns": [],
  102. "row_count": 0,
  103. "is_limited": False,
  104. "total_row_count": 0
  105. }
  106. summary = None
  107. if isinstance(df, pd.DataFrame):
  108. query_result["columns"] = list(df.columns)
  109. if not df.empty:
  110. total_rows = len(df)
  111. limited_df = df.head(MAX_RETURN_ROWS)
  112. query_result["rows"] = limited_df.to_dict(orient="records")
  113. query_result["row_count"] = len(limited_df)
  114. query_result["total_row_count"] = total_rows
  115. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  116. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  117. if ENABLE_RESULT_SUMMARY:
  118. try:
  119. summary = vn.generate_summary(question=question, df=df)
  120. logger.info(f"成功生成摘要: {summary}")
  121. except Exception as e:
  122. logger.warning(f"生成摘要失败: {str(e)}")
  123. summary = None
  124. # 构建返回数据
  125. response_data = {
  126. "sql": sql,
  127. "query_result": query_result,
  128. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  129. "session_id": browser_session_id
  130. }
  131. # 添加摘要(如果启用且生成成功)
  132. if ENABLE_RESULT_SUMMARY and summary is not None:
  133. response_data["summary"] = summary
  134. response_data["response"] = summary # 同时添加response字段
  135. from common.result import success_response
  136. return jsonify(success_response(
  137. response_text="查询执行完成" if summary is None else None,
  138. data=response_data
  139. ))
  140. except Exception as e:
  141. logger.error(f"ask_full执行失败: {str(e)}")
  142. # 即使发生异常,也检查是否有业务层面的解释
  143. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  144. # 在解释性文本末尾添加提示语
  145. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  146. from common.result import validation_failed_response
  147. return jsonify(validation_failed_response(
  148. response_text=explanation_message
  149. )), 422
  150. else:
  151. # 技术错误,使用500错误码
  152. from common.result import internal_error_response
  153. return jsonify(internal_error_response(
  154. response_text="查询处理失败,请稍后重试"
  155. )), 500
  156. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  157. def citu_run_sql():
  158. req = request.get_json(force=True)
  159. sql = req.get('sql')
  160. if not sql:
  161. from common.result import bad_request_response
  162. return jsonify(bad_request_response(
  163. response_text="缺少必需参数:sql",
  164. missing_params=["sql"]
  165. )), 400
  166. try:
  167. df = vn.run_sql(sql)
  168. # 处理返回数据 - 使用新的query_result结构
  169. query_result = {
  170. "rows": [],
  171. "columns": [],
  172. "row_count": 0,
  173. "is_limited": False,
  174. "total_row_count": 0
  175. }
  176. if isinstance(df, pd.DataFrame):
  177. query_result["columns"] = list(df.columns)
  178. if not df.empty:
  179. total_rows = len(df)
  180. limited_df = df.head(MAX_RETURN_ROWS)
  181. query_result["rows"] = limited_df.to_dict(orient="records")
  182. query_result["row_count"] = len(limited_df)
  183. query_result["total_row_count"] = total_rows
  184. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  185. from common.result import success_response
  186. return jsonify(success_response(
  187. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  188. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  189. data={
  190. "sql": sql,
  191. "query_result": query_result
  192. }
  193. ))
  194. except Exception as e:
  195. logger.error(f"citu_run_sql执行失败: {str(e)}")
  196. from common.result import internal_error_response
  197. return jsonify(internal_error_response(
  198. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  199. )), 500
  200. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  201. def ask_cached():
  202. """
  203. 带缓存功能的智能查询接口
  204. 支持会话管理和结果缓存,提高查询效率
  205. """
  206. req = request.get_json(force=True)
  207. question = req.get("question", None)
  208. browser_session_id = req.get("session_id", None)
  209. if not question:
  210. from common.result import bad_request_response
  211. return jsonify(bad_request_response(
  212. response_text="缺少必需参数:question",
  213. missing_params=["question"]
  214. )), 400
  215. try:
  216. # 生成conversation_id
  217. # 调试:查看generate_id的实际行为
  218. logger.debug(f"输入问题: '{question}'")
  219. conversation_id = app.cache.generate_id(question=question)
  220. logger.debug(f"生成的conversation_id: {conversation_id}")
  221. # 再次用相同问题测试
  222. conversation_id2 = app.cache.generate_id(question=question)
  223. logger.debug(f"再次生成的conversation_id: {conversation_id2}")
  224. logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
  225. # 检查缓存
  226. cached_sql = app.cache.get(id=conversation_id, field="sql")
  227. if cached_sql is not None:
  228. # 缓存命中
  229. logger.info(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  230. sql = cached_sql
  231. df = app.cache.get(id=conversation_id, field="df")
  232. summary = app.cache.get(id=conversation_id, field="summary")
  233. else:
  234. # 缓存未命中,执行新查询
  235. logger.info(f"[CACHE MISS] 执行新查询: {conversation_id}")
  236. sql, df, _ = vn.ask(
  237. question=question,
  238. print_results=False,
  239. visualize=False,
  240. allow_llm_to_see_data=True
  241. )
  242. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  243. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  244. # 在解释性文本末尾添加提示语
  245. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  246. from common.result import validation_failed_response
  247. return jsonify(validation_failed_response(
  248. response_text=explanation_message
  249. )), 422
  250. # 如果sql为None但没有解释性文本,返回通用错误
  251. if sql is None:
  252. from common.result import validation_failed_response
  253. return jsonify(validation_failed_response(
  254. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  255. )), 422
  256. # 缓存结果
  257. app.cache.set(id=conversation_id, field="question", value=question)
  258. app.cache.set(id=conversation_id, field="sql", value=sql)
  259. app.cache.set(id=conversation_id, field="df", value=df)
  260. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  261. summary = None
  262. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  263. try:
  264. summary = vn.generate_summary(question=question, df=df)
  265. logger.info(f"成功生成摘要: {summary}")
  266. except Exception as e:
  267. logger.warning(f"生成摘要失败: {str(e)}")
  268. summary = None
  269. app.cache.set(id=conversation_id, field="summary", value=summary)
  270. # 处理返回数据 - 使用新的query_result结构
  271. query_result = {
  272. "rows": [],
  273. "columns": [],
  274. "row_count": 0,
  275. "is_limited": False,
  276. "total_row_count": 0
  277. }
  278. if isinstance(df, pd.DataFrame):
  279. query_result["columns"] = list(df.columns)
  280. if not df.empty:
  281. total_rows = len(df)
  282. limited_df = df.head(MAX_RETURN_ROWS)
  283. query_result["rows"] = limited_df.to_dict(orient="records")
  284. query_result["row_count"] = len(limited_df)
  285. query_result["total_row_count"] = total_rows
  286. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  287. # 构建返回数据
  288. response_data = {
  289. "sql": sql,
  290. "query_result": query_result,
  291. "conversation_id": conversation_id,
  292. "session_id": browser_session_id,
  293. "cached": cached_sql is not None # 标识是否来自缓存
  294. }
  295. # 添加摘要(如果启用且生成成功)
  296. if ENABLE_RESULT_SUMMARY and summary is not None:
  297. response_data["summary"] = summary
  298. response_data["response"] = summary # 同时添加response字段
  299. from common.result import success_response
  300. return jsonify(success_response(
  301. response_text="查询执行完成" if summary is None else None,
  302. data=response_data
  303. ))
  304. except Exception as e:
  305. logger.error(f"ask_cached执行失败: {str(e)}")
  306. from common.result import internal_error_response
  307. return jsonify(internal_error_response(
  308. response_text="查询处理失败,请稍后重试"
  309. )), 500
  310. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  311. def citu_train_question_sql():
  312. """
  313. 训练问题-SQL对接口
  314. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  315. 支持仅传入SQL或同时传入问题和SQL进行训练。
  316. Args:
  317. question (str, optional): 用户问题
  318. sql (str, required): 对应的SQL查询语句
  319. Returns:
  320. JSON: 包含训练ID和成功消息的响应
  321. """
  322. try:
  323. req = request.get_json(force=True)
  324. question = req.get('question')
  325. sql = req.get('sql')
  326. if not sql:
  327. from common.result import bad_request_response
  328. return jsonify(bad_request_response(
  329. response_text="缺少必需参数:sql",
  330. missing_params=["sql"]
  331. )), 400
  332. # 正确的调用方式:同时传递question和sql
  333. if question:
  334. training_id = vn.train(question=question, sql=sql)
  335. logger.info(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  336. else:
  337. training_id = vn.train(sql=sql)
  338. logger.info(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  339. from common.result import success_response
  340. return jsonify(success_response(
  341. response_text="问题-SQL对训练成功",
  342. data={
  343. "training_id": training_id,
  344. "message": "Question-SQL pair trained successfully"
  345. }
  346. ))
  347. except Exception as e:
  348. from common.result import internal_error_response
  349. return jsonify(internal_error_response(
  350. response_text="训练失败,请稍后重试"
  351. )), 500
  352. # ============ LangGraph Agent 集成 ============
  353. # 全局Agent实例(单例模式)
  354. citu_langraph_agent = None
  355. def get_citu_langraph_agent():
  356. """获取LangGraph Agent实例(懒加载)"""
  357. global citu_langraph_agent
  358. if citu_langraph_agent is None:
  359. try:
  360. from agent.citu_agent import CituLangGraphAgent
  361. logger.info("开始创建LangGraph Agent实例...")
  362. citu_langraph_agent = CituLangGraphAgent()
  363. logger.info("LangGraph Agent实例创建成功")
  364. except ImportError as e:
  365. logger.critical(f"Agent模块导入失败: {str(e)}")
  366. logger.critical("请检查agent模块是否存在以及依赖是否正确安装")
  367. raise Exception(f"Agent模块导入失败: {str(e)}")
  368. except Exception as e:
  369. logger.critical(f"LangGraph Agent实例创建失败: {str(e)}")
  370. logger.critical(f"错误类型: {type(e).__name__}")
  371. # 提供更有用的错误信息
  372. if "config" in str(e).lower():
  373. logger.critical("可能是配置文件问题,请检查配置")
  374. elif "llm" in str(e).lower():
  375. logger.critical("可能是LLM连接问题,请检查LLM配置")
  376. elif "tool" in str(e).lower():
  377. logger.critical("可能是工具加载问题,请检查工具模块")
  378. raise Exception(f"Agent初始化失败: {str(e)}")
  379. return citu_langraph_agent
  380. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  381. def ask_agent():
  382. """
  383. 支持对话上下文的ask_agent API - 修正版
  384. """
  385. req = request.get_json(force=True)
  386. question = req.get("question", None)
  387. browser_session_id = req.get("session_id", None)
  388. # 新增参数解析
  389. user_id_input = req.get("user_id", None)
  390. conversation_id_input = req.get("conversation_id", None)
  391. continue_conversation = req.get("continue_conversation", False)
  392. # 新增:路由模式参数解析和验证
  393. api_routing_mode = req.get("routing_mode", None)
  394. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  395. if not question:
  396. return jsonify(bad_request_response(
  397. response_text="缺少必需参数:question",
  398. missing_params=["question"]
  399. )), 400
  400. # 验证routing_mode参数
  401. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  402. return jsonify(bad_request_response(
  403. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  404. invalid_params=["routing_mode"]
  405. )), 400
  406. try:
  407. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  408. login_user_id = session.get('user_id') if 'user_id' in session else None
  409. # 2. 智能ID解析(修正:传入登录用户ID)
  410. user_id = redis_conversation_manager.resolve_user_id(
  411. user_id_input, browser_session_id, request.remote_addr, login_user_id
  412. )
  413. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  414. user_id, conversation_id_input, continue_conversation
  415. )
  416. # 3. 获取上下文和上下文类型(提前到缓存检查之前)
  417. context = redis_conversation_manager.get_context(conversation_id)
  418. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  419. context_type = None
  420. if context:
  421. try:
  422. # 获取最后一条助手消息的metadata
  423. messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
  424. for message in reversed(messages): # 从最新的开始找
  425. if message.get("role") == "assistant":
  426. metadata = message.get("metadata", {})
  427. context_type = metadata.get("type")
  428. if context_type:
  429. logger.info(f"[AGENT_API] 检测到上下文类型: {context_type}")
  430. break
  431. except Exception as e:
  432. logger.warning(f"获取上下文类型失败: {str(e)}")
  433. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  434. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  435. if cached_answer:
  436. logger.info(f"[AGENT_API] 使用缓存答案")
  437. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  438. cached_response_type = cached_answer.get("type", "UNKNOWN")
  439. if cached_response_type == "DATABASE":
  440. # DATABASE类型:按优先级选择内容
  441. if cached_answer.get("response"):
  442. # 优先级1:错误或解释性回复(如SQL生成失败)
  443. assistant_response = cached_answer.get("response")
  444. elif cached_answer.get("summary"):
  445. # 优先级2:查询成功的摘要
  446. assistant_response = cached_answer.get("summary")
  447. elif cached_answer.get("query_result"):
  448. # 优先级3:构造简单描述
  449. query_result = cached_answer.get("query_result")
  450. row_count = query_result.get("row_count", 0)
  451. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  452. else:
  453. # 异常情况
  454. assistant_response = "数据库查询已处理。"
  455. else:
  456. # CHAT类型:直接使用response
  457. assistant_response = cached_answer.get("response", "")
  458. # 更新对话历史
  459. redis_conversation_manager.save_message(conversation_id, "user", question)
  460. redis_conversation_manager.save_message(
  461. conversation_id, "assistant",
  462. assistant_response,
  463. metadata={"from_cache": True}
  464. )
  465. # 添加对话信息到缓存结果
  466. cached_answer["conversation_id"] = conversation_id
  467. cached_answer["user_id"] = user_id
  468. cached_answer["from_cache"] = True
  469. cached_answer.update(conversation_status)
  470. # 使用agent_success_response返回标准格式
  471. return jsonify(agent_success_response(
  472. response_type=cached_answer.get("type", "UNKNOWN"),
  473. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  474. sql=cached_answer.get("sql"),
  475. records=cached_answer.get("query_result"), # 修改:query_result改为records
  476. summary=cached_answer.get("summary"),
  477. session_id=browser_session_id,
  478. execution_path=cached_answer.get("execution_path", []),
  479. classification_info=cached_answer.get("classification_info", {}),
  480. conversation_id=conversation_id,
  481. user_id=user_id,
  482. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  483. context_used=bool(context),
  484. from_cache=True,
  485. conversation_status=conversation_status["status"],
  486. conversation_message=conversation_status["message"],
  487. requested_conversation_id=conversation_status.get("requested_id")
  488. ))
  489. # 5. 保存用户消息
  490. redis_conversation_manager.save_message(conversation_id, "user", question)
  491. # 6. 构建带上下文的问题
  492. if context:
  493. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  494. logger.info(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  495. else:
  496. enhanced_question = question
  497. logger.info(f"[AGENT_API] 新对话,无上下文")
  498. # 7. 确定最终使用的路由模式(优先级逻辑)
  499. if api_routing_mode:
  500. # API传了参数,优先使用
  501. effective_routing_mode = api_routing_mode
  502. logger.info(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
  503. else:
  504. # API没传参数,使用配置文件
  505. try:
  506. from app_config import QUESTION_ROUTING_MODE
  507. effective_routing_mode = QUESTION_ROUTING_MODE
  508. logger.info(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
  509. except ImportError:
  510. effective_routing_mode = "hybrid"
  511. logger.info(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  512. # 8. 现有Agent处理逻辑(修改为传递路由模式)
  513. try:
  514. agent = get_citu_langraph_agent()
  515. except Exception as e:
  516. logger.critical(f"Agent初始化失败: {str(e)}")
  517. return jsonify(service_unavailable_response(
  518. response_text="AI服务暂时不可用,请稍后重试",
  519. can_retry=True
  520. )), 503
  521. # 异步调用Agent处理问题
  522. import asyncio
  523. agent_result = asyncio.run(agent.process_question(
  524. question=enhanced_question, # 使用增强后的问题
  525. session_id=browser_session_id,
  526. context_type=context_type, # 传递上下文类型
  527. routing_mode=effective_routing_mode # 新增:传递路由模式
  528. ))
  529. # 8. 处理Agent结果
  530. if agent_result.get("success", False):
  531. # 修正:直接从agent_result获取字段,因为它就是final_response
  532. response_type = agent_result.get("type", "UNKNOWN")
  533. response_text = agent_result.get("response", "")
  534. sql = agent_result.get("sql")
  535. query_result = agent_result.get("query_result")
  536. summary = agent_result.get("summary")
  537. execution_path = agent_result.get("execution_path", [])
  538. classification_info = agent_result.get("classification_info", {})
  539. # 确定助手回复内容的优先级
  540. if response_type == "DATABASE":
  541. # DATABASE类型:按优先级选择内容
  542. if response_text:
  543. # 优先级1:错误或解释性回复(如SQL生成失败)
  544. assistant_response = response_text
  545. elif summary:
  546. # 优先级2:查询成功的摘要
  547. assistant_response = summary
  548. elif query_result:
  549. # 优先级3:构造简单描述
  550. row_count = query_result.get("row_count", 0)
  551. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  552. else:
  553. # 异常情况
  554. assistant_response = "数据库查询已处理。"
  555. else:
  556. # CHAT类型:直接使用response
  557. assistant_response = response_text
  558. # 保存助手回复
  559. redis_conversation_manager.save_message(
  560. conversation_id, "assistant", assistant_response,
  561. metadata={
  562. "type": response_type,
  563. "sql": sql,
  564. "execution_path": execution_path
  565. }
  566. )
  567. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  568. # 直接缓存agent_result,它已经包含所有需要的字段
  569. redis_conversation_manager.cache_answer(question, agent_result, context)
  570. # 使用agent_success_response的正确方式
  571. return jsonify(agent_success_response(
  572. response_type=response_type,
  573. response=response_text, # 修正:使用response而不是response_text
  574. sql=sql,
  575. records=query_result, # 修改:query_result改为records
  576. summary=summary,
  577. session_id=browser_session_id,
  578. execution_path=execution_path,
  579. classification_info=classification_info,
  580. conversation_id=conversation_id,
  581. user_id=user_id,
  582. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  583. context_used=bool(context),
  584. from_cache=False,
  585. conversation_status=conversation_status["status"],
  586. conversation_message=conversation_status["message"],
  587. requested_conversation_id=conversation_status.get("requested_id"),
  588. routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
  589. routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
  590. ))
  591. else:
  592. # 错误处理(修正:确保使用现有的错误响应格式)
  593. error_message = agent_result.get("error", "Agent处理失败")
  594. error_code = agent_result.get("error_code", 500)
  595. return jsonify(agent_error_response(
  596. response_text=error_message,
  597. error_type="agent_processing_failed",
  598. code=error_code,
  599. session_id=browser_session_id,
  600. conversation_id=conversation_id,
  601. user_id=user_id
  602. )), error_code
  603. except Exception as e:
  604. logger.error(f"ask_agent执行失败: {str(e)}")
  605. return jsonify(internal_error_response(
  606. response_text="查询处理失败,请稍后重试"
  607. )), 500
  608. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  609. def agent_health():
  610. """
  611. Agent健康检查接口
  612. 响应格式:
  613. {
  614. "success": true/false,
  615. "code": 200/503,
  616. "message": "healthy/degraded/unhealthy",
  617. "data": {
  618. "status": "healthy/degraded/unhealthy",
  619. "test_result": true/false,
  620. "workflow_compiled": true/false,
  621. "tools_count": 4,
  622. "message": "详细信息",
  623. "timestamp": "2024-01-01T12:00:00",
  624. "checks": {
  625. "agent_creation": true/false,
  626. "tools_import": true/false,
  627. "llm_connection": true/false,
  628. "classifier_ready": true/false
  629. }
  630. }
  631. }
  632. """
  633. try:
  634. # 基础健康检查
  635. health_data = {
  636. "status": "unknown",
  637. "test_result": False,
  638. "workflow_compiled": False,
  639. "tools_count": 0,
  640. "message": "",
  641. "timestamp": datetime.now().isoformat(),
  642. "checks": {
  643. "agent_creation": False,
  644. "tools_import": False,
  645. "llm_connection": False,
  646. "classifier_ready": False
  647. }
  648. }
  649. # 检查1: Agent创建
  650. try:
  651. agent = get_citu_langraph_agent()
  652. health_data["checks"]["agent_creation"] = True
  653. # 修正:Agent现在是动态创建workflow的,不再有预创建的workflow属性
  654. health_data["workflow_compiled"] = True # 动态创建,始终可用
  655. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  656. except Exception as e:
  657. health_data["message"] = f"Agent创建失败: {str(e)}"
  658. health_data["status"] = "unhealthy" # 设置状态
  659. from common.result import health_error_response
  660. return jsonify(health_error_response(**health_data)), 503
  661. # 检查2: 工具导入
  662. try:
  663. from agent.tools import TOOLS
  664. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  665. except Exception as e:
  666. health_data["message"] = f"工具导入失败: {str(e)}"
  667. # 检查3: LLM连接(简单测试)
  668. try:
  669. from agent.tools.utils import get_compatible_llm
  670. llm = get_compatible_llm()
  671. health_data["checks"]["llm_connection"] = llm is not None
  672. except Exception as e:
  673. health_data["message"] = f"LLM连接失败: {str(e)}"
  674. # 检查4: 分类器准备
  675. try:
  676. from agent.classifier import QuestionClassifier
  677. classifier = QuestionClassifier()
  678. health_data["checks"]["classifier_ready"] = True
  679. except Exception as e:
  680. health_data["message"] = f"分类器失败: {str(e)}"
  681. # 检查5: 完整流程测试(可选)
  682. try:
  683. if all(health_data["checks"].values()):
  684. import asyncio
  685. # 异步调用健康检查
  686. test_result = asyncio.run(agent.health_check())
  687. health_data["test_result"] = test_result.get("status") == "healthy"
  688. health_data["status"] = test_result.get("status", "unknown")
  689. health_data["message"] = test_result.get("message", "健康检查完成")
  690. else:
  691. health_data["status"] = "degraded"
  692. health_data["message"] = "部分组件异常"
  693. except Exception as e:
  694. logger.error(f"健康检查异常: {str(e)}")
  695. import traceback
  696. logger.error(f"详细健康检查错误: {traceback.format_exc()}")
  697. health_data["status"] = "degraded"
  698. health_data["message"] = f"完整测试失败: {str(e)}"
  699. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  700. from common.result import health_success_response, health_error_response
  701. if health_data["status"] == "healthy":
  702. return jsonify(health_success_response(**health_data))
  703. elif health_data["status"] == "degraded":
  704. return jsonify(health_error_response(**health_data)), 503
  705. else:
  706. # 确保状态设置为unhealthy
  707. health_data["status"] = "unhealthy"
  708. return jsonify(health_error_response(**health_data)), 503
  709. except Exception as e:
  710. logger.error(f"顶层健康检查异常: {str(e)}")
  711. import traceback
  712. logger.error(f"详细错误信息: {traceback.format_exc()}")
  713. from common.result import internal_error_response
  714. return jsonify(internal_error_response(
  715. response_text="健康检查失败,请稍后重试"
  716. )), 500
  717. # ==================== 日常管理API ====================
  718. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  719. def cache_overview():
  720. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  721. try:
  722. cache = app.cache
  723. result_data = {
  724. 'overview_summary': {
  725. 'total_conversations': 0,
  726. 'total_sessions': 0,
  727. 'query_time': datetime.now().isoformat()
  728. },
  729. 'recent_conversations': [], # 最近的对话
  730. 'session_summary': [] # 会话摘要
  731. }
  732. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  733. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  734. # 获取会话信息
  735. if hasattr(cache, 'get_all_sessions'):
  736. all_sessions = cache.get_all_sessions()
  737. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  738. # 会话摘要(按最近活动排序)
  739. session_list = []
  740. for session_id, session_data in all_sessions.items():
  741. session_summary = {
  742. 'session_id': session_id,
  743. 'start_time': session_data['start_time'].isoformat(),
  744. 'conversation_count': session_data.get('conversation_count', 0),
  745. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  746. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  747. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  748. }
  749. session_list.append(session_summary)
  750. # 按最后活动时间排序
  751. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  752. result_data['session_summary'] = session_list
  753. # 最近的对话(最多显示10个)
  754. conversation_list = []
  755. for conversation_id, conversation_data in cache.cache.items():
  756. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  757. conversation_info = {
  758. 'conversation_id': conversation_id,
  759. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  760. 'session_id': cache.conversation_to_session.get(conversation_id),
  761. 'has_question': 'question' in conversation_data,
  762. 'has_sql': 'sql' in conversation_data,
  763. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  764. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  765. }
  766. # 计算对话持续时间
  767. if conversation_start_time:
  768. duration = datetime.now() - conversation_start_time
  769. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  770. conversation_list.append(conversation_info)
  771. # 按对话开始时间排序,显示最新的10个
  772. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  773. result_data['recent_conversations'] = conversation_list[:10]
  774. from common.result import success_response
  775. return jsonify(success_response(
  776. response_text="缓存概览查询完成",
  777. data=result_data
  778. ))
  779. except Exception as e:
  780. from common.result import internal_error_response
  781. return jsonify(internal_error_response(
  782. response_text="获取缓存概览失败,请稍后重试"
  783. )), 500
  784. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  785. def cache_stats():
  786. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  787. try:
  788. cache = app.cache
  789. current_time = datetime.now()
  790. stats = {
  791. 'basic_stats': {
  792. 'total_sessions': len(getattr(cache, 'session_info', {})),
  793. 'total_conversations': len(getattr(cache, 'cache', {})),
  794. 'active_sessions': 0, # 最近30分钟有活动
  795. 'average_conversations_per_session': 0
  796. },
  797. 'time_distribution': {
  798. 'sessions': {
  799. 'last_1_hour': 0,
  800. 'last_6_hours': 0,
  801. 'last_24_hours': 0,
  802. 'last_7_days': 0,
  803. 'older': 0
  804. },
  805. 'conversations': {
  806. 'last_1_hour': 0,
  807. 'last_6_hours': 0,
  808. 'last_24_hours': 0,
  809. 'last_7_days': 0,
  810. 'older': 0
  811. }
  812. },
  813. 'session_details': [],
  814. 'time_ranges': {
  815. 'oldest_session': None,
  816. 'newest_session': None,
  817. 'oldest_conversation': None,
  818. 'newest_conversation': None
  819. }
  820. }
  821. # 会话统计
  822. if hasattr(cache, 'session_info'):
  823. session_times = []
  824. total_conversations = 0
  825. for session_id, session_data in cache.session_info.items():
  826. start_time = session_data['start_time']
  827. session_times.append(start_time)
  828. conversation_count = len(session_data.get('conversations', []))
  829. total_conversations += conversation_count
  830. # 检查活跃状态
  831. last_activity = session_data.get('last_activity', session_data['start_time'])
  832. if (current_time - last_activity).total_seconds() < 1800:
  833. stats['basic_stats']['active_sessions'] += 1
  834. # 时间分布统计
  835. age_hours = (current_time - start_time).total_seconds() / 3600
  836. if age_hours <= 1:
  837. stats['time_distribution']['sessions']['last_1_hour'] += 1
  838. elif age_hours <= 6:
  839. stats['time_distribution']['sessions']['last_6_hours'] += 1
  840. elif age_hours <= 24:
  841. stats['time_distribution']['sessions']['last_24_hours'] += 1
  842. elif age_hours <= 168: # 7 days
  843. stats['time_distribution']['sessions']['last_7_days'] += 1
  844. else:
  845. stats['time_distribution']['sessions']['older'] += 1
  846. # 会话详细信息
  847. session_duration = current_time - start_time
  848. stats['session_details'].append({
  849. 'session_id': session_id,
  850. 'start_time': start_time.isoformat(),
  851. 'last_activity': last_activity.isoformat(),
  852. 'conversation_count': conversation_count,
  853. 'duration_seconds': session_duration.total_seconds(),
  854. 'duration_formatted': str(session_duration),
  855. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  856. 'browser_session_id': session_data.get('browser_session_id')
  857. })
  858. # 计算平均值
  859. if len(cache.session_info) > 0:
  860. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  861. # 时间范围
  862. if session_times:
  863. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  864. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  865. # 对话统计
  866. if hasattr(cache, 'conversation_start_times'):
  867. conversation_times = []
  868. for conv_time in cache.conversation_start_times.values():
  869. conversation_times.append(conv_time)
  870. age_hours = (current_time - conv_time).total_seconds() / 3600
  871. if age_hours <= 1:
  872. stats['time_distribution']['conversations']['last_1_hour'] += 1
  873. elif age_hours <= 6:
  874. stats['time_distribution']['conversations']['last_6_hours'] += 1
  875. elif age_hours <= 24:
  876. stats['time_distribution']['conversations']['last_24_hours'] += 1
  877. elif age_hours <= 168:
  878. stats['time_distribution']['conversations']['last_7_days'] += 1
  879. else:
  880. stats['time_distribution']['conversations']['older'] += 1
  881. if conversation_times:
  882. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  883. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  884. # 按最近活动排序会话详情
  885. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  886. from common.result import success_response
  887. return jsonify(success_response(
  888. response_text="缓存统计信息查询完成",
  889. data=stats
  890. ))
  891. except Exception as e:
  892. from common.result import internal_error_response
  893. return jsonify(internal_error_response(
  894. response_text="获取缓存统计失败,请稍后重试"
  895. )), 500
  896. # ==================== 高级功能API ====================
  897. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  898. def cache_export():
  899. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  900. try:
  901. cache = app.cache
  902. # 验证缓存的实际结构
  903. if not hasattr(cache, 'cache'):
  904. from common.result import internal_error_response
  905. return jsonify(internal_error_response(
  906. response_text="缓存对象结构异常,请联系系统管理员"
  907. )), 500
  908. if not isinstance(cache.cache, dict):
  909. from common.result import internal_error_response
  910. return jsonify(internal_error_response(
  911. response_text="缓存数据类型异常,请联系系统管理员"
  912. )), 500
  913. # 定义JSON序列化辅助函数
  914. def make_json_serializable(obj):
  915. """将对象转换为JSON可序列化的格式"""
  916. if obj is None:
  917. return None
  918. elif isinstance(obj, (str, int, float, bool)):
  919. return obj
  920. elif isinstance(obj, (list, tuple)):
  921. return [make_json_serializable(item) for item in obj]
  922. elif isinstance(obj, dict):
  923. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  924. elif hasattr(obj, 'isoformat'): # datetime objects
  925. return obj.isoformat()
  926. elif hasattr(obj, 'item'): # numpy scalars
  927. return obj.item()
  928. elif hasattr(obj, 'tolist'): # numpy arrays
  929. return obj.tolist()
  930. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  931. return str(obj)
  932. else:
  933. return str(obj)
  934. # 获取完整的原始缓存数据
  935. raw_cache = cache.cache
  936. # 获取会话和对话时间信息
  937. conversation_times = getattr(cache, 'conversation_start_times', {})
  938. session_info = getattr(cache, 'session_info', {})
  939. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  940. export_data = {
  941. 'export_metadata': {
  942. 'export_time': datetime.now().isoformat(),
  943. 'total_conversations': len(raw_cache),
  944. 'total_sessions': len(session_info),
  945. 'cache_type': type(cache).__name__,
  946. 'cache_object_info': str(cache),
  947. 'has_session_times': bool(session_info),
  948. 'has_conversation_times': bool(conversation_times)
  949. },
  950. 'session_info': {
  951. session_id: {
  952. 'start_time': session_data['start_time'].isoformat(),
  953. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  954. 'conversations': session_data['conversations'],
  955. 'conversation_count': len(session_data['conversations']),
  956. 'browser_session_id': session_data.get('browser_session_id'),
  957. 'user_info': session_data.get('user_info', {})
  958. }
  959. for session_id, session_data in session_info.items()
  960. },
  961. 'conversation_times': {
  962. conversation_id: start_time.isoformat()
  963. for conversation_id, start_time in conversation_times.items()
  964. },
  965. 'conversation_to_session_mapping': conversation_to_session,
  966. 'conversations': {}
  967. }
  968. # 处理每个对话的完整数据
  969. for conversation_id, conversation_data in raw_cache.items():
  970. # 获取时间信息
  971. conversation_start_time = conversation_times.get(conversation_id)
  972. session_id = conversation_to_session.get(conversation_id)
  973. session_start_time = None
  974. if session_id and session_id in session_info:
  975. session_start_time = session_info[session_id]['start_time']
  976. processed_conversation = {
  977. 'conversation_id': conversation_id,
  978. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  979. 'session_id': session_id,
  980. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  981. 'field_count': len(conversation_data),
  982. 'fields': {}
  983. }
  984. # 添加时间计算
  985. if conversation_start_time:
  986. conversation_duration = datetime.now() - conversation_start_time
  987. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  988. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  989. if session_start_time:
  990. session_duration = datetime.now() - session_start_time
  991. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  992. processed_conversation['session_duration_formatted'] = str(session_duration)
  993. # 处理每个字段,确保JSON序列化安全
  994. for field_name, field_value in conversation_data.items():
  995. field_info = {
  996. 'field_name': field_name,
  997. 'data_type': type(field_value).__name__,
  998. 'is_none': field_value is None
  999. }
  1000. try:
  1001. if field_value is None:
  1002. field_info['value'] = None
  1003. elif field_name in ['conversation_start_time', 'session_start_time']:
  1004. # 处理时间字段
  1005. field_info['content'] = make_json_serializable(field_value)
  1006. elif field_name == 'df' and field_value is not None:
  1007. # DataFrame的安全处理
  1008. if hasattr(field_value, 'to_dict'):
  1009. # 安全地处理dtypes
  1010. try:
  1011. dtypes_dict = {}
  1012. for col, dtype in field_value.dtypes.items():
  1013. dtypes_dict[col] = str(dtype)
  1014. except Exception:
  1015. dtypes_dict = {"error": "无法序列化dtypes"}
  1016. # 安全地处理内存使用
  1017. try:
  1018. memory_usage = field_value.memory_usage(deep=True)
  1019. memory_dict = {}
  1020. for idx, usage in memory_usage.items():
  1021. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  1022. except Exception:
  1023. memory_dict = {"error": "无法获取内存使用信息"}
  1024. field_info.update({
  1025. 'dataframe_info': {
  1026. 'shape': list(field_value.shape),
  1027. 'columns': list(field_value.columns),
  1028. 'dtypes': dtypes_dict,
  1029. 'index_info': {
  1030. 'type': type(field_value.index).__name__,
  1031. 'length': len(field_value.index)
  1032. }
  1033. },
  1034. 'data': make_json_serializable(field_value.to_dict('records')),
  1035. 'memory_usage': memory_dict
  1036. })
  1037. else:
  1038. field_info['value'] = str(field_value)
  1039. field_info['note'] = 'not_standard_dataframe'
  1040. elif field_name == 'fig_json':
  1041. # 图表JSON数据处理
  1042. if isinstance(field_value, str):
  1043. try:
  1044. import json
  1045. parsed_fig = json.loads(field_value)
  1046. field_info.update({
  1047. 'json_valid': True,
  1048. 'json_size_bytes': len(field_value),
  1049. 'plotly_structure': {
  1050. 'has_data': 'data' in parsed_fig,
  1051. 'has_layout': 'layout' in parsed_fig,
  1052. 'data_traces_count': len(parsed_fig.get('data', [])),
  1053. },
  1054. 'raw_json': field_value
  1055. })
  1056. except json.JSONDecodeError:
  1057. field_info.update({
  1058. 'json_valid': False,
  1059. 'raw_content': str(field_value)
  1060. })
  1061. else:
  1062. field_info['value'] = make_json_serializable(field_value)
  1063. elif field_name == 'followup_questions':
  1064. # 后续问题列表
  1065. field_info.update({
  1066. 'content': make_json_serializable(field_value)
  1067. })
  1068. elif field_name in ['question', 'sql', 'summary']:
  1069. # 文本字段
  1070. if isinstance(field_value, str):
  1071. field_info.update({
  1072. 'text_length': len(field_value),
  1073. 'content': field_value
  1074. })
  1075. else:
  1076. field_info['value'] = make_json_serializable(field_value)
  1077. else:
  1078. # 未知字段的安全处理
  1079. field_info['content'] = make_json_serializable(field_value)
  1080. except Exception as e:
  1081. field_info.update({
  1082. 'processing_error': str(e),
  1083. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1084. })
  1085. processed_conversation['fields'][field_name] = field_info
  1086. export_data['conversations'][conversation_id] = processed_conversation
  1087. # 添加缓存统计信息
  1088. field_frequency = {}
  1089. data_types_found = set()
  1090. total_dataframes = 0
  1091. total_questions = 0
  1092. for conv_data in export_data['conversations'].values():
  1093. for field_name, field_info in conv_data['fields'].items():
  1094. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1095. data_types_found.add(field_info['data_type'])
  1096. if field_name == 'df' and not field_info['is_none']:
  1097. total_dataframes += 1
  1098. if field_name == 'question' and not field_info['is_none']:
  1099. total_questions += 1
  1100. export_data['cache_statistics'] = {
  1101. 'field_frequency': field_frequency,
  1102. 'data_types_found': list(data_types_found),
  1103. 'total_dataframes': total_dataframes,
  1104. 'total_questions': total_questions,
  1105. 'has_session_timing': 'session_start_time' in field_frequency,
  1106. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1107. }
  1108. from common.result import success_response
  1109. return jsonify(success_response(
  1110. response_text="缓存数据导出完成",
  1111. data=export_data
  1112. ))
  1113. except Exception as e:
  1114. import traceback
  1115. error_details = {
  1116. 'error_message': str(e),
  1117. 'error_type': type(e).__name__,
  1118. 'traceback': traceback.format_exc()
  1119. }
  1120. from common.result import internal_error_response
  1121. return jsonify(internal_error_response(
  1122. response_text="导出缓存失败,请稍后重试"
  1123. )), 500
  1124. # ==================== 清理功能API ====================
  1125. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1126. def cache_preview_cleanup():
  1127. """清理功能:预览删除操作 - 保持原功能"""
  1128. try:
  1129. req = request.get_json(force=True)
  1130. # 时间条件 - 支持三种方式
  1131. older_than_hours = req.get('older_than_hours')
  1132. older_than_days = req.get('older_than_days')
  1133. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1134. cache = app.cache
  1135. # 计算截止时间
  1136. cutoff_time = None
  1137. time_condition = None
  1138. if older_than_hours:
  1139. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1140. time_condition = f"older_than_hours: {older_than_hours}"
  1141. elif older_than_days:
  1142. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1143. time_condition = f"older_than_days: {older_than_days}"
  1144. elif before_timestamp:
  1145. try:
  1146. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1147. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1148. time_condition = f"before_timestamp: {before_timestamp}"
  1149. except ValueError:
  1150. from common.result import validation_failed_response
  1151. return jsonify(validation_failed_response(
  1152. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1153. )), 422
  1154. else:
  1155. from common.result import bad_request_response
  1156. return jsonify(bad_request_response(
  1157. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1158. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1159. )), 400
  1160. preview = {
  1161. 'time_condition': time_condition,
  1162. 'cutoff_time': cutoff_time.isoformat(),
  1163. 'will_be_removed': {
  1164. 'sessions': []
  1165. },
  1166. 'will_be_kept': {
  1167. 'sessions_count': 0,
  1168. 'conversations_count': 0
  1169. },
  1170. 'summary': {
  1171. 'sessions_to_remove': 0,
  1172. 'conversations_to_remove': 0,
  1173. 'sessions_to_keep': 0,
  1174. 'conversations_to_keep': 0
  1175. }
  1176. }
  1177. # 预览按session删除
  1178. sessions_to_remove_count = 0
  1179. conversations_to_remove_count = 0
  1180. for session_id, session_data in cache.session_info.items():
  1181. session_preview = {
  1182. 'session_id': session_id,
  1183. 'start_time': session_data['start_time'].isoformat(),
  1184. 'conversation_count': len(session_data['conversations']),
  1185. 'conversations': []
  1186. }
  1187. # 添加conversation详情
  1188. for conv_id in session_data['conversations']:
  1189. if conv_id in cache.cache:
  1190. conv_data = cache.cache[conv_id]
  1191. session_preview['conversations'].append({
  1192. 'conversation_id': conv_id,
  1193. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1194. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1195. })
  1196. if session_data['start_time'] < cutoff_time:
  1197. preview['will_be_removed']['sessions'].append(session_preview)
  1198. sessions_to_remove_count += 1
  1199. conversations_to_remove_count += len(session_data['conversations'])
  1200. else:
  1201. preview['will_be_kept']['sessions_count'] += 1
  1202. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1203. # 更新摘要统计
  1204. preview['summary'] = {
  1205. 'sessions_to_remove': sessions_to_remove_count,
  1206. 'conversations_to_remove': conversations_to_remove_count,
  1207. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1208. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1209. }
  1210. from common.result import success_response
  1211. return jsonify(success_response(
  1212. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1213. data=preview
  1214. ))
  1215. except Exception as e:
  1216. from common.result import internal_error_response
  1217. return jsonify(internal_error_response(
  1218. response_text="预览清理操作失败,请稍后重试"
  1219. )), 500
  1220. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1221. def cache_cleanup():
  1222. """清理功能:实际删除缓存 - 保持原功能"""
  1223. try:
  1224. req = request.get_json(force=True)
  1225. # 时间条件 - 支持三种方式
  1226. older_than_hours = req.get('older_than_hours')
  1227. older_than_days = req.get('older_than_days')
  1228. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1229. cache = app.cache
  1230. if not hasattr(cache, 'session_info'):
  1231. from common.result import service_unavailable_response
  1232. return jsonify(service_unavailable_response(
  1233. response_text="缓存不支持会话功能"
  1234. )), 503
  1235. # 计算截止时间
  1236. cutoff_time = None
  1237. time_condition = None
  1238. if older_than_hours:
  1239. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1240. time_condition = f"older_than_hours: {older_than_hours}"
  1241. elif older_than_days:
  1242. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1243. time_condition = f"older_than_days: {older_than_days}"
  1244. elif before_timestamp:
  1245. try:
  1246. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1247. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1248. time_condition = f"before_timestamp: {before_timestamp}"
  1249. except ValueError:
  1250. from common.result import validation_failed_response
  1251. return jsonify(validation_failed_response(
  1252. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1253. )), 422
  1254. else:
  1255. from common.result import bad_request_response
  1256. return jsonify(bad_request_response(
  1257. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1258. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1259. )), 400
  1260. cleanup_stats = {
  1261. 'time_condition': time_condition,
  1262. 'cutoff_time': cutoff_time.isoformat(),
  1263. 'sessions_removed': 0,
  1264. 'conversations_removed': 0,
  1265. 'sessions_kept': 0,
  1266. 'conversations_kept': 0,
  1267. 'removed_session_ids': [],
  1268. 'removed_conversation_ids': []
  1269. }
  1270. # 按session删除
  1271. sessions_to_remove = []
  1272. for session_id, session_data in cache.session_info.items():
  1273. if session_data['start_time'] < cutoff_time:
  1274. sessions_to_remove.append(session_id)
  1275. # 删除符合条件的sessions及其所有conversations
  1276. for session_id in sessions_to_remove:
  1277. session_data = cache.session_info[session_id]
  1278. conversations_in_session = session_data['conversations'].copy()
  1279. # 删除session中的所有conversations
  1280. for conv_id in conversations_in_session:
  1281. if conv_id in cache.cache:
  1282. del cache.cache[conv_id]
  1283. cleanup_stats['conversations_removed'] += 1
  1284. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1285. # 清理conversation相关的时间记录
  1286. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1287. del cache.conversation_start_times[conv_id]
  1288. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1289. del cache.conversation_to_session[conv_id]
  1290. # 删除session记录
  1291. del cache.session_info[session_id]
  1292. cleanup_stats['sessions_removed'] += 1
  1293. cleanup_stats['removed_session_ids'].append(session_id)
  1294. # 统计保留的sessions和conversations
  1295. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1296. cleanup_stats['conversations_kept'] = len(cache.cache)
  1297. from common.result import success_response
  1298. return jsonify(success_response(
  1299. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1300. data=cleanup_stats
  1301. ))
  1302. except Exception as e:
  1303. from common.result import internal_error_response
  1304. return jsonify(internal_error_response(
  1305. response_text="缓存清理失败,请稍后重试"
  1306. )), 500
  1307. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1308. def training_error_question_sql():
  1309. """
  1310. 存储错误的question-sql对到error_sql集合中
  1311. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1312. Args:
  1313. question (str, required): 用户问题
  1314. sql (str, required): 对应的错误SQL查询语句
  1315. Returns:
  1316. JSON: 包含训练ID和成功消息的响应
  1317. """
  1318. try:
  1319. data = request.get_json()
  1320. question = data.get('question')
  1321. sql = data.get('sql')
  1322. logger.debug(f"接收到错误SQL训练请求: question={question}, sql={sql}")
  1323. if not question or not sql:
  1324. from common.result import bad_request_response
  1325. missing_params = []
  1326. if not question:
  1327. missing_params.append("question")
  1328. if not sql:
  1329. missing_params.append("sql")
  1330. return jsonify(bad_request_response(
  1331. response_text="question和sql参数都是必需的",
  1332. missing_params=missing_params
  1333. )), 400
  1334. # 使用vn实例的train_error_sql方法存储错误SQL
  1335. id = vn.train_error_sql(question=question, sql=sql)
  1336. logger.info(f"成功存储错误SQL,ID: {id}")
  1337. from common.result import success_response
  1338. return jsonify(success_response(
  1339. response_text="错误SQL对已成功存储",
  1340. data={
  1341. "id": id,
  1342. "message": "错误SQL对已成功存储到error_sql集合"
  1343. }
  1344. ))
  1345. except Exception as e:
  1346. logger.error(f"存储错误SQL失败: {str(e)}")
  1347. from common.result import internal_error_response
  1348. return jsonify(internal_error_response(
  1349. response_text="存储错误SQL失败,请稍后重试"
  1350. )), 500
  1351. # ==================== Redis对话管理API ====================
  1352. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1353. def get_user_conversations(user_id: str):
  1354. """获取用户的对话列表(按时间倒序)"""
  1355. try:
  1356. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1357. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1358. # 为每个对话动态获取标题(第一条用户消息)
  1359. for conversation in conversations:
  1360. conversation_id = conversation['conversation_id']
  1361. try:
  1362. # 获取所有消息,然后取第一条用户消息作为标题
  1363. messages = redis_conversation_manager.get_conversation_messages(conversation_id)
  1364. if messages and len(messages) > 0:
  1365. # 找到第一条用户消息(按时间顺序)
  1366. first_user_message = None
  1367. for message in messages:
  1368. if message.get('role') == 'user':
  1369. first_user_message = message
  1370. break
  1371. if first_user_message:
  1372. title = first_user_message.get('content', '对话').strip()
  1373. # 限制标题长度,保持整洁
  1374. if len(title) > 50:
  1375. conversation['conversation_title'] = title[:47] + "..."
  1376. else:
  1377. conversation['conversation_title'] = title
  1378. else:
  1379. conversation['conversation_title'] = "对话"
  1380. else:
  1381. conversation['conversation_title'] = "空对话"
  1382. except Exception as e:
  1383. logger.warning(f"获取对话标题失败 {conversation_id}: {str(e)}")
  1384. conversation['conversation_title'] = "对话"
  1385. return jsonify(success_response(
  1386. response_text="获取用户对话列表成功",
  1387. data={
  1388. "user_id": user_id,
  1389. "conversations": conversations,
  1390. "total_count": len(conversations)
  1391. }
  1392. ))
  1393. except Exception as e:
  1394. return jsonify(internal_error_response(
  1395. response_text="获取对话列表失败,请稍后重试"
  1396. )), 500
  1397. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1398. def get_conversation_messages(conversation_id: str):
  1399. """获取特定对话的消息历史"""
  1400. try:
  1401. limit = request.args.get('limit', type=int) # 可选参数
  1402. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1403. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1404. return jsonify(success_response(
  1405. response_text="获取对话消息成功",
  1406. data={
  1407. "conversation_id": conversation_id,
  1408. "conversation_meta": meta,
  1409. "messages": messages,
  1410. "message_count": len(messages)
  1411. }
  1412. ))
  1413. except Exception as e:
  1414. return jsonify(internal_error_response(
  1415. response_text="获取对话消息失败"
  1416. )), 500
  1417. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1418. def get_conversation_context(conversation_id: str):
  1419. """获取对话上下文(格式化用于LLM)"""
  1420. try:
  1421. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1422. context = redis_conversation_manager.get_context_for_display(conversation_id, count)
  1423. return jsonify(success_response(
  1424. response_text="获取对话上下文成功",
  1425. data={
  1426. "conversation_id": conversation_id,
  1427. "context": context,
  1428. "context_message_count": count
  1429. }
  1430. ))
  1431. except Exception as e:
  1432. return jsonify(internal_error_response(
  1433. response_text="获取对话上下文失败"
  1434. )), 500
  1435. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1436. def conversation_stats():
  1437. """获取对话系统统计信息"""
  1438. try:
  1439. stats = redis_conversation_manager.get_stats()
  1440. return jsonify(success_response(
  1441. response_text="获取统计信息成功",
  1442. data=stats
  1443. ))
  1444. except Exception as e:
  1445. return jsonify(internal_error_response(
  1446. response_text="获取统计信息失败,请稍后重试"
  1447. )), 500
  1448. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1449. def conversation_cleanup():
  1450. """手动清理过期对话"""
  1451. try:
  1452. redis_conversation_manager.cleanup_expired_conversations()
  1453. return jsonify(success_response(
  1454. response_text="对话清理完成"
  1455. ))
  1456. except Exception as e:
  1457. return jsonify(internal_error_response(
  1458. response_text="对话清理失败,请稍后重试"
  1459. )), 500
  1460. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1461. def get_user_conversations_with_messages(user_id: str):
  1462. """
  1463. 获取用户的完整对话数据(包含所有消息)
  1464. 一次性返回用户的所有对话和每个对话下的消息历史
  1465. Args:
  1466. user_id: 用户ID(路径参数)
  1467. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1468. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1469. Returns:
  1470. 包含用户所有对话和消息的完整数据
  1471. """
  1472. try:
  1473. # 获取可选参数,不传递时使用None(返回所有记录)
  1474. conversation_limit = request.args.get('conversation_limit', type=int)
  1475. message_limit = request.args.get('message_limit', type=int)
  1476. # 获取用户的对话列表
  1477. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1478. # 为每个对话获取消息历史
  1479. full_conversations = []
  1480. total_messages = 0
  1481. for conversation in conversations:
  1482. conversation_id = conversation['conversation_id']
  1483. # 获取对话消息
  1484. messages = redis_conversation_manager.get_conversation_messages(
  1485. conversation_id, message_limit
  1486. )
  1487. # 获取对话元数据
  1488. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1489. # 组合完整数据
  1490. full_conversation = {
  1491. **conversation, # 基础对话信息
  1492. 'meta': meta, # 对话元数据
  1493. 'messages': messages, # 消息列表
  1494. 'message_count': len(messages)
  1495. }
  1496. full_conversations.append(full_conversation)
  1497. total_messages += len(messages)
  1498. return jsonify(success_response(
  1499. response_text="获取用户完整对话数据成功",
  1500. data={
  1501. "user_id": user_id,
  1502. "conversations": full_conversations,
  1503. "total_conversations": len(full_conversations),
  1504. "total_messages": total_messages,
  1505. "conversation_limit_applied": conversation_limit,
  1506. "message_limit_applied": message_limit,
  1507. "query_time": datetime.now().isoformat()
  1508. }
  1509. ))
  1510. except Exception as e:
  1511. logger.error(f"获取用户完整对话数据失败: {str(e)}")
  1512. return jsonify(internal_error_response(
  1513. response_text="获取用户对话数据失败,请稍后重试"
  1514. )), 500
  1515. # ==================== Embedding缓存管理接口 ====================
  1516. @app.flask_app.route('/api/v0/embedding_cache_stats', methods=['GET'])
  1517. def embedding_cache_stats():
  1518. """获取embedding缓存统计信息"""
  1519. try:
  1520. from common.embedding_cache_manager import get_embedding_cache_manager
  1521. cache_manager = get_embedding_cache_manager()
  1522. stats = cache_manager.get_cache_stats()
  1523. return jsonify(success_response(
  1524. response_text="获取embedding缓存统计成功",
  1525. data=stats
  1526. ))
  1527. except Exception as e:
  1528. logger.error(f"获取embedding缓存统计失败: {str(e)}")
  1529. return jsonify(internal_error_response(
  1530. response_text="获取embedding缓存统计失败,请稍后重试"
  1531. )), 500
  1532. @app.flask_app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
  1533. def embedding_cache_cleanup():
  1534. """清空所有embedding缓存"""
  1535. try:
  1536. from common.embedding_cache_manager import get_embedding_cache_manager
  1537. cache_manager = get_embedding_cache_manager()
  1538. if not cache_manager.is_available():
  1539. return jsonify(internal_error_response(
  1540. response_text="Embedding缓存功能未启用或不可用"
  1541. )), 400
  1542. success = cache_manager.clear_all_cache()
  1543. if success:
  1544. return jsonify(success_response(
  1545. response_text="所有embedding缓存已清空",
  1546. data={"cleared": True}
  1547. ))
  1548. else:
  1549. return jsonify(internal_error_response(
  1550. response_text="清空embedding缓存失败"
  1551. )), 500
  1552. except Exception as e:
  1553. logger.error(f"清空embedding缓存失败: {str(e)}")
  1554. return jsonify(internal_error_response(
  1555. response_text="清空embedding缓存失败,请稍后重试"
  1556. )), 500
  1557. # ==================== QA反馈系统接口 ====================
  1558. # 全局反馈管理器实例
  1559. qa_feedback_manager = None
  1560. def get_qa_feedback_manager():
  1561. """获取QA反馈管理器实例(懒加载)- 复用Vanna连接版本"""
  1562. global qa_feedback_manager
  1563. if qa_feedback_manager is None:
  1564. try:
  1565. # 优先尝试复用vanna连接
  1566. vanna_instance = None
  1567. try:
  1568. # 尝试获取现有的vanna实例
  1569. if 'get_citu_langraph_agent' in globals():
  1570. agent = get_citu_langraph_agent()
  1571. if hasattr(agent, 'vn'):
  1572. vanna_instance = agent.vn
  1573. elif 'vn' in globals():
  1574. vanna_instance = vn
  1575. else:
  1576. logger.info("未找到可用的vanna实例,将创建新的数据库连接")
  1577. except Exception as e:
  1578. logger.info(f"获取vanna实例失败: {e},将创建新的数据库连接")
  1579. vanna_instance = None
  1580. qa_feedback_manager = QAFeedbackManager(vanna_instance=vanna_instance)
  1581. logger.info("QA反馈管理器实例创建成功")
  1582. except Exception as e:
  1583. logger.critical(f"QA反馈管理器创建失败: {str(e)}")
  1584. raise Exception(f"QA反馈管理器初始化失败: {str(e)}")
  1585. return qa_feedback_manager
  1586. @app.flask_app.route('/api/v0/qa_feedback/query', methods=['POST'])
  1587. def qa_feedback_query():
  1588. """
  1589. 查询反馈记录API
  1590. 支持分页、筛选和排序功能
  1591. """
  1592. try:
  1593. req = request.get_json(force=True)
  1594. # 解析参数,设置默认值
  1595. page = req.get('page', 1)
  1596. page_size = req.get('page_size', 20)
  1597. is_thumb_up = req.get('is_thumb_up')
  1598. create_time_start = req.get('create_time_start')
  1599. create_time_end = req.get('create_time_end')
  1600. is_in_training_data = req.get('is_in_training_data')
  1601. sort_by = req.get('sort_by', 'create_time')
  1602. sort_order = req.get('sort_order', 'desc')
  1603. # 参数验证
  1604. if page < 1:
  1605. return jsonify(bad_request_response(
  1606. response_text="页码必须大于0",
  1607. invalid_params=["page"]
  1608. )), 400
  1609. if page_size < 1 or page_size > 100:
  1610. return jsonify(bad_request_response(
  1611. response_text="每页大小必须在1-100之间",
  1612. invalid_params=["page_size"]
  1613. )), 400
  1614. # 获取反馈管理器并查询
  1615. manager = get_qa_feedback_manager()
  1616. records, total = manager.query_feedback(
  1617. page=page,
  1618. page_size=page_size,
  1619. is_thumb_up=is_thumb_up,
  1620. create_time_start=create_time_start,
  1621. create_time_end=create_time_end,
  1622. is_in_training_data=is_in_training_data,
  1623. sort_by=sort_by,
  1624. sort_order=sort_order
  1625. )
  1626. # 计算分页信息
  1627. total_pages = (total + page_size - 1) // page_size
  1628. return jsonify(success_response(
  1629. response_text=f"查询成功,共找到 {total} 条记录",
  1630. data={
  1631. "records": records,
  1632. "pagination": {
  1633. "page": page,
  1634. "page_size": page_size,
  1635. "total": total,
  1636. "total_pages": total_pages,
  1637. "has_next": page < total_pages,
  1638. "has_prev": page > 1
  1639. }
  1640. }
  1641. ))
  1642. except Exception as e:
  1643. logger.error(f"qa_feedback_query执行失败: {str(e)}")
  1644. return jsonify(internal_error_response(
  1645. response_text="查询反馈记录失败,请稍后重试"
  1646. )), 500
  1647. @app.flask_app.route('/api/v0/qa_feedback/delete/<int:feedback_id>', methods=['DELETE'])
  1648. def qa_feedback_delete(feedback_id):
  1649. """
  1650. 删除反馈记录API
  1651. """
  1652. try:
  1653. manager = get_qa_feedback_manager()
  1654. success = manager.delete_feedback(feedback_id)
  1655. if success:
  1656. return jsonify(success_response(
  1657. response_text=f"反馈记录删除成功",
  1658. data={"deleted_id": feedback_id}
  1659. ))
  1660. else:
  1661. return jsonify(not_found_response(
  1662. response_text=f"反馈记录不存在 (ID: {feedback_id})"
  1663. )), 404
  1664. except Exception as e:
  1665. logger.error(f"qa_feedback_delete执行失败: {str(e)}")
  1666. return jsonify(internal_error_response(
  1667. response_text="删除反馈记录失败,请稍后重试"
  1668. )), 500
  1669. @app.flask_app.route('/api/v0/qa_feedback/update/<int:feedback_id>', methods=['PUT'])
  1670. def qa_feedback_update(feedback_id):
  1671. """
  1672. 更新反馈记录API
  1673. """
  1674. try:
  1675. req = request.get_json(force=True)
  1676. # 提取允许更新的字段
  1677. allowed_fields = ['question', 'sql', 'is_thumb_up', 'user_id', 'is_in_training_data']
  1678. update_data = {}
  1679. for field in allowed_fields:
  1680. if field in req:
  1681. update_data[field] = req[field]
  1682. if not update_data:
  1683. return jsonify(bad_request_response(
  1684. response_text="没有提供有效的更新字段",
  1685. missing_params=allowed_fields
  1686. )), 400
  1687. manager = get_qa_feedback_manager()
  1688. success = manager.update_feedback(feedback_id, **update_data)
  1689. if success:
  1690. return jsonify(success_response(
  1691. response_text="反馈记录更新成功",
  1692. data={
  1693. "updated_id": feedback_id,
  1694. "updated_fields": list(update_data.keys())
  1695. }
  1696. ))
  1697. else:
  1698. return jsonify(not_found_response(
  1699. response_text=f"反馈记录不存在或无变化 (ID: {feedback_id})"
  1700. )), 404
  1701. except Exception as e:
  1702. logger.error(f"qa_feedback_update执行失败: {str(e)}")
  1703. return jsonify(internal_error_response(
  1704. response_text="更新反馈记录失败,请稍后重试"
  1705. )), 500
  1706. @app.flask_app.route('/api/v0/qa_feedback/add_to_training', methods=['POST'])
  1707. def qa_feedback_add_to_training():
  1708. """
  1709. 将反馈记录添加到训练数据集API
  1710. 支持混合批量处理:正向反馈加入SQL训练集,负向反馈加入error_sql训练集
  1711. """
  1712. try:
  1713. req = request.get_json(force=True)
  1714. feedback_ids = req.get('feedback_ids', [])
  1715. if not feedback_ids or not isinstance(feedback_ids, list):
  1716. return jsonify(bad_request_response(
  1717. response_text="缺少有效的反馈ID列表",
  1718. missing_params=["feedback_ids"]
  1719. )), 400
  1720. manager = get_qa_feedback_manager()
  1721. # 获取反馈记录
  1722. records = manager.get_feedback_by_ids(feedback_ids)
  1723. if not records:
  1724. return jsonify(not_found_response(
  1725. response_text="未找到任何有效的反馈记录"
  1726. )), 404
  1727. # 分别处理正向和负向反馈
  1728. positive_count = 0 # 正向训练计数
  1729. negative_count = 0 # 负向训练计数
  1730. already_trained_count = 0 # 已训练计数
  1731. error_count = 0 # 错误计数
  1732. successfully_trained_ids = [] # 成功训练的ID列表
  1733. for record in records:
  1734. try:
  1735. # 检查是否已经在训练数据中
  1736. if record['is_in_training_data']:
  1737. already_trained_count += 1
  1738. continue
  1739. if record['is_thumb_up']:
  1740. # 正向反馈 - 加入标准SQL训练集
  1741. training_id = vn.train(
  1742. question=record['question'],
  1743. sql=record['sql']
  1744. )
  1745. positive_count += 1
  1746. logger.info(f"正向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1747. else:
  1748. # 负向反馈 - 加入错误SQL训练集
  1749. training_id = vn.train_error_sql(
  1750. question=record['question'],
  1751. sql=record['sql']
  1752. )
  1753. negative_count += 1
  1754. logger.info(f"负向训练成功 - ID: {record['id']}, TrainingID: {training_id}")
  1755. successfully_trained_ids.append(record['id'])
  1756. except Exception as e:
  1757. logger.error(f"训练失败 - 反馈ID: {record['id']}, 错误: {e}")
  1758. error_count += 1
  1759. # 更新训练状态
  1760. if successfully_trained_ids:
  1761. updated_count = manager.mark_training_status(successfully_trained_ids, True)
  1762. logger.info(f"批量更新训练状态完成,影响 {updated_count} 条记录")
  1763. # 构建响应
  1764. total_processed = positive_count + negative_count + already_trained_count + error_count
  1765. return jsonify(success_response(
  1766. response_text=f"训练数据添加完成,成功处理 {positive_count + negative_count} 条记录",
  1767. data={
  1768. "summary": {
  1769. "total_requested": len(feedback_ids),
  1770. "total_processed": total_processed,
  1771. "positive_trained": positive_count,
  1772. "negative_trained": negative_count,
  1773. "already_trained": already_trained_count,
  1774. "errors": error_count
  1775. },
  1776. "successfully_trained_ids": successfully_trained_ids,
  1777. "training_details": {
  1778. "sql_training_count": positive_count,
  1779. "error_sql_training_count": negative_count
  1780. }
  1781. }
  1782. ))
  1783. except Exception as e:
  1784. logger.error(f"qa_feedback_add_to_training执行失败: {str(e)}")
  1785. return jsonify(internal_error_response(
  1786. response_text="添加训练数据失败,请稍后重试"
  1787. )), 500
  1788. @app.flask_app.route('/api/v0/qa_feedback/add', methods=['POST'])
  1789. def qa_feedback_add():
  1790. """
  1791. 添加反馈记录API
  1792. 用于前端直接创建反馈记录
  1793. """
  1794. try:
  1795. req = request.get_json(force=True)
  1796. question = req.get('question')
  1797. sql = req.get('sql')
  1798. is_thumb_up = req.get('is_thumb_up')
  1799. user_id = req.get('user_id', 'guest')
  1800. # 参数验证
  1801. if not question:
  1802. return jsonify(bad_request_response(
  1803. response_text="缺少必需参数:question",
  1804. missing_params=["question"]
  1805. )), 400
  1806. if not sql:
  1807. return jsonify(bad_request_response(
  1808. response_text="缺少必需参数:sql",
  1809. missing_params=["sql"]
  1810. )), 400
  1811. if is_thumb_up is None:
  1812. return jsonify(bad_request_response(
  1813. response_text="缺少必需参数:is_thumb_up",
  1814. missing_params=["is_thumb_up"]
  1815. )), 400
  1816. manager = get_qa_feedback_manager()
  1817. feedback_id = manager.add_feedback(
  1818. question=question,
  1819. sql=sql,
  1820. is_thumb_up=bool(is_thumb_up),
  1821. user_id=user_id
  1822. )
  1823. return jsonify(success_response(
  1824. response_text="反馈记录创建成功",
  1825. data={
  1826. "feedback_id": feedback_id
  1827. }
  1828. ))
  1829. except Exception as e:
  1830. logger.error(f"qa_feedback_add执行失败: {str(e)}")
  1831. return jsonify(internal_error_response(
  1832. response_text="创建反馈记录失败,请稍后重试"
  1833. )), 500
  1834. @app.flask_app.route('/api/v0/qa_feedback/stats', methods=['GET'])
  1835. def qa_feedback_stats():
  1836. """
  1837. 反馈统计API
  1838. 返回反馈数据的统计信息
  1839. """
  1840. try:
  1841. manager = get_qa_feedback_manager()
  1842. # 查询各种统计数据
  1843. all_records, total_count = manager.query_feedback(page=1, page_size=1)
  1844. positive_records, positive_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=True)
  1845. negative_records, negative_count = manager.query_feedback(page=1, page_size=1, is_thumb_up=False)
  1846. trained_records, trained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=True)
  1847. untrained_records, untrained_count = manager.query_feedback(page=1, page_size=1, is_in_training_data=False)
  1848. return jsonify(success_response(
  1849. response_text="统计信息获取成功",
  1850. data={
  1851. "total_feedback": total_count,
  1852. "positive_feedback": positive_count,
  1853. "negative_feedback": negative_count,
  1854. "trained_feedback": trained_count,
  1855. "untrained_feedback": untrained_count,
  1856. "positive_rate": round(positive_count / max(total_count, 1) * 100, 2),
  1857. "training_rate": round(trained_count / max(total_count, 1) * 100, 2)
  1858. }
  1859. ))
  1860. except Exception as e:
  1861. logger.error(f"qa_feedback_stats执行失败: {str(e)}")
  1862. return jsonify(internal_error_response(
  1863. response_text="获取统计信息失败,请稍后重试"
  1864. )), 500
  1865. # ==================== 问答缓存管理接口 ====================
  1866. @app.flask_app.route('/api/v0/qa_cache_stats', methods=['GET'])
  1867. def qa_cache_stats():
  1868. """获取问答缓存统计信息"""
  1869. try:
  1870. stats = redis_conversation_manager.get_qa_cache_stats()
  1871. return jsonify(success_response(
  1872. response_text="获取问答缓存统计成功",
  1873. data=stats
  1874. ))
  1875. except Exception as e:
  1876. logger.error(f"获取问答缓存统计失败: {str(e)}")
  1877. return jsonify(internal_error_response(
  1878. response_text="获取问答缓存统计失败,请稍后重试"
  1879. )), 500
  1880. @app.flask_app.route('/api/v0/qa_cache_list', methods=['GET'])
  1881. def qa_cache_list():
  1882. """获取问答缓存列表(支持分页)"""
  1883. try:
  1884. # 获取分页参数,默认限制50条
  1885. limit = request.args.get('limit', 50, type=int)
  1886. # 限制最大返回数量,防止一次性返回过多数据
  1887. if limit > 500:
  1888. limit = 500
  1889. elif limit <= 0:
  1890. limit = 50
  1891. cache_list = redis_conversation_manager.get_qa_cache_list(limit)
  1892. return jsonify(success_response(
  1893. response_text="获取问答缓存列表成功",
  1894. data={
  1895. "cache_list": cache_list,
  1896. "total_returned": len(cache_list),
  1897. "limit_applied": limit,
  1898. "note": "按缓存时间倒序排列,最新的在前面"
  1899. }
  1900. ))
  1901. except Exception as e:
  1902. logger.error(f"获取问答缓存列表失败: {str(e)}")
  1903. return jsonify(internal_error_response(
  1904. response_text="获取问答缓存列表失败,请稍后重试"
  1905. )), 500
  1906. @app.flask_app.route('/api/v0/qa_cache_cleanup', methods=['POST'])
  1907. def qa_cache_cleanup():
  1908. """清空所有问答缓存"""
  1909. try:
  1910. if not redis_conversation_manager.is_available():
  1911. return jsonify(internal_error_response(
  1912. response_text="Redis连接不可用,无法执行清理操作"
  1913. )), 500
  1914. deleted_count = redis_conversation_manager.clear_all_qa_cache()
  1915. return jsonify(success_response(
  1916. response_text="问答缓存清理完成",
  1917. data={
  1918. "deleted_count": deleted_count,
  1919. "cleared": deleted_count > 0,
  1920. "cleanup_time": datetime.now().isoformat()
  1921. }
  1922. ))
  1923. except Exception as e:
  1924. logger.error(f"清空问答缓存失败: {str(e)}")
  1925. return jsonify(internal_error_response(
  1926. response_text="清空问答缓存失败,请稍后重试"
  1927. )), 500
  1928. # ==================== 训练数据管理接口 ====================
  1929. def validate_sql_syntax(sql: str) -> tuple[bool, str]:
  1930. """SQL语法检查(仅对sql类型)"""
  1931. try:
  1932. parsed = sqlparse.parse(sql.strip())
  1933. if not parsed or not parsed[0].tokens:
  1934. return False, "SQL语法错误:空语句"
  1935. # 基本语法检查
  1936. sql_upper = sql.strip().upper()
  1937. if not any(sql_upper.startswith(keyword) for keyword in
  1938. ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
  1939. return False, "SQL语法错误:不是有效的SQL语句"
  1940. # 安全检查:禁止危险的SQL操作
  1941. dangerous_operations = ['UPDATE', 'DELETE', 'ALERT', 'DROP']
  1942. for operation in dangerous_operations:
  1943. if sql_upper.startswith(operation):
  1944. return False, f'在训练集中禁止使用"{",".join(dangerous_operations)}"'
  1945. return True, ""
  1946. except Exception as e:
  1947. return False, f"SQL语法错误:{str(e)}"
  1948. def paginate_data(data_list: list, page: int, page_size: int):
  1949. """分页处理算法"""
  1950. total = len(data_list)
  1951. start_idx = (page - 1) * page_size
  1952. end_idx = start_idx + page_size
  1953. page_data = data_list[start_idx:end_idx]
  1954. return {
  1955. "data": page_data,
  1956. "pagination": {
  1957. "page": page,
  1958. "page_size": page_size,
  1959. "total": total,
  1960. "total_pages": (total + page_size - 1) // page_size,
  1961. "has_next": end_idx < total,
  1962. "has_prev": page > 1
  1963. }
  1964. }
  1965. def filter_by_type(data_list: list, training_data_type: str):
  1966. """按类型筛选算法"""
  1967. if not training_data_type:
  1968. return data_list
  1969. return [
  1970. record for record in data_list
  1971. if record.get('training_data_type') == training_data_type
  1972. ]
  1973. def search_in_data(data_list: list, search_keyword: str):
  1974. """在数据中搜索关键词"""
  1975. if not search_keyword:
  1976. return data_list
  1977. keyword_lower = search_keyword.lower()
  1978. return [
  1979. record for record in data_list
  1980. if (record.get('question') and keyword_lower in record['question'].lower()) or
  1981. (record.get('content') and keyword_lower in record['content'].lower())
  1982. ]
  1983. def process_single_training_item(item: dict, index: int) -> dict:
  1984. """处理单个训练数据项"""
  1985. training_type = item.get('training_data_type')
  1986. if training_type == 'sql':
  1987. sql = item.get('sql')
  1988. if not sql:
  1989. raise ValueError("SQL字段是必需的")
  1990. # SQL语法检查
  1991. is_valid, error_msg = validate_sql_syntax(sql)
  1992. if not is_valid:
  1993. raise ValueError(error_msg)
  1994. question = item.get('question')
  1995. if question:
  1996. training_id = vn.train(question=question, sql=sql)
  1997. else:
  1998. training_id = vn.train(sql=sql)
  1999. elif training_type == 'error_sql':
  2000. # error_sql不需要语法检查
  2001. question = item.get('question')
  2002. sql = item.get('sql')
  2003. if not question or not sql:
  2004. raise ValueError("question和sql字段都是必需的")
  2005. training_id = vn.train_error_sql(question=question, sql=sql)
  2006. elif training_type == 'documentation':
  2007. content = item.get('content')
  2008. if not content:
  2009. raise ValueError("content字段是必需的")
  2010. training_id = vn.train(documentation=content)
  2011. elif training_type == 'ddl':
  2012. ddl = item.get('ddl')
  2013. if not ddl:
  2014. raise ValueError("ddl字段是必需的")
  2015. training_id = vn.train(ddl=ddl)
  2016. else:
  2017. raise ValueError(f"不支持的训练数据类型: {training_type}")
  2018. return {
  2019. "index": index,
  2020. "success": True,
  2021. "training_id": training_id,
  2022. "type": training_type,
  2023. "message": f"{training_type}训练数据创建成功"
  2024. }
  2025. def get_total_training_count():
  2026. """获取当前训练数据总数"""
  2027. try:
  2028. training_data = vn.get_training_data()
  2029. if training_data is not None and not training_data.empty:
  2030. return len(training_data)
  2031. return 0
  2032. except Exception as e:
  2033. logger.warning(f"获取训练数据总数失败: {e}")
  2034. return 0
  2035. @app.flask_app.route('/api/v0/training_data/query', methods=['POST'])
  2036. def training_data_query():
  2037. """
  2038. 分页查询训练数据API
  2039. 支持类型筛选、搜索和排序功能
  2040. """
  2041. try:
  2042. req = request.get_json(force=True)
  2043. # 解析参数,设置默认值
  2044. page = req.get('page', 1)
  2045. page_size = req.get('page_size', 20)
  2046. training_data_type = req.get('training_data_type')
  2047. sort_by = req.get('sort_by', 'id')
  2048. sort_order = req.get('sort_order', 'desc')
  2049. search_keyword = req.get('search_keyword')
  2050. # 参数验证
  2051. if page < 1:
  2052. return jsonify(bad_request_response(
  2053. response_text="页码必须大于0",
  2054. missing_params=["page"]
  2055. )), 400
  2056. if page_size < 1 or page_size > 100:
  2057. return jsonify(bad_request_response(
  2058. response_text="每页大小必须在1-100之间",
  2059. missing_params=["page_size"]
  2060. )), 400
  2061. if search_keyword and len(search_keyword) > 100:
  2062. return jsonify(bad_request_response(
  2063. response_text="搜索关键词最大长度为100字符",
  2064. missing_params=["search_keyword"]
  2065. )), 400
  2066. # 获取训练数据
  2067. training_data = vn.get_training_data()
  2068. if training_data is None or training_data.empty:
  2069. return jsonify(success_response(
  2070. response_text="查询成功,暂无训练数据",
  2071. data={
  2072. "records": [],
  2073. "pagination": {
  2074. "page": page,
  2075. "page_size": page_size,
  2076. "total": 0,
  2077. "total_pages": 0,
  2078. "has_next": False,
  2079. "has_prev": False
  2080. },
  2081. "filters_applied": {
  2082. "training_data_type": training_data_type,
  2083. "search_keyword": search_keyword
  2084. }
  2085. }
  2086. ))
  2087. # 转换为列表格式
  2088. records = training_data.to_dict(orient="records")
  2089. # 应用筛选条件
  2090. if training_data_type:
  2091. records = filter_by_type(records, training_data_type)
  2092. if search_keyword:
  2093. records = search_in_data(records, search_keyword)
  2094. # 排序
  2095. if sort_by in ['id', 'training_data_type']:
  2096. reverse = (sort_order.lower() == 'desc')
  2097. records.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
  2098. # 分页
  2099. paginated_result = paginate_data(records, page, page_size)
  2100. return jsonify(success_response(
  2101. response_text=f"查询成功,共找到 {paginated_result['pagination']['total']} 条记录",
  2102. data={
  2103. "records": paginated_result["data"],
  2104. "pagination": paginated_result["pagination"],
  2105. "filters_applied": {
  2106. "training_data_type": training_data_type,
  2107. "search_keyword": search_keyword
  2108. }
  2109. }
  2110. ))
  2111. except Exception as e:
  2112. logger.error(f"training_data_query执行失败: {str(e)}")
  2113. return jsonify(internal_error_response(
  2114. response_text="查询训练数据失败,请稍后重试"
  2115. )), 500
  2116. @app.flask_app.route('/api/v0/training_data/create', methods=['POST'])
  2117. def training_data_create():
  2118. """
  2119. 创建训练数据API
  2120. 支持单条和批量创建,支持四种数据类型
  2121. """
  2122. try:
  2123. req = request.get_json(force=True)
  2124. data = req.get('data')
  2125. if not data:
  2126. return jsonify(bad_request_response(
  2127. response_text="缺少必需参数:data",
  2128. missing_params=["data"]
  2129. )), 400
  2130. # 统一处理为列表格式
  2131. if isinstance(data, dict):
  2132. data_list = [data]
  2133. elif isinstance(data, list):
  2134. data_list = data
  2135. else:
  2136. return jsonify(bad_request_response(
  2137. response_text="data字段格式错误,应为对象或数组"
  2138. )), 400
  2139. # 批量操作限制
  2140. if len(data_list) > 50:
  2141. return jsonify(bad_request_response(
  2142. response_text="批量操作最大支持50条记录"
  2143. )), 400
  2144. results = []
  2145. successful_count = 0
  2146. type_summary = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2147. for index, item in enumerate(data_list):
  2148. try:
  2149. result = process_single_training_item(item, index)
  2150. results.append(result)
  2151. if result['success']:
  2152. successful_count += 1
  2153. type_summary[result['type']] += 1
  2154. except Exception as e:
  2155. results.append({
  2156. "index": index,
  2157. "success": False,
  2158. "type": item.get('training_data_type', 'unknown'),
  2159. "error": str(e),
  2160. "message": "创建失败"
  2161. })
  2162. # 获取创建后的总记录数
  2163. current_total = get_total_training_count()
  2164. return jsonify(success_response(
  2165. response_text="训练数据创建完成",
  2166. data={
  2167. "total_requested": len(data_list),
  2168. "successfully_created": successful_count,
  2169. "failed_count": len(data_list) - successful_count,
  2170. "results": results,
  2171. "summary": type_summary,
  2172. "current_total_count": current_total
  2173. }
  2174. ))
  2175. except Exception as e:
  2176. logger.error(f"training_data_create执行失败: {str(e)}")
  2177. return jsonify(internal_error_response(
  2178. response_text="创建训练数据失败,请稍后重试"
  2179. )), 500
  2180. @app.flask_app.route('/api/v0/training_data/delete', methods=['POST'])
  2181. def training_data_delete():
  2182. """
  2183. 删除训练数据API
  2184. 支持批量删除
  2185. """
  2186. try:
  2187. req = request.get_json(force=True)
  2188. ids = req.get('ids', [])
  2189. confirm = req.get('confirm', False)
  2190. if not ids or not isinstance(ids, list):
  2191. return jsonify(bad_request_response(
  2192. response_text="缺少有效的ID列表",
  2193. missing_params=["ids"]
  2194. )), 400
  2195. if not confirm:
  2196. return jsonify(bad_request_response(
  2197. response_text="删除操作需要确认,请设置confirm为true"
  2198. )), 400
  2199. # 批量操作限制
  2200. if len(ids) > 50:
  2201. return jsonify(bad_request_response(
  2202. response_text="批量删除最大支持50条记录"
  2203. )), 400
  2204. deleted_ids = []
  2205. failed_ids = []
  2206. failed_details = []
  2207. for training_id in ids:
  2208. try:
  2209. success = vn.remove_training_data(training_id)
  2210. if success:
  2211. deleted_ids.append(training_id)
  2212. else:
  2213. failed_ids.append(training_id)
  2214. failed_details.append({
  2215. "id": training_id,
  2216. "error": "记录不存在或删除失败"
  2217. })
  2218. except Exception as e:
  2219. failed_ids.append(training_id)
  2220. failed_details.append({
  2221. "id": training_id,
  2222. "error": str(e)
  2223. })
  2224. # 获取删除后的总记录数
  2225. current_total = get_total_training_count()
  2226. return jsonify(success_response(
  2227. response_text="训练数据删除完成",
  2228. data={
  2229. "total_requested": len(ids),
  2230. "successfully_deleted": len(deleted_ids),
  2231. "failed_count": len(failed_ids),
  2232. "deleted_ids": deleted_ids,
  2233. "failed_ids": failed_ids,
  2234. "failed_details": failed_details,
  2235. "current_total_count": current_total
  2236. }
  2237. ))
  2238. except Exception as e:
  2239. logger.error(f"training_data_delete执行失败: {str(e)}")
  2240. return jsonify(internal_error_response(
  2241. response_text="删除训练数据失败,请稍后重试"
  2242. )), 500
  2243. @app.flask_app.route('/api/v0/training_data/stats', methods=['GET'])
  2244. def training_data_stats():
  2245. """
  2246. 获取训练数据统计信息API
  2247. """
  2248. try:
  2249. training_data = vn.get_training_data()
  2250. if training_data is None or training_data.empty:
  2251. return jsonify(success_response(
  2252. response_text="统计信息获取成功",
  2253. data={
  2254. "total_count": 0,
  2255. "type_breakdown": {
  2256. "sql": 0,
  2257. "documentation": 0,
  2258. "ddl": 0,
  2259. "error_sql": 0
  2260. },
  2261. "type_percentages": {
  2262. "sql": 0.0,
  2263. "documentation": 0.0,
  2264. "ddl": 0.0,
  2265. "error_sql": 0.0
  2266. },
  2267. "last_updated": datetime.now().isoformat()
  2268. }
  2269. ))
  2270. total_count = len(training_data)
  2271. # 统计各类型数量
  2272. type_breakdown = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
  2273. if 'training_data_type' in training_data.columns:
  2274. type_counts = training_data['training_data_type'].value_counts()
  2275. for data_type, count in type_counts.items():
  2276. if data_type in type_breakdown:
  2277. type_breakdown[data_type] = int(count)
  2278. # 计算百分比
  2279. type_percentages = {}
  2280. for data_type, count in type_breakdown.items():
  2281. type_percentages[data_type] = round(count / max(total_count, 1) * 100, 2)
  2282. return jsonify(success_response(
  2283. response_text="统计信息获取成功",
  2284. data={
  2285. "total_count": total_count,
  2286. "type_breakdown": type_breakdown,
  2287. "type_percentages": type_percentages,
  2288. "last_updated": datetime.now().isoformat()
  2289. }
  2290. ))
  2291. except Exception as e:
  2292. logger.error(f"training_data_stats执行失败: {str(e)}")
  2293. return jsonify(internal_error_response(
  2294. response_text="获取统计信息失败,请稍后重试"
  2295. )), 500
  2296. @app.flask_app.route('/api/v0/cache_overview_full', methods=['GET'])
  2297. def cache_overview_full():
  2298. """获取所有缓存系统的综合概览"""
  2299. try:
  2300. from common.embedding_cache_manager import get_embedding_cache_manager
  2301. from common.vanna_instance import get_vanna_instance
  2302. # 获取现有的缓存统计
  2303. vanna_cache = get_vanna_instance()
  2304. # 直接使用应用中的缓存实例
  2305. cache = app.cache
  2306. cache_overview = {
  2307. "conversation_aware_cache": {
  2308. "enabled": True,
  2309. "total_items": len(cache.cache) if hasattr(cache, 'cache') else 0,
  2310. "sessions": list(cache.cache.keys()) if hasattr(cache, 'cache') else [],
  2311. "cache_type": type(cache).__name__
  2312. },
  2313. "question_answer_cache": redis_conversation_manager.get_qa_cache_stats() if redis_conversation_manager.is_available() else {"available": False},
  2314. "embedding_cache": get_embedding_cache_manager().get_cache_stats(),
  2315. "redis_conversation_stats": redis_conversation_manager.get_stats() if redis_conversation_manager.is_available() else None
  2316. }
  2317. return jsonify(success_response(
  2318. response_text="获取综合缓存概览成功",
  2319. data=cache_overview
  2320. ))
  2321. except Exception as e:
  2322. logger.error(f"获取综合缓存概览失败: {str(e)}")
  2323. return jsonify(internal_error_response(
  2324. response_text="获取缓存概览失败,请稍后重试"
  2325. )), 500
  2326. # 前端JavaScript示例 - 如何维持会话
  2327. """
  2328. // 前端需要维护一个会话ID
  2329. class ChatSession {
  2330. constructor() {
  2331. // 从localStorage获取或创建新的会话ID
  2332. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  2333. localStorage.setItem('chat_session_id', this.sessionId);
  2334. }
  2335. generateSessionId() {
  2336. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  2337. }
  2338. async askQuestion(question) {
  2339. const response = await fetch('/api/v0/ask', {
  2340. method: 'POST',
  2341. headers: {
  2342. 'Content-Type': 'application/json',
  2343. },
  2344. body: JSON.stringify({
  2345. question: question,
  2346. session_id: this.sessionId // 关键:传递会话ID
  2347. })
  2348. });
  2349. return await response.json();
  2350. }
  2351. // 开始新会话
  2352. startNewSession() {
  2353. this.sessionId = this.generateSessionId();
  2354. localStorage.setItem('chat_session_id', this.sessionId);
  2355. }
  2356. }
  2357. // 使用示例
  2358. const chatSession = new ChatSession();
  2359. chatSession.askQuestion("各年龄段客户的流失率如何?");
  2360. """
  2361. # ==================== Data Pipeline API ====================
  2362. # 导入简化的Data Pipeline模块
  2363. import asyncio
  2364. import os
  2365. from threading import Thread
  2366. from flask import send_file
  2367. from data_pipeline.api.simple_workflow import SimpleWorkflowManager
  2368. from data_pipeline.api.simple_file_manager import SimpleFileManager
  2369. # 创建简化的管理器
  2370. data_pipeline_manager = None
  2371. data_pipeline_file_manager = None
  2372. def get_data_pipeline_manager():
  2373. """获取Data Pipeline管理器单例"""
  2374. global data_pipeline_manager
  2375. if data_pipeline_manager is None:
  2376. data_pipeline_manager = SimpleWorkflowManager()
  2377. return data_pipeline_manager
  2378. def get_data_pipeline_file_manager():
  2379. """获取Data Pipeline文件管理器单例"""
  2380. global data_pipeline_file_manager
  2381. if data_pipeline_file_manager is None:
  2382. data_pipeline_file_manager = SimpleFileManager()
  2383. return data_pipeline_file_manager
  2384. # ==================== 简化的Data Pipeline API端点 ====================
  2385. @app.flask_app.route('/api/v0/data_pipeline/tasks', methods=['POST'])
  2386. def create_data_pipeline_task():
  2387. """创建数据管道任务"""
  2388. try:
  2389. req = request.get_json(force=True)
  2390. # table_list_file和business_context现在都是可选参数
  2391. # 如果未提供table_list_file,将使用文件上传模式
  2392. # 创建任务(支持可选的db_connection参数)
  2393. manager = get_data_pipeline_manager()
  2394. task_id = manager.create_task(
  2395. table_list_file=req.get('table_list_file'),
  2396. business_context=req.get('business_context'),
  2397. db_name=req.get('db_name'), # 可选参数,用于指定特定数据库名称
  2398. db_connection=req.get('db_connection'), # 可选参数,用于指定数据库连接字符串
  2399. task_name=req.get('task_name'), # 可选参数,用于指定任务名称
  2400. enable_sql_validation=req.get('enable_sql_validation', True),
  2401. enable_llm_repair=req.get('enable_llm_repair', True),
  2402. modify_original_file=req.get('modify_original_file', True),
  2403. enable_training_data_load=req.get('enable_training_data_load', True)
  2404. )
  2405. # 获取任务信息
  2406. task_info = manager.get_task_status(task_id)
  2407. response_data = {
  2408. "task_id": task_id,
  2409. "task_name": task_info.get('task_name'),
  2410. "status": task_info.get('status'),
  2411. "created_at": task_info.get('created_at').isoformat() if task_info.get('created_at') else None
  2412. }
  2413. # 检查是否为文件上传模式
  2414. file_upload_mode = not req.get('table_list_file')
  2415. response_message = "任务创建成功"
  2416. if file_upload_mode:
  2417. response_data["file_upload_mode"] = True
  2418. response_data["next_step"] = f"POST /api/v0/data_pipeline/tasks/{task_id}/upload-table-list"
  2419. response_message += ",请上传表清单文件后再执行任务"
  2420. return jsonify(success_response(
  2421. response_text=response_message,
  2422. data=response_data
  2423. )), 201
  2424. except Exception as e:
  2425. logger.error(f"创建数据管道任务失败: {str(e)}")
  2426. return jsonify(internal_error_response(
  2427. response_text="创建任务失败,请稍后重试"
  2428. )), 500
  2429. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/execute', methods=['POST'])
  2430. def execute_data_pipeline_task(task_id):
  2431. """执行数据管道任务"""
  2432. try:
  2433. req = request.get_json(force=True) if request.is_json else {}
  2434. execution_mode = req.get('execution_mode', 'complete')
  2435. step_name = req.get('step_name')
  2436. # 验证执行模式
  2437. if execution_mode not in ['complete', 'step']:
  2438. return jsonify(bad_request_response(
  2439. response_text="无效的执行模式,必须是 'complete' 或 'step'",
  2440. invalid_params=['execution_mode']
  2441. )), 400
  2442. # 如果是步骤执行模式,验证步骤名称
  2443. if execution_mode == 'step':
  2444. if not step_name:
  2445. return jsonify(bad_request_response(
  2446. response_text="步骤执行模式需要指定step_name",
  2447. missing_params=['step_name']
  2448. )), 400
  2449. valid_steps = ['ddl_generation', 'qa_generation', 'sql_validation', 'training_load']
  2450. if step_name not in valid_steps:
  2451. return jsonify(bad_request_response(
  2452. response_text=f"无效的步骤名称,支持的步骤: {', '.join(valid_steps)}",
  2453. invalid_params=['step_name']
  2454. )), 400
  2455. # 检查任务是否存在
  2456. manager = get_data_pipeline_manager()
  2457. task_info = manager.get_task_status(task_id)
  2458. if not task_info:
  2459. return jsonify(not_found_response(
  2460. response_text=f"任务不存在: {task_id}"
  2461. )), 404
  2462. # 使用subprocess启动独立进程执行任务
  2463. def run_task_subprocess():
  2464. try:
  2465. import subprocess
  2466. import sys
  2467. from pathlib import Path
  2468. # 构建执行命令
  2469. python_executable = sys.executable
  2470. script_path = Path(__file__).parent / "data_pipeline" / "task_executor.py"
  2471. cmd = [
  2472. python_executable,
  2473. str(script_path),
  2474. "--task-id", task_id,
  2475. "--execution-mode", execution_mode
  2476. ]
  2477. if step_name:
  2478. cmd.extend(["--step-name", step_name])
  2479. logger.info(f"启动任务进程: {' '.join(cmd)}")
  2480. # 启动后台进程(不等待完成)
  2481. process = subprocess.Popen(
  2482. cmd,
  2483. stdout=subprocess.PIPE,
  2484. stderr=subprocess.PIPE,
  2485. text=True,
  2486. cwd=Path(__file__).parent
  2487. )
  2488. logger.info(f"任务进程已启动: PID={process.pid}, task_id={task_id}")
  2489. except Exception as e:
  2490. logger.error(f"启动任务进程失败: {task_id}, 错误: {str(e)}")
  2491. # 在新线程中启动subprocess(避免阻塞API响应)
  2492. thread = Thread(target=run_task_subprocess, daemon=True)
  2493. thread.start()
  2494. response_data = {
  2495. "task_id": task_id,
  2496. "execution_mode": execution_mode,
  2497. "step_name": step_name if execution_mode == 'step' else None,
  2498. "message": "任务正在后台执行,请通过状态接口查询进度"
  2499. }
  2500. return jsonify(success_response(
  2501. response_text="任务执行已启动",
  2502. data=response_data
  2503. )), 202
  2504. except Exception as e:
  2505. logger.error(f"启动数据管道任务执行失败: {str(e)}")
  2506. return jsonify(internal_error_response(
  2507. response_text="启动任务执行失败,请稍后重试"
  2508. )), 500
  2509. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>', methods=['GET'])
  2510. def get_data_pipeline_task_status(task_id):
  2511. """
  2512. 获取数据管道任务状态
  2513. 响应:
  2514. {
  2515. "success": true,
  2516. "code": 200,
  2517. "message": "获取任务状态成功",
  2518. "data": {
  2519. "task_id": "task_20250627_143052",
  2520. "status": "in_progress",
  2521. "step_status": {
  2522. "ddl_generation": "completed",
  2523. "qa_generation": "running",
  2524. "sql_validation": "pending",
  2525. "training_load": "pending"
  2526. },
  2527. "created_at": "2025-06-27T14:30:52",
  2528. "started_at": "2025-06-27T14:31:00",
  2529. "parameters": {...},
  2530. "current_execution": {...},
  2531. "total_executions": 2
  2532. }
  2533. }
  2534. """
  2535. try:
  2536. manager = get_data_pipeline_manager()
  2537. task_info = manager.get_task_status(task_id)
  2538. if not task_info:
  2539. return jsonify(not_found_response(
  2540. response_text=f"任务不存在: {task_id}"
  2541. )), 404
  2542. # 获取步骤状态
  2543. steps = manager.get_task_steps(task_id)
  2544. current_step = None
  2545. for step in steps:
  2546. if step['step_status'] == 'running':
  2547. current_step = step
  2548. break
  2549. # 构建步骤状态摘要
  2550. step_status_summary = {}
  2551. for step in steps:
  2552. step_status_summary[step['step_name']] = step['step_status']
  2553. response_data = {
  2554. "task_id": task_info['task_id'],
  2555. "task_name": task_info.get('task_name'),
  2556. "status": task_info['status'],
  2557. "step_status": step_status_summary,
  2558. "created_at": task_info['created_at'].isoformat() if task_info.get('created_at') else None,
  2559. "started_at": task_info['started_at'].isoformat() if task_info.get('started_at') else None,
  2560. "completed_at": task_info['completed_at'].isoformat() if task_info.get('completed_at') else None,
  2561. "parameters": task_info.get('parameters', {}),
  2562. "result": task_info.get('result'),
  2563. "error_message": task_info.get('error_message'),
  2564. "current_step": {
  2565. "execution_id": current_step['execution_id'],
  2566. "step": current_step['step_name'],
  2567. "status": current_step['step_status'],
  2568. "started_at": current_step['started_at'].isoformat() if current_step and current_step.get('started_at') else None
  2569. } if current_step else None,
  2570. "total_steps": len(steps),
  2571. "steps": [{
  2572. "step_name": step['step_name'],
  2573. "step_status": step['step_status'],
  2574. "started_at": step['started_at'].isoformat() if step.get('started_at') else None,
  2575. "completed_at": step['completed_at'].isoformat() if step.get('completed_at') else None,
  2576. "error_message": step.get('error_message')
  2577. } for step in steps]
  2578. }
  2579. return jsonify(success_response(
  2580. response_text="获取任务状态成功",
  2581. data=response_data
  2582. ))
  2583. except Exception as e:
  2584. logger.error(f"获取数据管道任务状态失败: {str(e)}")
  2585. return jsonify(internal_error_response(
  2586. response_text="获取任务状态失败,请稍后重试"
  2587. )), 500
  2588. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/logs', methods=['GET'])
  2589. def get_data_pipeline_task_logs(task_id):
  2590. """
  2591. 获取数据管道任务日志(从任务目录文件读取)
  2592. 查询参数:
  2593. - limit: 日志行数限制,默认100
  2594. - level: 日志级别过滤,可选
  2595. 响应:
  2596. {
  2597. "success": true,
  2598. "code": 200,
  2599. "message": "获取任务日志成功",
  2600. "data": {
  2601. "task_id": "task_20250627_143052",
  2602. "logs": [
  2603. {
  2604. "timestamp": "2025-06-27 14:30:52",
  2605. "level": "INFO",
  2606. "message": "任务开始执行"
  2607. }
  2608. ],
  2609. "total": 15,
  2610. "source": "file"
  2611. }
  2612. }
  2613. """
  2614. try:
  2615. limit = request.args.get('limit', 100, type=int)
  2616. level = request.args.get('level')
  2617. # 限制最大查询数量
  2618. limit = min(limit, 1000)
  2619. manager = get_data_pipeline_manager()
  2620. # 验证任务是否存在
  2621. task_info = manager.get_task_status(task_id)
  2622. if not task_info:
  2623. return jsonify(not_found_response(
  2624. response_text=f"任务不存在: {task_id}"
  2625. )), 404
  2626. # 获取任务目录下的日志文件
  2627. import os
  2628. from pathlib import Path
  2629. # 获取项目根目录的绝对路径
  2630. project_root = Path(__file__).parent.absolute()
  2631. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  2632. log_file = task_dir / "data_pipeline.log"
  2633. logs = []
  2634. if log_file.exists():
  2635. try:
  2636. # 读取日志文件的最后N行
  2637. with open(log_file, 'r', encoding='utf-8') as f:
  2638. lines = f.readlines()
  2639. # 取最后limit行
  2640. recent_lines = lines[-limit:] if len(lines) > limit else lines
  2641. # 解析日志行
  2642. import re
  2643. log_pattern = r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) \[(\w+)\] (.+?): (.+)$'
  2644. for line in recent_lines:
  2645. line = line.strip()
  2646. if not line:
  2647. continue
  2648. match = re.match(log_pattern, line)
  2649. if match:
  2650. timestamp, log_level, logger_name, message = match.groups()
  2651. # 级别过滤
  2652. if level and log_level != level.upper():
  2653. continue
  2654. logs.append({
  2655. "timestamp": timestamp,
  2656. "level": log_level,
  2657. "logger": logger_name,
  2658. "message": message
  2659. })
  2660. else:
  2661. # 处理多行日志(如异常堆栈)
  2662. if logs:
  2663. logs[-1]["message"] += f"\n{line}"
  2664. except Exception as e:
  2665. logger.error(f"读取日志文件失败: {e}")
  2666. response_data = {
  2667. "task_id": task_id,
  2668. "logs": logs,
  2669. "total": len(logs),
  2670. "source": "file",
  2671. "log_file": str(log_file) if log_file.exists() else None
  2672. }
  2673. return jsonify(success_response(
  2674. response_text="获取任务日志成功",
  2675. data=response_data
  2676. ))
  2677. except Exception as e:
  2678. logger.error(f"获取数据管道任务日志失败: {str(e)}")
  2679. return jsonify(internal_error_response(
  2680. response_text="获取任务日志失败,请稍后重试"
  2681. )), 500
  2682. @app.flask_app.route('/api/v0/data_pipeline/tasks', methods=['GET'])
  2683. def list_data_pipeline_tasks():
  2684. """获取数据管道任务列表"""
  2685. try:
  2686. limit = request.args.get('limit', 50, type=int)
  2687. offset = request.args.get('offset', 0, type=int)
  2688. status_filter = request.args.get('status')
  2689. # 限制查询数量
  2690. limit = min(limit, 100)
  2691. manager = get_data_pipeline_manager()
  2692. tasks = manager.get_tasks_list(
  2693. limit=limit,
  2694. offset=offset,
  2695. status_filter=status_filter
  2696. )
  2697. # 格式化任务列表
  2698. formatted_tasks = []
  2699. for task in tasks:
  2700. formatted_tasks.append({
  2701. "task_id": task.get('task_id'),
  2702. "task_name": task.get('task_name'),
  2703. "status": task.get('status'),
  2704. "step_status": task.get('step_status'),
  2705. "created_at": task['created_at'].isoformat() if task.get('created_at') else None,
  2706. "started_at": task['started_at'].isoformat() if task.get('started_at') else None,
  2707. "completed_at": task['completed_at'].isoformat() if task.get('completed_at') else None,
  2708. "created_by": task.get('by_user'),
  2709. "db_name": task.get('db_name'),
  2710. "business_context": task.get('parameters', {}).get('business_context') if task.get('parameters') else None,
  2711. # 新增字段
  2712. "directory_exists": task.get('directory_exists', True), # 默认为True,兼容旧数据
  2713. "updated_at": task['updated_at'].isoformat() if task.get('updated_at') else None
  2714. })
  2715. response_data = {
  2716. "tasks": formatted_tasks,
  2717. "total": len(formatted_tasks),
  2718. "limit": limit,
  2719. "offset": offset
  2720. }
  2721. return jsonify(success_response(
  2722. response_text="获取任务列表成功",
  2723. data=response_data
  2724. ))
  2725. except Exception as e:
  2726. logger.error(f"获取数据管道任务列表失败: {str(e)}")
  2727. return jsonify(internal_error_response(
  2728. response_text="获取任务列表失败,请稍后重试"
  2729. )), 500
  2730. # ==================== 表检查API端点 ====================
  2731. import asyncio
  2732. from data_pipeline.api.table_inspector_api import TableInspectorAPI
  2733. @app.flask_app.route('/api/v0/database/tables', methods=['POST'])
  2734. def get_database_tables():
  2735. """
  2736. 获取数据库表列表
  2737. 请求体:
  2738. {
  2739. "db_connection": "postgresql://postgres:postgres@192.168.67.1:5432/highway_db", // 可选,不传则使用默认配置
  2740. "schema": "public,ods", // 可选,支持多个schema用逗号分隔,默认为public
  2741. "table_name_pattern": "ods_*" // 可选,表名模式匹配,支持通配符:ods_*、*_dim、*fact*、ods_%
  2742. }
  2743. 响应:
  2744. {
  2745. "success": true,
  2746. "code": 200,
  2747. "message": "获取表列表成功",
  2748. "data": {
  2749. "tables": ["public.table1", "public.table2", "ods.table3"],
  2750. "total": 3,
  2751. "schemas": ["public", "ods"],
  2752. "table_name_pattern": "ods_*"
  2753. }
  2754. }
  2755. """
  2756. try:
  2757. req = request.get_json(force=True)
  2758. # 处理数据库连接参数(可选)
  2759. db_connection = req.get('db_connection')
  2760. if not db_connection:
  2761. # 使用app_config的默认数据库配置
  2762. import app_config
  2763. db_params = app_config.APP_DB_CONFIG
  2764. db_connection = f"postgresql://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
  2765. logger.info("使用默认数据库配置获取表列表")
  2766. else:
  2767. logger.info("使用用户指定的数据库配置获取表列表")
  2768. # 可选参数
  2769. schema = req.get('schema', '')
  2770. table_name_pattern = req.get('table_name_pattern')
  2771. # 创建表检查API实例
  2772. table_inspector = TableInspectorAPI()
  2773. # 使用asyncio运行异步方法
  2774. async def get_tables():
  2775. return await table_inspector.get_tables_list(db_connection, schema, table_name_pattern)
  2776. # 在新的事件循环中运行异步方法
  2777. try:
  2778. loop = asyncio.new_event_loop()
  2779. asyncio.set_event_loop(loop)
  2780. tables = loop.run_until_complete(get_tables())
  2781. finally:
  2782. loop.close()
  2783. # 解析schema信息
  2784. parsed_schemas = table_inspector._parse_schemas(schema)
  2785. response_data = {
  2786. "tables": tables,
  2787. "total": len(tables),
  2788. "schemas": parsed_schemas,
  2789. "db_connection_info": {
  2790. "database": db_connection.split('/')[-1].split('?')[0] if '/' in db_connection else "unknown"
  2791. }
  2792. }
  2793. # 如果使用了表名模式,添加到响应中
  2794. if table_name_pattern:
  2795. response_data["table_name_pattern"] = table_name_pattern
  2796. return jsonify(success_response(
  2797. response_text="获取表列表成功",
  2798. data=response_data
  2799. )), 200
  2800. except Exception as e:
  2801. logger.error(f"获取数据库表列表失败: {str(e)}")
  2802. return jsonify(internal_error_response(
  2803. response_text=f"获取表列表失败: {str(e)}"
  2804. )), 500
  2805. @app.flask_app.route('/api/v0/database/table/ddl', methods=['POST'])
  2806. def get_table_ddl():
  2807. """
  2808. 获取表的DDL语句或MD文档
  2809. 请求体:
  2810. {
  2811. "db_connection": "postgresql://postgres:postgres@192.168.67.1:5432/highway_db", // 可选,不传则使用默认配置
  2812. "table": "public.test",
  2813. "business_context": "这是高速公路服务区的相关数据", // 可选
  2814. "type": "ddl" // 可选,支持ddl/md/both,默认为ddl
  2815. }
  2816. 响应:
  2817. {
  2818. "success": true,
  2819. "code": 200,
  2820. "message": "获取表DDL成功",
  2821. "data": {
  2822. "ddl": "create table public.test (...);",
  2823. "md": "## test表...", // 仅当type为md或both时返回
  2824. "table_info": {
  2825. "table_name": "test",
  2826. "schema_name": "public",
  2827. "full_name": "public.test",
  2828. "comment": "测试表",
  2829. "field_count": 10,
  2830. "row_count": 1000
  2831. },
  2832. "fields": [...]
  2833. }
  2834. }
  2835. """
  2836. try:
  2837. req = request.get_json(force=True)
  2838. # 处理参数(table仍为必需,db_connection可选)
  2839. table = req.get('table')
  2840. db_connection = req.get('db_connection')
  2841. if not table:
  2842. return jsonify(bad_request_response(
  2843. response_text="缺少必需参数:table",
  2844. missing_params=['table']
  2845. )), 400
  2846. if not db_connection:
  2847. # 使用app_config的默认数据库配置
  2848. import app_config
  2849. db_params = app_config.APP_DB_CONFIG
  2850. db_connection = f"postgresql://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
  2851. logger.info("使用默认数据库配置获取表DDL")
  2852. else:
  2853. logger.info("使用用户指定的数据库配置获取表DDL")
  2854. # 可选参数
  2855. business_context = req.get('business_context', '')
  2856. output_type = req.get('type', 'ddl')
  2857. # 验证type参数
  2858. valid_types = ['ddl', 'md', 'both']
  2859. if output_type not in valid_types:
  2860. return jsonify(bad_request_response(
  2861. response_text=f"无效的type参数: {output_type},支持的值: {valid_types}",
  2862. invalid_params=['type']
  2863. )), 400
  2864. # 创建表检查API实例
  2865. table_inspector = TableInspectorAPI()
  2866. # 使用asyncio运行异步方法
  2867. async def get_ddl():
  2868. return await table_inspector.get_table_ddl(
  2869. db_connection=db_connection,
  2870. table=table,
  2871. business_context=business_context,
  2872. output_type=output_type
  2873. )
  2874. # 在新的事件循环中运行异步方法
  2875. try:
  2876. loop = asyncio.new_event_loop()
  2877. asyncio.set_event_loop(loop)
  2878. result = loop.run_until_complete(get_ddl())
  2879. finally:
  2880. loop.close()
  2881. response_data = {
  2882. **result,
  2883. "generation_info": {
  2884. "business_context": business_context,
  2885. "output_type": output_type,
  2886. "has_llm_comments": bool(business_context),
  2887. "database": db_connection.split('/')[-1].split('?')[0] if '/' in db_connection else "unknown"
  2888. }
  2889. }
  2890. return jsonify(success_response(
  2891. response_text=f"获取表{output_type.upper()}成功",
  2892. data=response_data
  2893. )), 200
  2894. except Exception as e:
  2895. logger.error(f"获取表DDL失败: {str(e)}")
  2896. return jsonify(internal_error_response(
  2897. response_text=f"获取表{output_type.upper() if 'output_type' in locals() else 'DDL'}失败: {str(e)}"
  2898. )), 500
  2899. # ==================== Data Pipeline 文件管理 API ====================
  2900. from flask import send_file
  2901. # 创建文件管理器
  2902. data_pipeline_file_manager = None
  2903. def get_data_pipeline_file_manager():
  2904. """获取Data Pipeline文件管理器单例"""
  2905. global data_pipeline_file_manager
  2906. if data_pipeline_file_manager is None:
  2907. data_pipeline_file_manager = SimpleFileManager()
  2908. return data_pipeline_file_manager
  2909. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/files', methods=['GET'])
  2910. def get_data_pipeline_task_files(task_id):
  2911. """获取任务文件列表"""
  2912. try:
  2913. file_manager = get_data_pipeline_file_manager()
  2914. # 获取任务文件
  2915. files = file_manager.get_task_files(task_id)
  2916. directory_info = file_manager.get_directory_info(task_id)
  2917. # 格式化文件信息
  2918. formatted_files = []
  2919. for file_info in files:
  2920. formatted_files.append({
  2921. "file_name": file_info['file_name'],
  2922. "file_type": file_info['file_type'],
  2923. "file_size": file_info['file_size'],
  2924. "file_size_formatted": file_info['file_size_formatted'],
  2925. "created_at": file_info['created_at'].isoformat() if file_info.get('created_at') else None,
  2926. "modified_at": file_info['modified_at'].isoformat() if file_info.get('modified_at') else None,
  2927. "is_readable": file_info['is_readable']
  2928. })
  2929. response_data = {
  2930. "task_id": task_id,
  2931. "files": formatted_files,
  2932. "directory_info": directory_info
  2933. }
  2934. return jsonify(success_response(
  2935. response_text="获取任务文件列表成功",
  2936. data=response_data
  2937. ))
  2938. except Exception as e:
  2939. logger.error(f"获取任务文件列表失败: {str(e)}")
  2940. return jsonify(internal_error_response(
  2941. response_text="获取任务文件列表失败,请稍后重试"
  2942. )), 500
  2943. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/files/<file_name>', methods=['GET'])
  2944. def download_data_pipeline_task_file(task_id, file_name):
  2945. """下载任务文件"""
  2946. try:
  2947. logger.info(f"开始下载文件: task_id={task_id}, file_name={file_name}")
  2948. # 直接构建文件路径,避免依赖数据库
  2949. from pathlib import Path
  2950. import os
  2951. # 获取项目根目录的绝对路径
  2952. project_root = Path(__file__).parent.absolute()
  2953. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  2954. file_path = task_dir / file_name
  2955. logger.info(f"文件路径: {file_path}")
  2956. # 检查文件是否存在
  2957. if not file_path.exists():
  2958. logger.warning(f"文件不存在: {file_path}")
  2959. return jsonify(not_found_response(
  2960. response_text=f"文件不存在: {file_name}"
  2961. )), 404
  2962. # 检查是否为文件(而不是目录)
  2963. if not file_path.is_file():
  2964. logger.warning(f"路径不是文件: {file_path}")
  2965. return jsonify(bad_request_response(
  2966. response_text=f"路径不是有效文件: {file_name}"
  2967. )), 400
  2968. # 安全检查:确保文件在允许的目录内
  2969. try:
  2970. file_path.resolve().relative_to(task_dir.resolve())
  2971. except ValueError:
  2972. logger.warning(f"文件路径不安全: {file_path}")
  2973. return jsonify(bad_request_response(
  2974. response_text="非法的文件路径"
  2975. )), 400
  2976. # 检查文件是否可读
  2977. if not os.access(file_path, os.R_OK):
  2978. logger.warning(f"文件不可读: {file_path}")
  2979. return jsonify(bad_request_response(
  2980. response_text="文件不可读"
  2981. )), 400
  2982. logger.info(f"开始发送文件: {file_path}")
  2983. return send_file(
  2984. file_path,
  2985. as_attachment=True,
  2986. download_name=file_name
  2987. )
  2988. except Exception as e:
  2989. logger.error(f"下载任务文件失败: task_id={task_id}, file_name={file_name}, 错误: {str(e)}", exc_info=True)
  2990. return jsonify(internal_error_response(
  2991. response_text="下载文件失败,请稍后重试"
  2992. )), 500
  2993. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/upload-table-list', methods=['POST'])
  2994. def upload_table_list_file(task_id):
  2995. """
  2996. 上传表清单文件
  2997. 表单参数:
  2998. - file: 要上传的表清单文件(multipart/form-data)
  2999. 响应:
  3000. {
  3001. "success": true,
  3002. "code": 200,
  3003. "message": "表清单文件上传成功",
  3004. "data": {
  3005. "task_id": "task_20250701_123456",
  3006. "filename": "table_list.txt",
  3007. "file_size": 1024,
  3008. "file_size_formatted": "1.0 KB"
  3009. }
  3010. }
  3011. """
  3012. try:
  3013. # 验证任务是否存在
  3014. manager = get_data_pipeline_manager()
  3015. task_info = manager.get_task_status(task_id)
  3016. if not task_info:
  3017. return jsonify(not_found_response(
  3018. response_text=f"任务不存在: {task_id}"
  3019. )), 404
  3020. # 检查是否有文件上传
  3021. if 'file' not in request.files:
  3022. return jsonify(bad_request_response(
  3023. response_text="请选择要上传的表清单文件",
  3024. missing_params=['file']
  3025. )), 400
  3026. file = request.files['file']
  3027. # 验证文件名
  3028. if file.filename == '':
  3029. return jsonify(bad_request_response(
  3030. response_text="请选择有效的文件"
  3031. )), 400
  3032. try:
  3033. # 使用文件管理器上传文件
  3034. file_manager = get_data_pipeline_file_manager()
  3035. result = file_manager.upload_table_list_file(task_id, file)
  3036. response_data = {
  3037. "task_id": task_id,
  3038. "filename": result["filename"],
  3039. "file_size": result["file_size"],
  3040. "file_size_formatted": result["file_size_formatted"],
  3041. "upload_time": result["upload_time"].isoformat() if result.get("upload_time") else None
  3042. }
  3043. return jsonify(success_response(
  3044. response_text="表清单文件上传成功",
  3045. data=response_data
  3046. )), 200
  3047. except ValueError as e:
  3048. # 文件验证错误(如文件太大、空文件等)
  3049. return jsonify(bad_request_response(
  3050. response_text=str(e)
  3051. )), 400
  3052. except Exception as e:
  3053. logger.error(f"上传表清单文件失败: {str(e)}")
  3054. return jsonify(internal_error_response(
  3055. response_text="文件上传失败,请稍后重试"
  3056. )), 500
  3057. except Exception as e:
  3058. logger.error(f"处理表清单文件上传请求失败: {str(e)}")
  3059. return jsonify(internal_error_response(
  3060. response_text="处理上传请求失败,请稍后重试"
  3061. )), 500
  3062. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/table-list-info', methods=['GET'])
  3063. def get_table_list_info(task_id):
  3064. """
  3065. 获取任务的表清单文件信息
  3066. 响应:
  3067. {
  3068. "success": true,
  3069. "code": 200,
  3070. "message": "获取表清单文件信息成功",
  3071. "data": {
  3072. "task_id": "task_20250701_123456",
  3073. "has_file": true,
  3074. "filename": "table_list.txt",
  3075. "file_path": "./data_pipeline/training_data/task_20250701_123456/table_list.txt",
  3076. "file_size": 1024,
  3077. "file_size_formatted": "1.0 KB",
  3078. "uploaded_at": "2025-07-01T12:34:56",
  3079. "table_count": 5,
  3080. "is_readable": true
  3081. }
  3082. }
  3083. """
  3084. try:
  3085. file_manager = get_data_pipeline_file_manager()
  3086. # 获取表清单文件信息
  3087. table_list_info = file_manager.get_table_list_file_info(task_id)
  3088. response_data = {
  3089. "task_id": task_id,
  3090. "has_file": table_list_info.get("exists", False),
  3091. **table_list_info
  3092. }
  3093. return jsonify(success_response(
  3094. response_text="获取表清单文件信息成功",
  3095. data=response_data
  3096. ))
  3097. except Exception as e:
  3098. logger.error(f"获取表清单文件信息失败: {str(e)}")
  3099. return jsonify(internal_error_response(
  3100. response_text="获取表清单文件信息失败,请稍后重试"
  3101. )), 500
  3102. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/table-list', methods=['POST'])
  3103. def create_table_list_from_names(task_id):
  3104. """
  3105. 通过POST方式提交表名列表并创建table_list.txt文件
  3106. 请求体:
  3107. {
  3108. "tables": ["table1", "schema.table2", "table3"]
  3109. }
  3110. 或者:
  3111. {
  3112. "tables": "table1,schema.table2,table3"
  3113. }
  3114. 响应:
  3115. {
  3116. "success": true,
  3117. "code": 200,
  3118. "message": "表清单已成功创建",
  3119. "data": {
  3120. "task_id": "task_20250701_123456",
  3121. "filename": "table_list.txt",
  3122. "table_count": 3,
  3123. "file_size": 45,
  3124. "file_size_formatted": "45 B",
  3125. "created_time": "2025-07-01T12:34:56"
  3126. }
  3127. }
  3128. """
  3129. try:
  3130. # 验证任务是否存在
  3131. manager = get_data_pipeline_manager()
  3132. task_info = manager.get_task_status(task_id)
  3133. if not task_info:
  3134. return jsonify(not_found_response(
  3135. response_text=f"任务不存在: {task_id}"
  3136. )), 404
  3137. # 获取请求数据
  3138. req = request.get_json(force=True)
  3139. tables_param = req.get('tables')
  3140. if not tables_param:
  3141. return jsonify(bad_request_response(
  3142. response_text="缺少必需参数:tables",
  3143. missing_params=['tables']
  3144. )), 400
  3145. # 处理不同格式的表名参数
  3146. try:
  3147. if isinstance(tables_param, str):
  3148. # 逗号分隔的字符串格式
  3149. table_names = [name.strip() for name in tables_param.split(',') if name.strip()]
  3150. elif isinstance(tables_param, list):
  3151. # 数组格式
  3152. table_names = [str(name).strip() for name in tables_param if str(name).strip()]
  3153. else:
  3154. return jsonify(bad_request_response(
  3155. response_text="tables参数格式错误,应为字符串(逗号分隔)或数组"
  3156. )), 400
  3157. if not table_names:
  3158. return jsonify(bad_request_response(
  3159. response_text="表名列表不能为空"
  3160. )), 400
  3161. except Exception as e:
  3162. return jsonify(bad_request_response(
  3163. response_text=f"解析tables参数失败: {str(e)}"
  3164. )), 400
  3165. try:
  3166. # 使用文件管理器创建表清单文件
  3167. file_manager = get_data_pipeline_file_manager()
  3168. result = file_manager.create_table_list_from_names(task_id, table_names)
  3169. response_data = {
  3170. "task_id": task_id,
  3171. "filename": result["filename"],
  3172. "table_count": result["table_count"],
  3173. "unique_table_count": result["unique_table_count"],
  3174. "file_size": result["file_size"],
  3175. "file_size_formatted": result["file_size_formatted"],
  3176. "created_time": result["created_time"].isoformat() if result.get("created_time") else None,
  3177. "original_count": len(table_names) if isinstance(table_names, list) else len(tables_param.split(','))
  3178. }
  3179. return jsonify(success_response(
  3180. response_text=f"表清单已成功创建,包含 {result['table_count']} 个表",
  3181. data=response_data
  3182. )), 200
  3183. except ValueError as e:
  3184. # 表名验证错误(如格式错误、数量限制等)
  3185. return jsonify(bad_request_response(
  3186. response_text=str(e)
  3187. )), 400
  3188. except Exception as e:
  3189. logger.error(f"创建表清单文件失败: {str(e)}")
  3190. return jsonify(internal_error_response(
  3191. response_text="创建表清单文件失败,请稍后重试"
  3192. )), 500
  3193. except Exception as e:
  3194. logger.error(f"处理表清单创建请求失败: {str(e)}")
  3195. return jsonify(internal_error_response(
  3196. response_text="处理请求失败,请稍后重试"
  3197. )), 500
  3198. @app.flask_app.route('/api/v0/data_pipeline/tasks/<task_id>/files', methods=['POST'])
  3199. def upload_file_to_task(task_id):
  3200. """
  3201. 上传文件到指定任务目录
  3202. 表单参数:
  3203. - file: 要上传的文件(multipart/form-data)
  3204. - overwrite_mode: 重名处理模式 (backup, replace, skip),默认为backup
  3205. 支持的文件类型:
  3206. - .ddl: DDL文件
  3207. - .md: Markdown文档
  3208. - .txt: 文本文件
  3209. - .json: JSON文件
  3210. - .sql: SQL文件
  3211. - .csv: CSV文件
  3212. 重名处理模式:
  3213. - backup: 备份原文件(默认)
  3214. - replace: 直接覆盖
  3215. - skip: 跳过上传
  3216. 响应:
  3217. {
  3218. "success": true,
  3219. "code": 200,
  3220. "message": "文件上传成功",
  3221. "data": {
  3222. "task_id": "task_20250701_123456",
  3223. "uploaded_file": {
  3224. "filename": "test.ddl",
  3225. "size": 1024,
  3226. "size_formatted": "1.0 KB",
  3227. "uploaded_at": "2025-07-01T12:34:56",
  3228. "overwrite_mode": "backup"
  3229. },
  3230. "backup_info": { // 仅当overwrite_mode为backup且文件已存在时返回
  3231. "had_existing_file": true,
  3232. "backup_filename": "test.ddl_bak1",
  3233. "backup_version": 1,
  3234. "backup_created_at": "2025-07-01T12:34:56"
  3235. }
  3236. }
  3237. }
  3238. """
  3239. try:
  3240. # 验证任务是否存在
  3241. manager = get_data_pipeline_manager()
  3242. task_info = manager.get_task_status(task_id)
  3243. if not task_info:
  3244. return jsonify(not_found_response(
  3245. response_text=f"任务不存在: {task_id}"
  3246. )), 404
  3247. # 检查是否有文件上传
  3248. if 'file' not in request.files:
  3249. return jsonify(bad_request_response(
  3250. response_text="请选择要上传的文件",
  3251. missing_params=['file']
  3252. )), 400
  3253. file = request.files['file']
  3254. # 验证文件名
  3255. if file.filename == '':
  3256. return jsonify(bad_request_response(
  3257. response_text="请选择有效的文件"
  3258. )), 400
  3259. # 获取重名处理模式
  3260. overwrite_mode = request.form.get('overwrite_mode', 'backup')
  3261. # 验证重名处理模式
  3262. valid_modes = ['backup', 'replace', 'skip']
  3263. if overwrite_mode not in valid_modes:
  3264. return jsonify(bad_request_response(
  3265. response_text=f"无效的overwrite_mode参数: {overwrite_mode},支持的值: {valid_modes}",
  3266. invalid_params=['overwrite_mode']
  3267. )), 400
  3268. try:
  3269. # 使用文件管理器上传文件
  3270. file_manager = get_data_pipeline_file_manager()
  3271. result = file_manager.upload_file_to_task(task_id, file, file.filename, overwrite_mode)
  3272. # 检查是否跳过上传
  3273. if result.get('skipped'):
  3274. return jsonify(success_response(
  3275. response_text=result.get('message', '文件已存在,跳过上传'),
  3276. data=result
  3277. )), 200
  3278. return jsonify(success_response(
  3279. response_text="文件上传成功",
  3280. data=result
  3281. )), 200
  3282. except ValueError as e:
  3283. # 文件验证错误(如文件太大、空文件、不支持的类型等)
  3284. return jsonify(bad_request_response(
  3285. response_text=str(e)
  3286. )), 400
  3287. except Exception as e:
  3288. logger.error(f"上传文件失败: {str(e)}")
  3289. return jsonify(internal_error_response(
  3290. response_text="文件上传失败,请稍后重试"
  3291. )), 500
  3292. except Exception as e:
  3293. logger.error(f"处理文件上传请求失败: {str(e)}")
  3294. return jsonify(internal_error_response(
  3295. response_text="处理上传请求失败,请稍后重试"
  3296. )), 500
  3297. # ==================== 任务目录删除API ====================
  3298. import shutil
  3299. from pathlib import Path
  3300. from datetime import datetime
  3301. import psycopg2
  3302. from app_config import PGVECTOR_CONFIG
  3303. def delete_task_directory_simple(task_id, delete_database_records=False):
  3304. """
  3305. 简单的任务目录删除功能
  3306. - 删除 data_pipeline/training_data/{task_id} 目录
  3307. - 更新数据库中的 directory_exists 字段
  3308. - 可选:删除数据库记录
  3309. """
  3310. try:
  3311. # 1. 删除目录
  3312. project_root = Path(__file__).parent.absolute()
  3313. task_dir = project_root / "data_pipeline" / "training_data" / task_id
  3314. deleted_files_count = 0
  3315. deleted_size = 0
  3316. if task_dir.exists():
  3317. # 计算删除前的统计信息
  3318. for file_path in task_dir.rglob('*'):
  3319. if file_path.is_file():
  3320. deleted_files_count += 1
  3321. deleted_size += file_path.stat().st_size
  3322. # 删除目录
  3323. shutil.rmtree(task_dir)
  3324. directory_deleted = True
  3325. else:
  3326. directory_deleted = False
  3327. # 2. 更新数据库
  3328. database_records_deleted = False
  3329. try:
  3330. conn = psycopg2.connect(**PGVECTOR_CONFIG)
  3331. cur = conn.cursor()
  3332. if delete_database_records:
  3333. # 删除任务步骤记录
  3334. cur.execute("DELETE FROM data_pipeline_task_steps WHERE task_id = %s", (task_id,))
  3335. # 删除任务主记录
  3336. cur.execute("DELETE FROM data_pipeline_tasks WHERE task_id = %s", (task_id,))
  3337. database_records_deleted = True
  3338. else:
  3339. # 只更新目录状态
  3340. cur.execute("""
  3341. UPDATE data_pipeline_tasks
  3342. SET directory_exists = FALSE, updated_at = CURRENT_TIMESTAMP
  3343. WHERE task_id = %s
  3344. """, (task_id,))
  3345. conn.commit()
  3346. cur.close()
  3347. conn.close()
  3348. except Exception as db_error:
  3349. logger.error(f"数据库操作失败: {db_error}")
  3350. # 数据库失败不影响文件删除的结果
  3351. # 3. 格式化文件大小
  3352. def format_size(size_bytes):
  3353. if size_bytes < 1024:
  3354. return f"{size_bytes} B"
  3355. elif size_bytes < 1024**2:
  3356. return f"{size_bytes/1024:.1f} KB"
  3357. elif size_bytes < 1024**3:
  3358. return f"{size_bytes/(1024**2):.1f} MB"
  3359. else:
  3360. return f"{size_bytes/(1024**3):.1f} GB"
  3361. return {
  3362. "success": True,
  3363. "task_id": task_id,
  3364. "directory_deleted": directory_deleted,
  3365. "database_records_deleted": database_records_deleted,
  3366. "deleted_files_count": deleted_files_count,
  3367. "deleted_size": format_size(deleted_size),
  3368. "deleted_at": datetime.now().isoformat()
  3369. }
  3370. except Exception as e:
  3371. logger.error(f"删除任务目录失败: {task_id}, 错误: {str(e)}")
  3372. return {
  3373. "success": False,
  3374. "task_id": task_id,
  3375. "error": str(e),
  3376. "error_code": "DELETE_FAILED"
  3377. }
  3378. @app.flask_app.route('/api/v0/data_pipeline/tasks', methods=['DELETE'])
  3379. def delete_tasks():
  3380. """删除任务目录(支持单个和批量)"""
  3381. try:
  3382. # 获取请求参数
  3383. req = request.get_json(force=True)
  3384. # 验证必需参数
  3385. task_ids = req.get('task_ids')
  3386. confirm = req.get('confirm')
  3387. if not task_ids:
  3388. return jsonify(bad_request_response(
  3389. response_text="缺少必需参数: task_ids",
  3390. missing_params=['task_ids']
  3391. )), 400
  3392. if not confirm:
  3393. return jsonify(bad_request_response(
  3394. response_text="缺少必需参数: confirm",
  3395. missing_params=['confirm']
  3396. )), 400
  3397. if confirm != True:
  3398. return jsonify(bad_request_response(
  3399. response_text="confirm参数必须为true以确认删除操作"
  3400. )), 400
  3401. if not isinstance(task_ids, list) or len(task_ids) == 0:
  3402. return jsonify(bad_request_response(
  3403. response_text="task_ids必须是非空的任务ID列表"
  3404. )), 400
  3405. # 获取可选参数
  3406. delete_database_records = req.get('delete_database_records', False)
  3407. continue_on_error = req.get('continue_on_error', True)
  3408. # 执行批量删除操作
  3409. deleted_tasks = []
  3410. failed_tasks = []
  3411. total_size_freed = 0
  3412. for task_id in task_ids:
  3413. result = delete_task_directory_simple(task_id, delete_database_records)
  3414. if result["success"]:
  3415. deleted_tasks.append(result)
  3416. # 累计释放的空间大小(这里简化处理,实际应该解析size字符串)
  3417. else:
  3418. failed_tasks.append({
  3419. "task_id": task_id,
  3420. "error": result["error"],
  3421. "error_code": result.get("error_code", "UNKNOWN")
  3422. })
  3423. if not continue_on_error:
  3424. break
  3425. # 构建响应
  3426. summary = {
  3427. "total_requested": len(task_ids),
  3428. "successfully_deleted": len(deleted_tasks),
  3429. "failed": len(failed_tasks)
  3430. }
  3431. batch_result = {
  3432. "deleted_tasks": deleted_tasks,
  3433. "failed_tasks": failed_tasks,
  3434. "summary": summary,
  3435. "deleted_at": datetime.now().isoformat()
  3436. }
  3437. if len(task_ids) == 1:
  3438. # 单个删除
  3439. if summary["failed"] == 0:
  3440. message = "任务目录删除成功"
  3441. else:
  3442. message = "任务目录删除失败"
  3443. else:
  3444. # 批量删除
  3445. if summary["failed"] == 0:
  3446. message = "批量删除完成"
  3447. elif summary["successfully_deleted"] == 0:
  3448. message = "批量删除失败"
  3449. else:
  3450. message = "批量删除部分完成"
  3451. return jsonify(success_response(
  3452. response_text=message,
  3453. data=batch_result
  3454. )), 200
  3455. except Exception as e:
  3456. logger.error(f"删除任务失败: 错误: {str(e)}")
  3457. return jsonify(internal_error_response(
  3458. response_text="删除任务失败,请稍后重试"
  3459. )), 500
  3460. logger.info("启动Flask应用: http://localhost:8084")
  3461. app.run(host="0.0.0.0", port=8084, debug=True)