citu_app.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. vn = create_vanna_instance()
  10. # 创建带时间戳的缓存
  11. timestamped_cache = WebSessionAwareMemoryCache()
  12. # 实例化 VannaFlaskApp,使用自定义缓存
  13. app = VannaFlaskApp(
  14. vn,
  15. cache=timestamped_cache, # 使用带时间戳的缓存
  16. title="辞图智能数据问答平台",
  17. logo = "https://www.citupro.com/img/logo-black-2.png",
  18. subtitle="让 AI 为你写 SQL",
  19. chart=False,
  20. allow_llm_to_see_data=True,
  21. ask_results_correct=True,
  22. followup_questions=True,
  23. debug=True
  24. )
  25. # 修改ask接口,支持前端传递session_id
  26. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  27. def ask_full():
  28. req = request.get_json(force=True)
  29. question = req.get("question", None)
  30. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  31. if not question:
  32. return jsonify(result.failed(message="未提供问题", code=400)), 400
  33. # 如果使用WebSessionAwareMemoryCache
  34. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  35. # 这里需要修改vanna的ask方法来支持传递session_id
  36. # 或者预先调用generate_id来建立会话关联
  37. conversation_id = app.cache.generate_id_with_browser_session(
  38. question=question,
  39. browser_session_id=browser_session_id
  40. )
  41. try:
  42. sql, df, _ = vn.ask(
  43. question=question,
  44. print_results=False,
  45. visualize=False,
  46. allow_llm_to_see_data=True
  47. )
  48. rows, columns = [], []
  49. summary = None
  50. if isinstance(df, pd.DataFrame) and not df.empty:
  51. rows = df.head(1000).to_dict(orient="records")
  52. columns = list(df.columns)
  53. # 生成数据摘要
  54. try:
  55. summary = vn.generate_summary(question=question, df=df)
  56. print(f"[INFO] 成功生成摘要: {summary}")
  57. except Exception as e:
  58. print(f"[WARNING] 生成摘要失败: {str(e)}")
  59. summary = None
  60. return jsonify(result.success(data={
  61. "sql": sql,
  62. "rows": rows,
  63. "columns": columns,
  64. "summary": summary, # 添加摘要到返回结果
  65. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  66. "session_id": browser_session_id
  67. }))
  68. except Exception as e:
  69. print(f"[ERROR] ask_full执行失败: {str(e)}")
  70. return jsonify(result.failed(
  71. message=f"查询处理失败: {str(e)}",
  72. code=500
  73. )), 500
  74. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  75. def citu_run_sql():
  76. req = request.get_json(force=True)
  77. sql = req.get('sql')
  78. if not sql:
  79. return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
  80. try:
  81. df = vn.run_sql(sql)
  82. rows, columns = [], []
  83. if isinstance(df, pd.DataFrame) and not df.empty:
  84. rows = df.head(1000).to_dict(orient="records")
  85. columns = list(df.columns)
  86. return jsonify(result.success(data={
  87. "sql": sql,
  88. "rows": rows,
  89. "columns": columns
  90. }))
  91. except Exception as e:
  92. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  93. return jsonify(result.failed(
  94. message=f"SQL执行失败: {str(e)}",
  95. code=500
  96. )), 500
  97. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  98. def ask_cached():
  99. """
  100. 带缓存功能的智能查询接口
  101. 支持会话管理和结果缓存,提高查询效率
  102. """
  103. req = request.get_json(force=True)
  104. question = req.get("question", None)
  105. browser_session_id = req.get("session_id", None)
  106. if not question:
  107. return jsonify(result.failed(message="未提供问题", code=400)), 400
  108. try:
  109. # 生成conversation_id
  110. # 调试:查看generate_id的实际行为
  111. print(f"[DEBUG] 输入问题: '{question}'")
  112. conversation_id = app.cache.generate_id(question=question)
  113. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  114. # 再次用相同问题测试
  115. conversation_id2 = app.cache.generate_id(question=question)
  116. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  117. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  118. # 检查缓存
  119. cached_sql = app.cache.get(id=conversation_id, field="sql")
  120. if cached_sql is not None:
  121. # 缓存命中
  122. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  123. sql = cached_sql
  124. df = app.cache.get(id=conversation_id, field="df")
  125. summary = app.cache.get(id=conversation_id, field="summary")
  126. else:
  127. # 缓存未命中,执行新查询
  128. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  129. sql, df, _ = vn.ask(
  130. question=question,
  131. print_results=False,
  132. visualize=False,
  133. allow_llm_to_see_data=True
  134. )
  135. # 缓存结果
  136. app.cache.set(id=conversation_id, field="question", value=question)
  137. app.cache.set(id=conversation_id, field="sql", value=sql)
  138. app.cache.set(id=conversation_id, field="df", value=df)
  139. # 生成并缓存摘要
  140. summary = None
  141. if isinstance(df, pd.DataFrame) and not df.empty:
  142. try:
  143. summary = vn.generate_summary(question=question, df=df)
  144. print(f"[INFO] 成功生成摘要: {summary}")
  145. except Exception as e:
  146. print(f"[WARNING] 生成摘要失败: {str(e)}")
  147. summary = None
  148. app.cache.set(id=conversation_id, field="summary", value=summary)
  149. # 处理返回数据
  150. rows, columns = [], []
  151. if isinstance(df, pd.DataFrame) and not df.empty:
  152. rows = df.head(1000).to_dict(orient="records")
  153. columns = list(df.columns)
  154. return jsonify(result.success(data={
  155. "sql": sql,
  156. "rows": rows,
  157. "columns": columns,
  158. "summary": summary,
  159. "conversation_id": conversation_id,
  160. "session_id": browser_session_id,
  161. "cached": cached_sql is not None # 标识是否来自缓存
  162. }))
  163. except Exception as e:
  164. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  165. return jsonify(result.failed(
  166. message=f"查询处理失败: {str(e)}",
  167. code=500
  168. )), 500
  169. @app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
  170. def citu_train_question_sql():
  171. """
  172. 训练问题-SQL对接口
  173. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  174. 支持仅传入SQL或同时传入问题和SQL进行训练。
  175. Args:
  176. question (str, optional): 用户问题
  177. sql (str, required): 对应的SQL查询语句
  178. Returns:
  179. JSON: 包含训练ID和成功消息的响应
  180. """
  181. try:
  182. req = request.get_json(force=True)
  183. question = req.get('question')
  184. sql = req.get('sql')
  185. if not sql:
  186. return jsonify(result.failed(
  187. message="'sql' are required",
  188. code=400
  189. )), 400
  190. # 正确的调用方式:同时传递question和sql
  191. if question:
  192. training_id = vn.train(question=question, sql=sql)
  193. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  194. else:
  195. training_id = vn.train(sql=sql)
  196. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  197. return jsonify(result.success(data={
  198. "training_id": training_id,
  199. "message": "Question-SQL pair trained successfully"
  200. }))
  201. except Exception as e:
  202. return jsonify(result.failed(
  203. message=f"Training failed: {str(e)}",
  204. code=500
  205. )), 500
  206. # ==================== 日常管理API ====================
  207. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  208. def cache_overview():
  209. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  210. try:
  211. cache = app.cache
  212. result_data = {
  213. 'overview_summary': {
  214. 'total_conversations': 0,
  215. 'total_sessions': 0,
  216. 'query_time': datetime.now().isoformat()
  217. },
  218. 'recent_conversations': [], # 最近的对话
  219. 'session_summary': [] # 会话摘要
  220. }
  221. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  222. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  223. # 获取会话信息
  224. if hasattr(cache, 'get_all_sessions'):
  225. all_sessions = cache.get_all_sessions()
  226. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  227. # 会话摘要(按最近活动排序)
  228. session_list = []
  229. for session_id, session_data in all_sessions.items():
  230. session_summary = {
  231. 'session_id': session_id,
  232. 'start_time': session_data['start_time'].isoformat(),
  233. 'conversation_count': session_data.get('conversation_count', 0),
  234. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  235. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  236. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  237. }
  238. session_list.append(session_summary)
  239. # 按最后活动时间排序
  240. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  241. result_data['session_summary'] = session_list
  242. # 最近的对话(最多显示10个)
  243. conversation_list = []
  244. for conversation_id, conversation_data in cache.cache.items():
  245. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  246. conversation_info = {
  247. 'conversation_id': conversation_id,
  248. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  249. 'session_id': cache.conversation_to_session.get(conversation_id),
  250. 'has_question': 'question' in conversation_data,
  251. 'has_sql': 'sql' in conversation_data,
  252. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  253. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  254. }
  255. # 计算对话持续时间
  256. if conversation_start_time:
  257. duration = datetime.now() - conversation_start_time
  258. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  259. conversation_list.append(conversation_info)
  260. # 按对话开始时间排序,显示最新的10个
  261. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  262. result_data['recent_conversations'] = conversation_list[:10]
  263. return jsonify(result.success(data=result_data))
  264. except Exception as e:
  265. return jsonify(result.failed(
  266. message=f"获取缓存概览失败: {str(e)}",
  267. code=500
  268. )), 500
  269. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  270. def cache_stats():
  271. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  272. try:
  273. cache = app.cache
  274. current_time = datetime.now()
  275. stats = {
  276. 'basic_stats': {
  277. 'total_sessions': len(getattr(cache, 'session_info', {})),
  278. 'total_conversations': len(getattr(cache, 'cache', {})),
  279. 'active_sessions': 0, # 最近30分钟有活动
  280. 'average_conversations_per_session': 0
  281. },
  282. 'time_distribution': {
  283. 'sessions': {
  284. 'last_1_hour': 0,
  285. 'last_6_hours': 0,
  286. 'last_24_hours': 0,
  287. 'last_7_days': 0,
  288. 'older': 0
  289. },
  290. 'conversations': {
  291. 'last_1_hour': 0,
  292. 'last_6_hours': 0,
  293. 'last_24_hours': 0,
  294. 'last_7_days': 0,
  295. 'older': 0
  296. }
  297. },
  298. 'session_details': [],
  299. 'time_ranges': {
  300. 'oldest_session': None,
  301. 'newest_session': None,
  302. 'oldest_conversation': None,
  303. 'newest_conversation': None
  304. }
  305. }
  306. # 会话统计
  307. if hasattr(cache, 'session_info'):
  308. session_times = []
  309. total_conversations = 0
  310. for session_id, session_data in cache.session_info.items():
  311. start_time = session_data['start_time']
  312. session_times.append(start_time)
  313. conversation_count = len(session_data.get('conversations', []))
  314. total_conversations += conversation_count
  315. # 检查活跃状态
  316. last_activity = session_data.get('last_activity', session_data['start_time'])
  317. if (current_time - last_activity).total_seconds() < 1800:
  318. stats['basic_stats']['active_sessions'] += 1
  319. # 时间分布统计
  320. age_hours = (current_time - start_time).total_seconds() / 3600
  321. if age_hours <= 1:
  322. stats['time_distribution']['sessions']['last_1_hour'] += 1
  323. elif age_hours <= 6:
  324. stats['time_distribution']['sessions']['last_6_hours'] += 1
  325. elif age_hours <= 24:
  326. stats['time_distribution']['sessions']['last_24_hours'] += 1
  327. elif age_hours <= 168: # 7 days
  328. stats['time_distribution']['sessions']['last_7_days'] += 1
  329. else:
  330. stats['time_distribution']['sessions']['older'] += 1
  331. # 会话详细信息
  332. session_duration = current_time - start_time
  333. stats['session_details'].append({
  334. 'session_id': session_id,
  335. 'start_time': start_time.isoformat(),
  336. 'last_activity': last_activity.isoformat(),
  337. 'conversation_count': conversation_count,
  338. 'duration_seconds': session_duration.total_seconds(),
  339. 'duration_formatted': str(session_duration),
  340. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  341. 'browser_session_id': session_data.get('browser_session_id')
  342. })
  343. # 计算平均值
  344. if len(cache.session_info) > 0:
  345. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  346. # 时间范围
  347. if session_times:
  348. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  349. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  350. # 对话统计
  351. if hasattr(cache, 'conversation_start_times'):
  352. conversation_times = []
  353. for conv_time in cache.conversation_start_times.values():
  354. conversation_times.append(conv_time)
  355. age_hours = (current_time - conv_time).total_seconds() / 3600
  356. if age_hours <= 1:
  357. stats['time_distribution']['conversations']['last_1_hour'] += 1
  358. elif age_hours <= 6:
  359. stats['time_distribution']['conversations']['last_6_hours'] += 1
  360. elif age_hours <= 24:
  361. stats['time_distribution']['conversations']['last_24_hours'] += 1
  362. elif age_hours <= 168:
  363. stats['time_distribution']['conversations']['last_7_days'] += 1
  364. else:
  365. stats['time_distribution']['conversations']['older'] += 1
  366. if conversation_times:
  367. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  368. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  369. # 按最近活动排序会话详情
  370. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  371. return jsonify(result.success(data=stats))
  372. except Exception as e:
  373. return jsonify(result.failed(
  374. message=f"获取缓存统计失败: {str(e)}",
  375. code=500
  376. )), 500
  377. # ==================== 高级功能API ====================
  378. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  379. def cache_export():
  380. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  381. try:
  382. cache = app.cache
  383. # 验证缓存的实际结构
  384. if not hasattr(cache, 'cache'):
  385. return jsonify(result.failed(message="缓存对象没有cache属性", code=500)), 500
  386. if not isinstance(cache.cache, dict):
  387. return jsonify(result.failed(message="缓存不是字典类型", code=500)), 500
  388. # 定义JSON序列化辅助函数
  389. def make_json_serializable(obj):
  390. """将对象转换为JSON可序列化的格式"""
  391. if obj is None:
  392. return None
  393. elif isinstance(obj, (str, int, float, bool)):
  394. return obj
  395. elif isinstance(obj, (list, tuple)):
  396. return [make_json_serializable(item) for item in obj]
  397. elif isinstance(obj, dict):
  398. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  399. elif hasattr(obj, 'isoformat'): # datetime objects
  400. return obj.isoformat()
  401. elif hasattr(obj, 'item'): # numpy scalars
  402. return obj.item()
  403. elif hasattr(obj, 'tolist'): # numpy arrays
  404. return obj.tolist()
  405. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  406. return str(obj)
  407. else:
  408. return str(obj)
  409. # 获取完整的原始缓存数据
  410. raw_cache = cache.cache
  411. # 获取会话和对话时间信息
  412. conversation_times = getattr(cache, 'conversation_start_times', {})
  413. session_info = getattr(cache, 'session_info', {})
  414. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  415. export_data = {
  416. 'export_metadata': {
  417. 'export_time': datetime.now().isoformat(),
  418. 'total_conversations': len(raw_cache),
  419. 'total_sessions': len(session_info),
  420. 'cache_type': type(cache).__name__,
  421. 'cache_object_info': str(cache),
  422. 'has_session_times': bool(session_info),
  423. 'has_conversation_times': bool(conversation_times)
  424. },
  425. 'session_info': {
  426. session_id: {
  427. 'start_time': session_data['start_time'].isoformat(),
  428. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  429. 'conversations': session_data['conversations'],
  430. 'conversation_count': len(session_data['conversations']),
  431. 'browser_session_id': session_data.get('browser_session_id'),
  432. 'user_info': session_data.get('user_info', {})
  433. }
  434. for session_id, session_data in session_info.items()
  435. },
  436. 'conversation_times': {
  437. conversation_id: start_time.isoformat()
  438. for conversation_id, start_time in conversation_times.items()
  439. },
  440. 'conversation_to_session_mapping': conversation_to_session,
  441. 'conversations': {}
  442. }
  443. # 处理每个对话的完整数据
  444. for conversation_id, conversation_data in raw_cache.items():
  445. # 获取时间信息
  446. conversation_start_time = conversation_times.get(conversation_id)
  447. session_id = conversation_to_session.get(conversation_id)
  448. session_start_time = None
  449. if session_id and session_id in session_info:
  450. session_start_time = session_info[session_id]['start_time']
  451. processed_conversation = {
  452. 'conversation_id': conversation_id,
  453. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  454. 'session_id': session_id,
  455. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  456. 'field_count': len(conversation_data),
  457. 'fields': {}
  458. }
  459. # 添加时间计算
  460. if conversation_start_time:
  461. conversation_duration = datetime.now() - conversation_start_time
  462. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  463. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  464. if session_start_time:
  465. session_duration = datetime.now() - session_start_time
  466. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  467. processed_conversation['session_duration_formatted'] = str(session_duration)
  468. # 处理每个字段,确保JSON序列化安全
  469. for field_name, field_value in conversation_data.items():
  470. field_info = {
  471. 'field_name': field_name,
  472. 'data_type': type(field_value).__name__,
  473. 'is_none': field_value is None
  474. }
  475. try:
  476. if field_value is None:
  477. field_info['value'] = None
  478. elif field_name in ['conversation_start_time', 'session_start_time']:
  479. # 处理时间字段
  480. field_info['content'] = make_json_serializable(field_value)
  481. elif field_name == 'df' and field_value is not None:
  482. # DataFrame的安全处理
  483. if hasattr(field_value, 'to_dict'):
  484. # 安全地处理dtypes
  485. try:
  486. dtypes_dict = {}
  487. for col, dtype in field_value.dtypes.items():
  488. dtypes_dict[col] = str(dtype)
  489. except Exception:
  490. dtypes_dict = {"error": "无法序列化dtypes"}
  491. # 安全地处理内存使用
  492. try:
  493. memory_usage = field_value.memory_usage(deep=True)
  494. memory_dict = {}
  495. for idx, usage in memory_usage.items():
  496. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  497. except Exception:
  498. memory_dict = {"error": "无法获取内存使用信息"}
  499. field_info.update({
  500. 'dataframe_info': {
  501. 'shape': list(field_value.shape),
  502. 'columns': list(field_value.columns),
  503. 'dtypes': dtypes_dict,
  504. 'index_info': {
  505. 'type': type(field_value.index).__name__,
  506. 'length': len(field_value.index)
  507. }
  508. },
  509. 'data': make_json_serializable(field_value.to_dict('records')),
  510. 'memory_usage': memory_dict
  511. })
  512. else:
  513. field_info['value'] = str(field_value)
  514. field_info['note'] = 'not_standard_dataframe'
  515. elif field_name == 'fig_json':
  516. # 图表JSON数据处理
  517. if isinstance(field_value, str):
  518. try:
  519. import json
  520. parsed_fig = json.loads(field_value)
  521. field_info.update({
  522. 'json_valid': True,
  523. 'json_size_bytes': len(field_value),
  524. 'plotly_structure': {
  525. 'has_data': 'data' in parsed_fig,
  526. 'has_layout': 'layout' in parsed_fig,
  527. 'data_traces_count': len(parsed_fig.get('data', [])),
  528. },
  529. 'raw_json': field_value
  530. })
  531. except json.JSONDecodeError:
  532. field_info.update({
  533. 'json_valid': False,
  534. 'raw_content': str(field_value)
  535. })
  536. else:
  537. field_info['value'] = make_json_serializable(field_value)
  538. elif field_name == 'followup_questions':
  539. # 后续问题列表
  540. field_info.update({
  541. 'content': make_json_serializable(field_value)
  542. })
  543. elif field_name in ['question', 'sql', 'summary']:
  544. # 文本字段
  545. if isinstance(field_value, str):
  546. field_info.update({
  547. 'text_length': len(field_value),
  548. 'content': field_value
  549. })
  550. else:
  551. field_info['value'] = make_json_serializable(field_value)
  552. else:
  553. # 未知字段的安全处理
  554. field_info['content'] = make_json_serializable(field_value)
  555. except Exception as e:
  556. field_info.update({
  557. 'processing_error': str(e),
  558. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  559. })
  560. processed_conversation['fields'][field_name] = field_info
  561. export_data['conversations'][conversation_id] = processed_conversation
  562. # 添加缓存统计信息
  563. field_frequency = {}
  564. data_types_found = set()
  565. total_dataframes = 0
  566. total_questions = 0
  567. for conv_data in export_data['conversations'].values():
  568. for field_name, field_info in conv_data['fields'].items():
  569. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  570. data_types_found.add(field_info['data_type'])
  571. if field_name == 'df' and not field_info['is_none']:
  572. total_dataframes += 1
  573. if field_name == 'question' and not field_info['is_none']:
  574. total_questions += 1
  575. export_data['cache_statistics'] = {
  576. 'field_frequency': field_frequency,
  577. 'data_types_found': list(data_types_found),
  578. 'total_dataframes': total_dataframes,
  579. 'total_questions': total_questions,
  580. 'has_session_timing': 'session_start_time' in field_frequency,
  581. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  582. }
  583. return jsonify(result.success(data=export_data))
  584. except Exception as e:
  585. import traceback
  586. error_details = {
  587. 'error_message': str(e),
  588. 'error_type': type(e).__name__,
  589. 'traceback': traceback.format_exc()
  590. }
  591. return jsonify(result.failed(
  592. message=f"导出缓存失败: {str(e)}",
  593. code=500,
  594. data=error_details
  595. )), 500
  596. # ==================== 清理功能API ====================
  597. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  598. def cache_preview_cleanup():
  599. """清理功能:预览删除操作 - 保持原功能"""
  600. try:
  601. req = request.get_json(force=True)
  602. # 时间条件 - 支持三种方式
  603. older_than_hours = req.get('older_than_hours')
  604. older_than_days = req.get('older_than_days')
  605. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  606. cache = app.cache
  607. # 计算截止时间
  608. cutoff_time = None
  609. time_condition = None
  610. if older_than_hours:
  611. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  612. time_condition = f"older_than_hours: {older_than_hours}"
  613. elif older_than_days:
  614. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  615. time_condition = f"older_than_days: {older_than_days}"
  616. elif before_timestamp:
  617. try:
  618. # 支持 YYYY-MM-DD HH:MM:SS 格式
  619. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  620. time_condition = f"before_timestamp: {before_timestamp}"
  621. except ValueError:
  622. return jsonify(result.failed(
  623. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  624. code=400
  625. )), 400
  626. else:
  627. return jsonify(result.failed(
  628. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  629. code=400
  630. )), 400
  631. preview = {
  632. 'time_condition': time_condition,
  633. 'cutoff_time': cutoff_time.isoformat(),
  634. 'will_be_removed': {
  635. 'sessions': []
  636. },
  637. 'will_be_kept': {
  638. 'sessions_count': 0,
  639. 'conversations_count': 0
  640. },
  641. 'summary': {
  642. 'sessions_to_remove': 0,
  643. 'conversations_to_remove': 0,
  644. 'sessions_to_keep': 0,
  645. 'conversations_to_keep': 0
  646. }
  647. }
  648. # 预览按session删除
  649. sessions_to_remove_count = 0
  650. conversations_to_remove_count = 0
  651. for session_id, session_data in cache.session_info.items():
  652. session_preview = {
  653. 'session_id': session_id,
  654. 'start_time': session_data['start_time'].isoformat(),
  655. 'conversation_count': len(session_data['conversations']),
  656. 'conversations': []
  657. }
  658. # 添加conversation详情
  659. for conv_id in session_data['conversations']:
  660. if conv_id in cache.cache:
  661. conv_data = cache.cache[conv_id]
  662. session_preview['conversations'].append({
  663. 'conversation_id': conv_id,
  664. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  665. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  666. })
  667. if session_data['start_time'] < cutoff_time:
  668. preview['will_be_removed']['sessions'].append(session_preview)
  669. sessions_to_remove_count += 1
  670. conversations_to_remove_count += len(session_data['conversations'])
  671. else:
  672. preview['will_be_kept']['sessions_count'] += 1
  673. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  674. # 更新摘要统计
  675. preview['summary'] = {
  676. 'sessions_to_remove': sessions_to_remove_count,
  677. 'conversations_to_remove': conversations_to_remove_count,
  678. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  679. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  680. }
  681. return jsonify(result.success(data=preview))
  682. except Exception as e:
  683. return jsonify(result.failed(
  684. message=f"预览清理操作失败: {str(e)}",
  685. code=500
  686. )), 500
  687. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  688. def cache_cleanup():
  689. """清理功能:实际删除缓存 - 保持原功能"""
  690. try:
  691. req = request.get_json(force=True)
  692. # 时间条件 - 支持三种方式
  693. older_than_hours = req.get('older_than_hours')
  694. older_than_days = req.get('older_than_days')
  695. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  696. cache = app.cache
  697. if not hasattr(cache, 'session_info'):
  698. return jsonify(result.failed(
  699. message="缓存不支持会话功能",
  700. code=400
  701. )), 400
  702. # 计算截止时间
  703. cutoff_time = None
  704. time_condition = None
  705. if older_than_hours:
  706. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  707. time_condition = f"older_than_hours: {older_than_hours}"
  708. elif older_than_days:
  709. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  710. time_condition = f"older_than_days: {older_than_days}"
  711. elif before_timestamp:
  712. try:
  713. # 支持 YYYY-MM-DD HH:MM:SS 格式
  714. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  715. time_condition = f"before_timestamp: {before_timestamp}"
  716. except ValueError:
  717. return jsonify(result.failed(
  718. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  719. code=400
  720. )), 400
  721. else:
  722. return jsonify(result.failed(
  723. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  724. code=400
  725. )), 400
  726. cleanup_stats = {
  727. 'time_condition': time_condition,
  728. 'cutoff_time': cutoff_time.isoformat(),
  729. 'sessions_removed': 0,
  730. 'conversations_removed': 0,
  731. 'sessions_kept': 0,
  732. 'conversations_kept': 0,
  733. 'removed_session_ids': [],
  734. 'removed_conversation_ids': []
  735. }
  736. # 按session删除
  737. sessions_to_remove = []
  738. for session_id, session_data in cache.session_info.items():
  739. if session_data['start_time'] < cutoff_time:
  740. sessions_to_remove.append(session_id)
  741. # 删除符合条件的sessions及其所有conversations
  742. for session_id in sessions_to_remove:
  743. session_data = cache.session_info[session_id]
  744. conversations_in_session = session_data['conversations'].copy()
  745. # 删除session中的所有conversations
  746. for conv_id in conversations_in_session:
  747. if conv_id in cache.cache:
  748. del cache.cache[conv_id]
  749. cleanup_stats['conversations_removed'] += 1
  750. cleanup_stats['removed_conversation_ids'].append(conv_id)
  751. # 清理conversation相关的时间记录
  752. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  753. del cache.conversation_start_times[conv_id]
  754. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  755. del cache.conversation_to_session[conv_id]
  756. # 删除session记录
  757. del cache.session_info[session_id]
  758. cleanup_stats['sessions_removed'] += 1
  759. cleanup_stats['removed_session_ids'].append(session_id)
  760. # 统计保留的sessions和conversations
  761. cleanup_stats['sessions_kept'] = len(cache.session_info)
  762. cleanup_stats['conversations_kept'] = len(cache.cache)
  763. return jsonify(result.success(data=cleanup_stats))
  764. except Exception as e:
  765. return jsonify(result.failed(
  766. message=f"清理缓存失败: {str(e)}",
  767. code=500
  768. )), 500
  769. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  770. def training_error_question_sql():
  771. """
  772. 存储错误的question-sql对到error_sql集合中
  773. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  774. Args:
  775. question (str, required): 用户问题
  776. sql (str, required): 对应的错误SQL查询语句
  777. Returns:
  778. JSON: 包含训练ID和成功消息的响应
  779. """
  780. try:
  781. data = request.get_json()
  782. question = data.get('question')
  783. sql = data.get('sql')
  784. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  785. if not question or not sql:
  786. return jsonify(result.failed(
  787. message="question和sql参数都是必需的",
  788. code=400
  789. )), 400
  790. # 使用vn实例的train_error_sql方法存储错误SQL
  791. id = vn.train_error_sql(question=question, sql=sql)
  792. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  793. return jsonify(result.success(data={
  794. "id": id,
  795. "message": "错误SQL对已成功存储到error_sql集合"
  796. }))
  797. except Exception as e:
  798. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  799. return jsonify(result.failed(
  800. message=f"存储错误SQL失败: {str(e)}",
  801. code=500
  802. )), 500
  803. # 前端JavaScript示例 - 如何维持会话
  804. """
  805. // 前端需要维护一个会话ID
  806. class ChatSession {
  807. constructor() {
  808. // 从localStorage获取或创建新的会话ID
  809. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  810. localStorage.setItem('chat_session_id', this.sessionId);
  811. }
  812. generateSessionId() {
  813. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  814. }
  815. async askQuestion(question) {
  816. const response = await fetch('/api/v0/ask', {
  817. method: 'POST',
  818. headers: {
  819. 'Content-Type': 'application/json',
  820. },
  821. body: JSON.stringify({
  822. question: question,
  823. session_id: this.sessionId // 关键:传递会话ID
  824. })
  825. });
  826. return await response.json();
  827. }
  828. // 开始新会话
  829. startNewSession() {
  830. this.sessionId = this.generateSessionId();
  831. localStorage.setItem('chat_session_id', this.sessionId);
  832. }
  833. }
  834. // 使用示例
  835. const chatSession = new ChatSession();
  836. chatSession.askQuestion("各年龄段客户的流失率如何?");
  837. """
  838. print("正在启动Flask应用: http://localhost:8084")
  839. app.run(host="0.0.0.0", port=8084, debug=True)