citu_app.py 41 KB


  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS
  10. # 设置默认的最大返回行数
  11. DEFAULT_MAX_RETURN_ROWS = 200
  12. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  13. vn = create_vanna_instance()
  14. # 创建带时间戳的缓存
  15. timestamped_cache = WebSessionAwareMemoryCache()
  16. # 实例化 VannaFlaskApp,使用自定义缓存
  17. app = VannaFlaskApp(
  18. vn,
  19. cache=timestamped_cache, # 使用带时间戳的缓存
  20. title="辞图智能数据问答平台",
  21. logo = "https://www.citupro.com/img/logo-black-2.png",
  22. subtitle="让 AI 为你写 SQL",
  23. chart=False,
  24. allow_llm_to_see_data=True,
  25. ask_results_correct=True,
  26. followup_questions=True,
  27. debug=True
  28. )
  29. # 修改ask接口,支持前端传递session_id
  30. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  31. def ask_full():
  32. req = request.get_json(force=True)
  33. question = req.get("question", None)
  34. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  35. if not question:
  36. return jsonify(result.failed(message="未提供问题", code=400)), 400
  37. # 如果使用WebSessionAwareMemoryCache
  38. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  39. # 这里需要修改vanna的ask方法来支持传递session_id
  40. # 或者预先调用generate_id来建立会话关联
  41. conversation_id = app.cache.generate_id_with_browser_session(
  42. question=question,
  43. browser_session_id=browser_session_id
  44. )
  45. try:
  46. sql, df, _ = vn.ask(
  47. question=question,
  48. print_results=False,
  49. visualize=False,
  50. allow_llm_to_see_data=True
  51. )
  52. rows, columns = [], []
  53. summary = None
  54. if isinstance(df, pd.DataFrame) and not df.empty:
  55. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  56. columns = list(df.columns)
  57. # 生成数据摘要
  58. try:
  59. summary = vn.generate_summary(question=question, df=df)
  60. print(f"[INFO] 成功生成摘要: {summary}")
  61. except Exception as e:
  62. print(f"[WARNING] 生成摘要失败: {str(e)}")
  63. summary = None
  64. return jsonify(result.success(data={
  65. "sql": sql,
  66. "rows": rows,
  67. "columns": columns,
  68. "summary": summary, # 添加摘要到返回结果
  69. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  70. "session_id": browser_session_id
  71. }))
  72. except Exception as e:
  73. print(f"[ERROR] ask_full执行失败: {str(e)}")
  74. return jsonify(result.failed(
  75. message=f"查询处理失败: {str(e)}",
  76. code=500
  77. )), 500
  78. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  79. def citu_run_sql():
  80. req = request.get_json(force=True)
  81. sql = req.get('sql')
  82. if not sql:
  83. return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
  84. try:
  85. df = vn.run_sql(sql)
  86. rows, columns = [], []
  87. if isinstance(df, pd.DataFrame) and not df.empty:
  88. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  89. columns = list(df.columns)
  90. return jsonify(result.success(data={
  91. "sql": sql,
  92. "rows": rows,
  93. "columns": columns
  94. }))
  95. except Exception as e:
  96. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  97. return jsonify(result.failed(
  98. message=f"SQL执行失败: {str(e)}",
  99. code=500
  100. )), 500
  101. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  102. def ask_cached():
  103. """
  104. 带缓存功能的智能查询接口
  105. 支持会话管理和结果缓存,提高查询效率
  106. """
  107. req = request.get_json(force=True)
  108. question = req.get("question", None)
  109. browser_session_id = req.get("session_id", None)
  110. if not question:
  111. return jsonify(result.failed(message="未提供问题", code=400)), 400
  112. try:
  113. # 生成conversation_id
  114. # 调试:查看generate_id的实际行为
  115. print(f"[DEBUG] 输入问题: '{question}'")
  116. conversation_id = app.cache.generate_id(question=question)
  117. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  118. # 再次用相同问题测试
  119. conversation_id2 = app.cache.generate_id(question=question)
  120. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  121. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  122. # 检查缓存
  123. cached_sql = app.cache.get(id=conversation_id, field="sql")
  124. if cached_sql is not None:
  125. # 缓存命中
  126. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  127. sql = cached_sql
  128. df = app.cache.get(id=conversation_id, field="df")
  129. summary = app.cache.get(id=conversation_id, field="summary")
  130. else:
  131. # 缓存未命中,执行新查询
  132. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  133. sql, df, _ = vn.ask(
  134. question=question,
  135. print_results=False,
  136. visualize=False,
  137. allow_llm_to_see_data=True
  138. )
  139. # 缓存结果
  140. app.cache.set(id=conversation_id, field="question", value=question)
  141. app.cache.set(id=conversation_id, field="sql", value=sql)
  142. app.cache.set(id=conversation_id, field="df", value=df)
  143. # 生成并缓存摘要
  144. summary = None
  145. if isinstance(df, pd.DataFrame) and not df.empty:
  146. try:
  147. summary = vn.generate_summary(question=question, df=df)
  148. print(f"[INFO] 成功生成摘要: {summary}")
  149. except Exception as e:
  150. print(f"[WARNING] 生成摘要失败: {str(e)}")
  151. summary = None
  152. app.cache.set(id=conversation_id, field="summary", value=summary)
  153. # 处理返回数据
  154. rows, columns = [], []
  155. if isinstance(df, pd.DataFrame) and not df.empty:
  156. rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
  157. columns = list(df.columns)
  158. return jsonify(result.success(data={
  159. "sql": sql,
  160. "rows": rows,
  161. "columns": columns,
  162. "summary": summary,
  163. "conversation_id": conversation_id,
  164. "session_id": browser_session_id,
  165. "cached": cached_sql is not None # 标识是否来自缓存
  166. }))
  167. except Exception as e:
  168. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  169. return jsonify(result.failed(
  170. message=f"查询处理失败: {str(e)}",
  171. code=500
  172. )), 500
  173. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  174. def citu_train_question_sql():
  175. """
  176. 训练问题-SQL对接口
  177. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  178. 支持仅传入SQL或同时传入问题和SQL进行训练。
  179. Args:
  180. question (str, optional): 用户问题
  181. sql (str, required): 对应的SQL查询语句
  182. Returns:
  183. JSON: 包含训练ID和成功消息的响应
  184. """
  185. try:
  186. req = request.get_json(force=True)
  187. question = req.get('question')
  188. sql = req.get('sql')
  189. if not sql:
  190. return jsonify(result.failed(
  191. message="'sql' are required",
  192. code=400
  193. )), 400
  194. # 正确的调用方式:同时传递question和sql
  195. if question:
  196. training_id = vn.train(question=question, sql=sql)
  197. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  198. else:
  199. training_id = vn.train(sql=sql)
  200. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  201. return jsonify(result.success(data={
  202. "training_id": training_id,
  203. "message": "Question-SQL pair trained successfully"
  204. }))
  205. except Exception as e:
  206. return jsonify(result.failed(
  207. message=f"Training failed: {str(e)}",
  208. code=500
  209. )), 500
  210. # ==================== 日常管理API ====================
  211. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  212. def cache_overview():
  213. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  214. try:
  215. cache = app.cache
  216. result_data = {
  217. 'overview_summary': {
  218. 'total_conversations': 0,
  219. 'total_sessions': 0,
  220. 'query_time': datetime.now().isoformat()
  221. },
  222. 'recent_conversations': [], # 最近的对话
  223. 'session_summary': [] # 会话摘要
  224. }
  225. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  226. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  227. # 获取会话信息
  228. if hasattr(cache, 'get_all_sessions'):
  229. all_sessions = cache.get_all_sessions()
  230. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  231. # 会话摘要(按最近活动排序)
  232. session_list = []
  233. for session_id, session_data in all_sessions.items():
  234. session_summary = {
  235. 'session_id': session_id,
  236. 'start_time': session_data['start_time'].isoformat(),
  237. 'conversation_count': session_data.get('conversation_count', 0),
  238. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  239. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  240. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  241. }
  242. session_list.append(session_summary)
  243. # 按最后活动时间排序
  244. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  245. result_data['session_summary'] = session_list
  246. # 最近的对话(最多显示10个)
  247. conversation_list = []
  248. for conversation_id, conversation_data in cache.cache.items():
  249. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  250. conversation_info = {
  251. 'conversation_id': conversation_id,
  252. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  253. 'session_id': cache.conversation_to_session.get(conversation_id),
  254. 'has_question': 'question' in conversation_data,
  255. 'has_sql': 'sql' in conversation_data,
  256. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  257. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  258. }
  259. # 计算对话持续时间
  260. if conversation_start_time:
  261. duration = datetime.now() - conversation_start_time
  262. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  263. conversation_list.append(conversation_info)
  264. # 按对话开始时间排序,显示最新的10个
  265. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  266. result_data['recent_conversations'] = conversation_list[:10]
  267. return jsonify(result.success(data=result_data))
  268. except Exception as e:
  269. return jsonify(result.failed(
  270. message=f"获取缓存概览失败: {str(e)}",
  271. code=500
  272. )), 500
  273. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  274. def cache_stats():
  275. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  276. try:
  277. cache = app.cache
  278. current_time = datetime.now()
  279. stats = {
  280. 'basic_stats': {
  281. 'total_sessions': len(getattr(cache, 'session_info', {})),
  282. 'total_conversations': len(getattr(cache, 'cache', {})),
  283. 'active_sessions': 0, # 最近30分钟有活动
  284. 'average_conversations_per_session': 0
  285. },
  286. 'time_distribution': {
  287. 'sessions': {
  288. 'last_1_hour': 0,
  289. 'last_6_hours': 0,
  290. 'last_24_hours': 0,
  291. 'last_7_days': 0,
  292. 'older': 0
  293. },
  294. 'conversations': {
  295. 'last_1_hour': 0,
  296. 'last_6_hours': 0,
  297. 'last_24_hours': 0,
  298. 'last_7_days': 0,
  299. 'older': 0
  300. }
  301. },
  302. 'session_details': [],
  303. 'time_ranges': {
  304. 'oldest_session': None,
  305. 'newest_session': None,
  306. 'oldest_conversation': None,
  307. 'newest_conversation': None
  308. }
  309. }
  310. # 会话统计
  311. if hasattr(cache, 'session_info'):
  312. session_times = []
  313. total_conversations = 0
  314. for session_id, session_data in cache.session_info.items():
  315. start_time = session_data['start_time']
  316. session_times.append(start_time)
  317. conversation_count = len(session_data.get('conversations', []))
  318. total_conversations += conversation_count
  319. # 检查活跃状态
  320. last_activity = session_data.get('last_activity', session_data['start_time'])
  321. if (current_time - last_activity).total_seconds() < 1800:
  322. stats['basic_stats']['active_sessions'] += 1
  323. # 时间分布统计
  324. age_hours = (current_time - start_time).total_seconds() / 3600
  325. if age_hours <= 1:
  326. stats['time_distribution']['sessions']['last_1_hour'] += 1
  327. elif age_hours <= 6:
  328. stats['time_distribution']['sessions']['last_6_hours'] += 1
  329. elif age_hours <= 24:
  330. stats['time_distribution']['sessions']['last_24_hours'] += 1
  331. elif age_hours <= 168: # 7 days
  332. stats['time_distribution']['sessions']['last_7_days'] += 1
  333. else:
  334. stats['time_distribution']['sessions']['older'] += 1
  335. # 会话详细信息
  336. session_duration = current_time - start_time
  337. stats['session_details'].append({
  338. 'session_id': session_id,
  339. 'start_time': start_time.isoformat(),
  340. 'last_activity': last_activity.isoformat(),
  341. 'conversation_count': conversation_count,
  342. 'duration_seconds': session_duration.total_seconds(),
  343. 'duration_formatted': str(session_duration),
  344. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  345. 'browser_session_id': session_data.get('browser_session_id')
  346. })
  347. # 计算平均值
  348. if len(cache.session_info) > 0:
  349. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  350. # 时间范围
  351. if session_times:
  352. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  353. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  354. # 对话统计
  355. if hasattr(cache, 'conversation_start_times'):
  356. conversation_times = []
  357. for conv_time in cache.conversation_start_times.values():
  358. conversation_times.append(conv_time)
  359. age_hours = (current_time - conv_time).total_seconds() / 3600
  360. if age_hours <= 1:
  361. stats['time_distribution']['conversations']['last_1_hour'] += 1
  362. elif age_hours <= 6:
  363. stats['time_distribution']['conversations']['last_6_hours'] += 1
  364. elif age_hours <= 24:
  365. stats['time_distribution']['conversations']['last_24_hours'] += 1
  366. elif age_hours <= 168:
  367. stats['time_distribution']['conversations']['last_7_days'] += 1
  368. else:
  369. stats['time_distribution']['conversations']['older'] += 1
  370. if conversation_times:
  371. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  372. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  373. # 按最近活动排序会话详情
  374. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  375. return jsonify(result.success(data=stats))
  376. except Exception as e:
  377. return jsonify(result.failed(
  378. message=f"获取缓存统计失败: {str(e)}",
  379. code=500
  380. )), 500
  381. # ==================== 高级功能API ====================
  382. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  383. def cache_export():
  384. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  385. try:
  386. cache = app.cache
  387. # 验证缓存的实际结构
  388. if not hasattr(cache, 'cache'):
  389. return jsonify(result.failed(message="缓存对象没有cache属性", code=500)), 500
  390. if not isinstance(cache.cache, dict):
  391. return jsonify(result.failed(message="缓存不是字典类型", code=500)), 500
  392. # 定义JSON序列化辅助函数
  393. def make_json_serializable(obj):
  394. """将对象转换为JSON可序列化的格式"""
  395. if obj is None:
  396. return None
  397. elif isinstance(obj, (str, int, float, bool)):
  398. return obj
  399. elif isinstance(obj, (list, tuple)):
  400. return [make_json_serializable(item) for item in obj]
  401. elif isinstance(obj, dict):
  402. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  403. elif hasattr(obj, 'isoformat'): # datetime objects
  404. return obj.isoformat()
  405. elif hasattr(obj, 'item'): # numpy scalars
  406. return obj.item()
  407. elif hasattr(obj, 'tolist'): # numpy arrays
  408. return obj.tolist()
  409. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  410. return str(obj)
  411. else:
  412. return str(obj)
  413. # 获取完整的原始缓存数据
  414. raw_cache = cache.cache
  415. # 获取会话和对话时间信息
  416. conversation_times = getattr(cache, 'conversation_start_times', {})
  417. session_info = getattr(cache, 'session_info', {})
  418. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  419. export_data = {
  420. 'export_metadata': {
  421. 'export_time': datetime.now().isoformat(),
  422. 'total_conversations': len(raw_cache),
  423. 'total_sessions': len(session_info),
  424. 'cache_type': type(cache).__name__,
  425. 'cache_object_info': str(cache),
  426. 'has_session_times': bool(session_info),
  427. 'has_conversation_times': bool(conversation_times)
  428. },
  429. 'session_info': {
  430. session_id: {
  431. 'start_time': session_data['start_time'].isoformat(),
  432. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  433. 'conversations': session_data['conversations'],
  434. 'conversation_count': len(session_data['conversations']),
  435. 'browser_session_id': session_data.get('browser_session_id'),
  436. 'user_info': session_data.get('user_info', {})
  437. }
  438. for session_id, session_data in session_info.items()
  439. },
  440. 'conversation_times': {
  441. conversation_id: start_time.isoformat()
  442. for conversation_id, start_time in conversation_times.items()
  443. },
  444. 'conversation_to_session_mapping': conversation_to_session,
  445. 'conversations': {}
  446. }
  447. # 处理每个对话的完整数据
  448. for conversation_id, conversation_data in raw_cache.items():
  449. # 获取时间信息
  450. conversation_start_time = conversation_times.get(conversation_id)
  451. session_id = conversation_to_session.get(conversation_id)
  452. session_start_time = None
  453. if session_id and session_id in session_info:
  454. session_start_time = session_info[session_id]['start_time']
  455. processed_conversation = {
  456. 'conversation_id': conversation_id,
  457. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  458. 'session_id': session_id,
  459. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  460. 'field_count': len(conversation_data),
  461. 'fields': {}
  462. }
  463. # 添加时间计算
  464. if conversation_start_time:
  465. conversation_duration = datetime.now() - conversation_start_time
  466. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  467. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  468. if session_start_time:
  469. session_duration = datetime.now() - session_start_time
  470. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  471. processed_conversation['session_duration_formatted'] = str(session_duration)
  472. # 处理每个字段,确保JSON序列化安全
  473. for field_name, field_value in conversation_data.items():
  474. field_info = {
  475. 'field_name': field_name,
  476. 'data_type': type(field_value).__name__,
  477. 'is_none': field_value is None
  478. }
  479. try:
  480. if field_value is None:
  481. field_info['value'] = None
  482. elif field_name in ['conversation_start_time', 'session_start_time']:
  483. # 处理时间字段
  484. field_info['content'] = make_json_serializable(field_value)
  485. elif field_name == 'df' and field_value is not None:
  486. # DataFrame的安全处理
  487. if hasattr(field_value, 'to_dict'):
  488. # 安全地处理dtypes
  489. try:
  490. dtypes_dict = {}
  491. for col, dtype in field_value.dtypes.items():
  492. dtypes_dict[col] = str(dtype)
  493. except Exception:
  494. dtypes_dict = {"error": "无法序列化dtypes"}
  495. # 安全地处理内存使用
  496. try:
  497. memory_usage = field_value.memory_usage(deep=True)
  498. memory_dict = {}
  499. for idx, usage in memory_usage.items():
  500. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  501. except Exception:
  502. memory_dict = {"error": "无法获取内存使用信息"}
  503. field_info.update({
  504. 'dataframe_info': {
  505. 'shape': list(field_value.shape),
  506. 'columns': list(field_value.columns),
  507. 'dtypes': dtypes_dict,
  508. 'index_info': {
  509. 'type': type(field_value.index).__name__,
  510. 'length': len(field_value.index)
  511. }
  512. },
  513. 'data': make_json_serializable(field_value.to_dict('records')),
  514. 'memory_usage': memory_dict
  515. })
  516. else:
  517. field_info['value'] = str(field_value)
  518. field_info['note'] = 'not_standard_dataframe'
  519. elif field_name == 'fig_json':
  520. # 图表JSON数据处理
  521. if isinstance(field_value, str):
  522. try:
  523. import json
  524. parsed_fig = json.loads(field_value)
  525. field_info.update({
  526. 'json_valid': True,
  527. 'json_size_bytes': len(field_value),
  528. 'plotly_structure': {
  529. 'has_data': 'data' in parsed_fig,
  530. 'has_layout': 'layout' in parsed_fig,
  531. 'data_traces_count': len(parsed_fig.get('data', [])),
  532. },
  533. 'raw_json': field_value
  534. })
  535. except json.JSONDecodeError:
  536. field_info.update({
  537. 'json_valid': False,
  538. 'raw_content': str(field_value)
  539. })
  540. else:
  541. field_info['value'] = make_json_serializable(field_value)
  542. elif field_name == 'followup_questions':
  543. # 后续问题列表
  544. field_info.update({
  545. 'content': make_json_serializable(field_value)
  546. })
  547. elif field_name in ['question', 'sql', 'summary']:
  548. # 文本字段
  549. if isinstance(field_value, str):
  550. field_info.update({
  551. 'text_length': len(field_value),
  552. 'content': field_value
  553. })
  554. else:
  555. field_info['value'] = make_json_serializable(field_value)
  556. else:
  557. # 未知字段的安全处理
  558. field_info['content'] = make_json_serializable(field_value)
  559. except Exception as e:
  560. field_info.update({
  561. 'processing_error': str(e),
  562. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  563. })
  564. processed_conversation['fields'][field_name] = field_info
  565. export_data['conversations'][conversation_id] = processed_conversation
  566. # 添加缓存统计信息
  567. field_frequency = {}
  568. data_types_found = set()
  569. total_dataframes = 0
  570. total_questions = 0
  571. for conv_data in export_data['conversations'].values():
  572. for field_name, field_info in conv_data['fields'].items():
  573. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  574. data_types_found.add(field_info['data_type'])
  575. if field_name == 'df' and not field_info['is_none']:
  576. total_dataframes += 1
  577. if field_name == 'question' and not field_info['is_none']:
  578. total_questions += 1
  579. export_data['cache_statistics'] = {
  580. 'field_frequency': field_frequency,
  581. 'data_types_found': list(data_types_found),
  582. 'total_dataframes': total_dataframes,
  583. 'total_questions': total_questions,
  584. 'has_session_timing': 'session_start_time' in field_frequency,
  585. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  586. }
  587. return jsonify(result.success(data=export_data))
  588. except Exception as e:
  589. import traceback
  590. error_details = {
  591. 'error_message': str(e),
  592. 'error_type': type(e).__name__,
  593. 'traceback': traceback.format_exc()
  594. }
  595. return jsonify(result.failed(
  596. message=f"导出缓存失败: {str(e)}",
  597. code=500,
  598. data=error_details
  599. )), 500
  600. # ==================== 清理功能API ====================
  601. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  602. def cache_preview_cleanup():
  603. """清理功能:预览删除操作 - 保持原功能"""
  604. try:
  605. req = request.get_json(force=True)
  606. # 时间条件 - 支持三种方式
  607. older_than_hours = req.get('older_than_hours')
  608. older_than_days = req.get('older_than_days')
  609. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  610. cache = app.cache
  611. # 计算截止时间
  612. cutoff_time = None
  613. time_condition = None
  614. if older_than_hours:
  615. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  616. time_condition = f"older_than_hours: {older_than_hours}"
  617. elif older_than_days:
  618. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  619. time_condition = f"older_than_days: {older_than_days}"
  620. elif before_timestamp:
  621. try:
  622. # 支持 YYYY-MM-DD HH:MM:SS 格式
  623. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  624. time_condition = f"before_timestamp: {before_timestamp}"
  625. except ValueError:
  626. return jsonify(result.failed(
  627. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  628. code=400
  629. )), 400
  630. else:
  631. return jsonify(result.failed(
  632. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  633. code=400
  634. )), 400
  635. preview = {
  636. 'time_condition': time_condition,
  637. 'cutoff_time': cutoff_time.isoformat(),
  638. 'will_be_removed': {
  639. 'sessions': []
  640. },
  641. 'will_be_kept': {
  642. 'sessions_count': 0,
  643. 'conversations_count': 0
  644. },
  645. 'summary': {
  646. 'sessions_to_remove': 0,
  647. 'conversations_to_remove': 0,
  648. 'sessions_to_keep': 0,
  649. 'conversations_to_keep': 0
  650. }
  651. }
  652. # 预览按session删除
  653. sessions_to_remove_count = 0
  654. conversations_to_remove_count = 0
  655. for session_id, session_data in cache.session_info.items():
  656. session_preview = {
  657. 'session_id': session_id,
  658. 'start_time': session_data['start_time'].isoformat(),
  659. 'conversation_count': len(session_data['conversations']),
  660. 'conversations': []
  661. }
  662. # 添加conversation详情
  663. for conv_id in session_data['conversations']:
  664. if conv_id in cache.cache:
  665. conv_data = cache.cache[conv_id]
  666. session_preview['conversations'].append({
  667. 'conversation_id': conv_id,
  668. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  669. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  670. })
  671. if session_data['start_time'] < cutoff_time:
  672. preview['will_be_removed']['sessions'].append(session_preview)
  673. sessions_to_remove_count += 1
  674. conversations_to_remove_count += len(session_data['conversations'])
  675. else:
  676. preview['will_be_kept']['sessions_count'] += 1
  677. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  678. # 更新摘要统计
  679. preview['summary'] = {
  680. 'sessions_to_remove': sessions_to_remove_count,
  681. 'conversations_to_remove': conversations_to_remove_count,
  682. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  683. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  684. }
  685. return jsonify(result.success(data=preview))
  686. except Exception as e:
  687. return jsonify(result.failed(
  688. message=f"预览清理操作失败: {str(e)}",
  689. code=500
  690. )), 500
  691. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  692. def cache_cleanup():
  693. """清理功能:实际删除缓存 - 保持原功能"""
  694. try:
  695. req = request.get_json(force=True)
  696. # 时间条件 - 支持三种方式
  697. older_than_hours = req.get('older_than_hours')
  698. older_than_days = req.get('older_than_days')
  699. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  700. cache = app.cache
  701. if not hasattr(cache, 'session_info'):
  702. return jsonify(result.failed(
  703. message="缓存不支持会话功能",
  704. code=400
  705. )), 400
  706. # 计算截止时间
  707. cutoff_time = None
  708. time_condition = None
  709. if older_than_hours:
  710. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  711. time_condition = f"older_than_hours: {older_than_hours}"
  712. elif older_than_days:
  713. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  714. time_condition = f"older_than_days: {older_than_days}"
  715. elif before_timestamp:
  716. try:
  717. # 支持 YYYY-MM-DD HH:MM:SS 格式
  718. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  719. time_condition = f"before_timestamp: {before_timestamp}"
  720. except ValueError:
  721. return jsonify(result.failed(
  722. message="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式",
  723. code=400
  724. )), 400
  725. else:
  726. return jsonify(result.failed(
  727. message="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  728. code=400
  729. )), 400
  730. cleanup_stats = {
  731. 'time_condition': time_condition,
  732. 'cutoff_time': cutoff_time.isoformat(),
  733. 'sessions_removed': 0,
  734. 'conversations_removed': 0,
  735. 'sessions_kept': 0,
  736. 'conversations_kept': 0,
  737. 'removed_session_ids': [],
  738. 'removed_conversation_ids': []
  739. }
  740. # 按session删除
  741. sessions_to_remove = []
  742. for session_id, session_data in cache.session_info.items():
  743. if session_data['start_time'] < cutoff_time:
  744. sessions_to_remove.append(session_id)
  745. # 删除符合条件的sessions及其所有conversations
  746. for session_id in sessions_to_remove:
  747. session_data = cache.session_info[session_id]
  748. conversations_in_session = session_data['conversations'].copy()
  749. # 删除session中的所有conversations
  750. for conv_id in conversations_in_session:
  751. if conv_id in cache.cache:
  752. del cache.cache[conv_id]
  753. cleanup_stats['conversations_removed'] += 1
  754. cleanup_stats['removed_conversation_ids'].append(conv_id)
  755. # 清理conversation相关的时间记录
  756. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  757. del cache.conversation_start_times[conv_id]
  758. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  759. del cache.conversation_to_session[conv_id]
  760. # 删除session记录
  761. del cache.session_info[session_id]
  762. cleanup_stats['sessions_removed'] += 1
  763. cleanup_stats['removed_session_ids'].append(session_id)
  764. # 统计保留的sessions和conversations
  765. cleanup_stats['sessions_kept'] = len(cache.session_info)
  766. cleanup_stats['conversations_kept'] = len(cache.cache)
  767. return jsonify(result.success(data=cleanup_stats))
  768. except Exception as e:
  769. return jsonify(result.failed(
  770. message=f"清理缓存失败: {str(e)}",
  771. code=500
  772. )), 500
  773. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  774. def training_error_question_sql():
  775. """
  776. 存储错误的question-sql对到error_sql集合中
  777. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  778. Args:
  779. question (str, required): 用户问题
  780. sql (str, required): 对应的错误SQL查询语句
  781. Returns:
  782. JSON: 包含训练ID和成功消息的响应
  783. """
  784. try:
  785. data = request.get_json()
  786. question = data.get('question')
  787. sql = data.get('sql')
  788. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  789. if not question or not sql:
  790. return jsonify(result.failed(
  791. message="question和sql参数都是必需的",
  792. code=400
  793. )), 400
  794. # 使用vn实例的train_error_sql方法存储错误SQL
  795. id = vn.train_error_sql(question=question, sql=sql)
  796. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  797. return jsonify(result.success(data={
  798. "id": id,
  799. "message": "错误SQL对已成功存储到error_sql集合"
  800. }))
  801. except Exception as e:
  802. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  803. return jsonify(result.failed(
  804. message=f"存储错误SQL失败: {str(e)}",
  805. code=500
  806. )), 500
  807. # 前端JavaScript示例 - 如何维持会话
  808. """
  809. // 前端需要维护一个会话ID
  810. class ChatSession {
  811. constructor() {
  812. // 从localStorage获取或创建新的会话ID
  813. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  814. localStorage.setItem('chat_session_id', this.sessionId);
  815. }
  816. generateSessionId() {
  817. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  818. }
  819. async askQuestion(question) {
  820. const response = await fetch('/api/v0/ask', {
  821. method: 'POST',
  822. headers: {
  823. 'Content-Type': 'application/json',
  824. },
  825. body: JSON.stringify({
  826. question: question,
  827. session_id: this.sessionId // 关键:传递会话ID
  828. })
  829. });
  830. return await response.json();
  831. }
  832. // 开始新会话
  833. startNewSession() {
  834. this.sessionId = this.generateSessionId();
  835. localStorage.setItem('chat_session_id', this.sessionId);
  836. }
  837. }
  838. // 使用示例
  839. const chatSession = new ChatSession();
  840. chatSession.askQuestion("各年龄段客户的流失率如何?");
  841. """
  842. print("正在启动Flask应用: http://localhost:8084")
  843. app.run(host="0.0.0.0", port=8084, debug=True)