citu_app.py 77 KB


  1. # 给dataops 对话助手返回结果
  2. from vanna.flask import VannaFlaskApp
  3. from core.vanna_llm_factory import create_vanna_instance
  4. from flask import request, jsonify
  5. import pandas as pd
  6. import common.result as result
  7. from datetime import datetime, timedelta
  8. from common.session_aware_cache import WebSessionAwareMemoryCache
  9. from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
  10. import re
  11. import chainlit as cl
  12. import json
  13. from flask import session # 添加session导入
  14. from common.redis_conversation_manager import RedisConversationManager # 添加Redis对话管理器导入
  15. from common.result import ( # 统一导入所有需要的响应函数
  16. bad_request_response, service_unavailable_response,
  17. agent_success_response, agent_error_response,
  18. internal_error_response, success_response,
  19. validation_failed_response
  20. )
  21. from app_config import ( # 添加Redis相关配置导入
  22. USER_MAX_CONVERSATIONS,
  23. CONVERSATION_CONTEXT_COUNT,
  24. DEFAULT_ANONYMOUS_USER,
  25. ENABLE_QUESTION_ANSWER_CACHE
  26. )
  27. # 设置默认的最大返回行数
  28. DEFAULT_MAX_RETURN_ROWS = 200
  29. MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  30. vn = create_vanna_instance()
  31. # 创建带时间戳的缓存
  32. timestamped_cache = WebSessionAwareMemoryCache()
  33. # 实例化 VannaFlaskApp,使用自定义缓存
  34. app = VannaFlaskApp(
  35. vn,
  36. cache=timestamped_cache, # 使用带时间戳的缓存
  37. title="辞图智能数据问答平台",
  38. logo = "https://www.citupro.com/img/logo-black-2.png",
  39. subtitle="让 AI 为你写 SQL",
  40. chart=False,
  41. allow_llm_to_see_data=True,
  42. ask_results_correct=True,
  43. followup_questions=True,
  44. debug=True
  45. )
  46. # 创建Redis对话管理器实例
  47. redis_conversation_manager = RedisConversationManager()
  48. # 修改ask接口,支持前端传递session_id
  49. @app.flask_app.route('/api/v0/ask', methods=['POST'])
  50. def ask_full():
  51. req = request.get_json(force=True)
  52. question = req.get("question", None)
  53. browser_session_id = req.get("session_id", None) # 前端传递的会话ID
  54. if not question:
  55. from common.result import bad_request_response
  56. return jsonify(bad_request_response(
  57. response_text="缺少必需参数:question",
  58. missing_params=["question"]
  59. )), 400
  60. # 如果使用WebSessionAwareMemoryCache
  61. if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
  62. # 这里需要修改vanna的ask方法来支持传递session_id
  63. # 或者预先调用generate_id来建立会话关联
  64. conversation_id = app.cache.generate_id_with_browser_session(
  65. question=question,
  66. browser_session_id=browser_session_id
  67. )
  68. try:
  69. sql, df, _ = vn.ask(
  70. question=question,
  71. print_results=False,
  72. visualize=False,
  73. allow_llm_to_see_data=True
  74. )
  75. # 关键:检查是否有LLM解释性文本(无法生成SQL的情况)
  76. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  77. # 在解释性文本末尾添加提示语
  78. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  79. # 使用标准化错误响应
  80. from common.result import validation_failed_response
  81. return jsonify(validation_failed_response(
  82. response_text=explanation_message
  83. )), 422 # 修改HTTP状态码为422
  84. # 如果sql为None但没有解释性文本,返回通用错误
  85. if sql is None:
  86. from common.result import validation_failed_response
  87. return jsonify(validation_failed_response(
  88. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  89. )), 422
  90. # 处理返回数据 - 使用新的query_result结构
  91. query_result = {
  92. "rows": [],
  93. "columns": [],
  94. "row_count": 0,
  95. "is_limited": False,
  96. "total_row_count": 0
  97. }
  98. summary = None
  99. if isinstance(df, pd.DataFrame):
  100. query_result["columns"] = list(df.columns)
  101. if not df.empty:
  102. total_rows = len(df)
  103. limited_df = df.head(MAX_RETURN_ROWS)
  104. query_result["rows"] = limited_df.to_dict(orient="records")
  105. query_result["row_count"] = len(limited_df)
  106. query_result["total_row_count"] = total_rows
  107. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  108. # 生成数据摘要(可通过配置控制,仅在有数据时生成)
  109. if ENABLE_RESULT_SUMMARY:
  110. try:
  111. summary = vn.generate_summary(question=question, df=df)
  112. print(f"[INFO] 成功生成摘要: {summary}")
  113. except Exception as e:
  114. print(f"[WARNING] 生成摘要失败: {str(e)}")
  115. summary = None
  116. # 构建返回数据
  117. response_data = {
  118. "sql": sql,
  119. "query_result": query_result,
  120. "conversation_id": conversation_id if 'conversation_id' in locals() else None,
  121. "session_id": browser_session_id
  122. }
  123. # 添加摘要(如果启用且生成成功)
  124. if ENABLE_RESULT_SUMMARY and summary is not None:
  125. response_data["summary"] = summary
  126. response_data["response"] = summary # 同时添加response字段
  127. from common.result import success_response
  128. return jsonify(success_response(
  129. response_text="查询执行完成" if summary is None else None,
  130. data=response_data
  131. ))
  132. except Exception as e:
  133. print(f"[ERROR] ask_full执行失败: {str(e)}")
  134. # 即使发生异常,也检查是否有业务层面的解释
  135. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  136. # 在解释性文本末尾添加提示语
  137. explanation_message = vn.last_llm_explanation + "请尝试提问其它问题。"
  138. from common.result import validation_failed_response
  139. return jsonify(validation_failed_response(
  140. response_text=explanation_message
  141. )), 422
  142. else:
  143. # 技术错误,使用500错误码
  144. from common.result import internal_error_response
  145. return jsonify(internal_error_response(
  146. response_text="查询处理失败,请稍后重试"
  147. )), 500
  148. @app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
  149. def citu_run_sql():
  150. req = request.get_json(force=True)
  151. sql = req.get('sql')
  152. if not sql:
  153. from common.result import bad_request_response
  154. return jsonify(bad_request_response(
  155. response_text="缺少必需参数:sql",
  156. missing_params=["sql"]
  157. )), 400
  158. try:
  159. df = vn.run_sql(sql)
  160. # 处理返回数据 - 使用新的query_result结构
  161. query_result = {
  162. "rows": [],
  163. "columns": [],
  164. "row_count": 0,
  165. "is_limited": False,
  166. "total_row_count": 0
  167. }
  168. if isinstance(df, pd.DataFrame):
  169. query_result["columns"] = list(df.columns)
  170. if not df.empty:
  171. total_rows = len(df)
  172. limited_df = df.head(MAX_RETURN_ROWS)
  173. query_result["rows"] = limited_df.to_dict(orient="records")
  174. query_result["row_count"] = len(limited_df)
  175. query_result["total_row_count"] = total_rows
  176. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  177. from common.result import success_response
  178. return jsonify(success_response(
  179. response_text=f"SQL执行完成,共返回 {query_result['total_row_count']} 条记录" +
  180. (f",已限制显示前 {MAX_RETURN_ROWS} 条" if query_result["is_limited"] else ""),
  181. data={
  182. "sql": sql,
  183. "query_result": query_result
  184. }
  185. ))
  186. except Exception as e:
  187. print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
  188. from common.result import internal_error_response
  189. return jsonify(internal_error_response(
  190. response_text=f"SQL执行失败,请检查SQL语句是否正确"
  191. )), 500
  192. @app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
  193. def ask_cached():
  194. """
  195. 带缓存功能的智能查询接口
  196. 支持会话管理和结果缓存,提高查询效率
  197. """
  198. req = request.get_json(force=True)
  199. question = req.get("question", None)
  200. browser_session_id = req.get("session_id", None)
  201. if not question:
  202. from common.result import bad_request_response
  203. return jsonify(bad_request_response(
  204. response_text="缺少必需参数:question",
  205. missing_params=["question"]
  206. )), 400
  207. try:
  208. # 生成conversation_id
  209. # 调试:查看generate_id的实际行为
  210. print(f"[DEBUG] 输入问题: '{question}'")
  211. conversation_id = app.cache.generate_id(question=question)
  212. print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
  213. # 再次用相同问题测试
  214. conversation_id2 = app.cache.generate_id(question=question)
  215. print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
  216. print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
  217. # 检查缓存
  218. cached_sql = app.cache.get(id=conversation_id, field="sql")
  219. if cached_sql is not None:
  220. # 缓存命中
  221. print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
  222. sql = cached_sql
  223. df = app.cache.get(id=conversation_id, field="df")
  224. summary = app.cache.get(id=conversation_id, field="summary")
  225. else:
  226. # 缓存未命中,执行新查询
  227. print(f"[CACHE MISS] 执行新查询: {conversation_id}")
  228. sql, df, _ = vn.ask(
  229. question=question,
  230. print_results=False,
  231. visualize=False,
  232. allow_llm_to_see_data=True
  233. )
  234. # 检查是否有LLM解释性文本(无法生成SQL的情况)
  235. if sql is None and hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  236. # 在解释性文本末尾添加提示语
  237. explanation_message = vn.last_llm_explanation + "请尝试用其它方式提问。"
  238. from common.result import validation_failed_response
  239. return jsonify(validation_failed_response(
  240. response_text=explanation_message
  241. )), 422
  242. # 如果sql为None但没有解释性文本,返回通用错误
  243. if sql is None:
  244. from common.result import validation_failed_response
  245. return jsonify(validation_failed_response(
  246. response_text="无法生成SQL查询,请检查问题描述或数据表结构"
  247. )), 422
  248. # 缓存结果
  249. app.cache.set(id=conversation_id, field="question", value=question)
  250. app.cache.set(id=conversation_id, field="sql", value=sql)
  251. app.cache.set(id=conversation_id, field="df", value=df)
  252. # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
  253. summary = None
  254. if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
  255. try:
  256. summary = vn.generate_summary(question=question, df=df)
  257. print(f"[INFO] 成功生成摘要: {summary}")
  258. except Exception as e:
  259. print(f"[WARNING] 生成摘要失败: {str(e)}")
  260. summary = None
  261. app.cache.set(id=conversation_id, field="summary", value=summary)
  262. # 处理返回数据 - 使用新的query_result结构
  263. query_result = {
  264. "rows": [],
  265. "columns": [],
  266. "row_count": 0,
  267. "is_limited": False,
  268. "total_row_count": 0
  269. }
  270. if isinstance(df, pd.DataFrame):
  271. query_result["columns"] = list(df.columns)
  272. if not df.empty:
  273. total_rows = len(df)
  274. limited_df = df.head(MAX_RETURN_ROWS)
  275. query_result["rows"] = limited_df.to_dict(orient="records")
  276. query_result["row_count"] = len(limited_df)
  277. query_result["total_row_count"] = total_rows
  278. query_result["is_limited"] = total_rows > MAX_RETURN_ROWS
  279. # 构建返回数据
  280. response_data = {
  281. "sql": sql,
  282. "query_result": query_result,
  283. "conversation_id": conversation_id,
  284. "session_id": browser_session_id,
  285. "cached": cached_sql is not None # 标识是否来自缓存
  286. }
  287. # 添加摘要(如果启用且生成成功)
  288. if ENABLE_RESULT_SUMMARY and summary is not None:
  289. response_data["summary"] = summary
  290. response_data["response"] = summary # 同时添加response字段
  291. from common.result import success_response
  292. return jsonify(success_response(
  293. response_text="查询执行完成" if summary is None else None,
  294. data=response_data
  295. ))
  296. except Exception as e:
  297. print(f"[ERROR] ask_cached执行失败: {str(e)}")
  298. from common.result import internal_error_response
  299. return jsonify(internal_error_response(
  300. response_text="查询处理失败,请稍后重试"
  301. )), 500
  302. @app.flask_app.route('/api/v0/citu_train_question_sql', methods=['POST'])
  303. def citu_train_question_sql():
  304. """
  305. 训练问题-SQL对接口
  306. 此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
  307. 支持仅传入SQL或同时传入问题和SQL进行训练。
  308. Args:
  309. question (str, optional): 用户问题
  310. sql (str, required): 对应的SQL查询语句
  311. Returns:
  312. JSON: 包含训练ID和成功消息的响应
  313. """
  314. try:
  315. req = request.get_json(force=True)
  316. question = req.get('question')
  317. sql = req.get('sql')
  318. if not sql:
  319. from common.result import bad_request_response
  320. return jsonify(bad_request_response(
  321. response_text="缺少必需参数:sql",
  322. missing_params=["sql"]
  323. )), 400
  324. # 正确的调用方式:同时传递question和sql
  325. if question:
  326. training_id = vn.train(question=question, sql=sql)
  327. print(f"训练成功,训练ID为:{training_id},问题:{question},SQL:{sql}")
  328. else:
  329. training_id = vn.train(sql=sql)
  330. print(f"训练成功,训练ID为:{training_id},SQL:{sql}")
  331. from common.result import success_response
  332. return jsonify(success_response(
  333. response_text="问题-SQL对训练成功",
  334. data={
  335. "training_id": training_id,
  336. "message": "Question-SQL pair trained successfully"
  337. }
  338. ))
  339. except Exception as e:
  340. from common.result import internal_error_response
  341. return jsonify(internal_error_response(
  342. response_text="训练失败,请稍后重试"
  343. )), 500
  344. # ============ LangGraph Agent 集成 ============
  345. # 全局Agent实例(单例模式)
  346. citu_langraph_agent = None
  347. def get_citu_langraph_agent():
  348. """获取LangGraph Agent实例(懒加载)"""
  349. global citu_langraph_agent
  350. if citu_langraph_agent is None:
  351. try:
  352. from agent.citu_agent import CituLangGraphAgent
  353. print("[CITU_APP] 开始创建LangGraph Agent实例...")
  354. citu_langraph_agent = CituLangGraphAgent()
  355. print("[CITU_APP] LangGraph Agent实例创建成功")
  356. except ImportError as e:
  357. print(f"[CRITICAL] Agent模块导入失败: {str(e)}")
  358. print("[CRITICAL] 请检查agent模块是否存在以及依赖是否正确安装")
  359. raise Exception(f"Agent模块导入失败: {str(e)}")
  360. except Exception as e:
  361. print(f"[CRITICAL] LangGraph Agent实例创建失败: {str(e)}")
  362. print(f"[CRITICAL] 错误类型: {type(e).__name__}")
  363. # 提供更有用的错误信息
  364. if "config" in str(e).lower():
  365. print("[CRITICAL] 可能是配置文件问题,请检查配置")
  366. elif "llm" in str(e).lower():
  367. print("[CRITICAL] 可能是LLM连接问题,请检查LLM配置")
  368. elif "tool" in str(e).lower():
  369. print("[CRITICAL] 可能是工具加载问题,请检查工具模块")
  370. raise Exception(f"Agent初始化失败: {str(e)}")
  371. return citu_langraph_agent
  372. @app.flask_app.route('/api/v0/ask_agent', methods=['POST'])
  373. def ask_agent():
  374. """
  375. 支持对话上下文的ask_agent API - 修正版
  376. """
  377. req = request.get_json(force=True)
  378. question = req.get("question", None)
  379. browser_session_id = req.get("session_id", None)
  380. # 新增参数解析
  381. user_id_input = req.get("user_id", None)
  382. conversation_id_input = req.get("conversation_id", None)
  383. continue_conversation = req.get("continue_conversation", False)
  384. # 新增:路由模式参数解析和验证
  385. api_routing_mode = req.get("routing_mode", None)
  386. VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
  387. if not question:
  388. return jsonify(bad_request_response(
  389. response_text="缺少必需参数:question",
  390. missing_params=["question"]
  391. )), 400
  392. # 验证routing_mode参数
  393. if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
  394. return jsonify(bad_request_response(
  395. response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
  396. invalid_params=["routing_mode"]
  397. )), 400
  398. try:
  399. # 1. 获取登录用户ID(修正:在函数中获取session信息)
  400. login_user_id = session.get('user_id') if 'user_id' in session else None
  401. # 2. 智能ID解析(修正:传入登录用户ID)
  402. user_id = redis_conversation_manager.resolve_user_id(
  403. user_id_input, browser_session_id, request.remote_addr, login_user_id
  404. )
  405. conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
  406. user_id, conversation_id_input, continue_conversation
  407. )
  408. # 3. 获取上下文和上下文类型(提前到缓存检查之前)
  409. context = redis_conversation_manager.get_context(conversation_id)
  410. # 获取上下文类型:从最后一条助手消息的metadata中获取类型
  411. context_type = None
  412. if context:
  413. try:
  414. # 获取最后一条助手消息的metadata
  415. messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
  416. for message in reversed(messages): # 从最新的开始找
  417. if message.get("role") == "assistant":
  418. metadata = message.get("metadata", {})
  419. context_type = metadata.get("type")
  420. if context_type:
  421. print(f"[AGENT_API] 检测到上下文类型: {context_type}")
  422. break
  423. except Exception as e:
  424. print(f"[WARNING] 获取上下文类型失败: {str(e)}")
  425. # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
  426. cached_answer = redis_conversation_manager.get_cached_answer(question, context)
  427. if cached_answer:
  428. print(f"[AGENT_API] 使用缓存答案")
  429. # 确定缓存答案的助手回复内容(使用与非缓存相同的优先级逻辑)
  430. cached_response_type = cached_answer.get("type", "UNKNOWN")
  431. if cached_response_type == "DATABASE":
  432. # DATABASE类型:按优先级选择内容
  433. if cached_answer.get("response"):
  434. # 优先级1:错误或解释性回复(如SQL生成失败)
  435. assistant_response = cached_answer.get("response")
  436. elif cached_answer.get("summary"):
  437. # 优先级2:查询成功的摘要
  438. assistant_response = cached_answer.get("summary")
  439. elif cached_answer.get("query_result"):
  440. # 优先级3:构造简单描述
  441. query_result = cached_answer.get("query_result")
  442. row_count = query_result.get("row_count", 0)
  443. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  444. else:
  445. # 异常情况
  446. assistant_response = "数据库查询已处理。"
  447. else:
  448. # CHAT类型:直接使用response
  449. assistant_response = cached_answer.get("response", "")
  450. # 更新对话历史
  451. redis_conversation_manager.save_message(conversation_id, "user", question)
  452. redis_conversation_manager.save_message(
  453. conversation_id, "assistant",
  454. assistant_response,
  455. metadata={"from_cache": True}
  456. )
  457. # 添加对话信息到缓存结果
  458. cached_answer["conversation_id"] = conversation_id
  459. cached_answer["user_id"] = user_id
  460. cached_answer["from_cache"] = True
  461. cached_answer.update(conversation_status)
  462. # 使用agent_success_response返回标准格式
  463. return jsonify(agent_success_response(
  464. response_type=cached_answer.get("type", "UNKNOWN"),
  465. response=cached_answer.get("response", ""), # 修正:使用response而不是response_text
  466. sql=cached_answer.get("sql"),
  467. query_result=cached_answer.get("query_result"),
  468. summary=cached_answer.get("summary"),
  469. session_id=browser_session_id,
  470. execution_path=cached_answer.get("execution_path", []),
  471. classification_info=cached_answer.get("classification_info", {}),
  472. conversation_id=conversation_id,
  473. user_id=user_id,
  474. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  475. context_used=bool(context),
  476. from_cache=True,
  477. conversation_status=conversation_status["status"],
  478. conversation_message=conversation_status["message"],
  479. requested_conversation_id=conversation_status.get("requested_id")
  480. ))
  481. # 5. 保存用户消息
  482. redis_conversation_manager.save_message(conversation_id, "user", question)
  483. # 6. 构建带上下文的问题
  484. if context:
  485. enhanced_question = f"\n[CONTEXT]\n{context}\n\n[CURRENT]\n{question}"
  486. print(f"[AGENT_API] 使用上下文,长度: {len(context)}字符")
  487. else:
  488. enhanced_question = question
  489. print(f"[AGENT_API] 新对话,无上下文")
  490. # 7. 确定最终使用的路由模式(优先级逻辑)
  491. if api_routing_mode:
  492. # API传了参数,优先使用
  493. effective_routing_mode = api_routing_mode
  494. print(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
  495. else:
  496. # API没传参数,使用配置文件
  497. try:
  498. from app_config import QUESTION_ROUTING_MODE
  499. effective_routing_mode = QUESTION_ROUTING_MODE
  500. print(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
  501. except ImportError:
  502. effective_routing_mode = "hybrid"
  503. print(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
  504. # 8. 现有Agent处理逻辑(修改为传递路由模式)
  505. try:
  506. agent = get_citu_langraph_agent()
  507. except Exception as e:
  508. print(f"[CRITICAL] Agent初始化失败: {str(e)}")
  509. return jsonify(service_unavailable_response(
  510. response_text="AI服务暂时不可用,请稍后重试",
  511. can_retry=True
  512. )), 503
  513. agent_result = agent.process_question(
  514. question=enhanced_question, # 使用增强后的问题
  515. session_id=browser_session_id,
  516. context_type=context_type, # 传递上下文类型
  517. routing_mode=effective_routing_mode # 新增:传递路由模式
  518. )
  519. # 8. 处理Agent结果
  520. if agent_result.get("success", False):
  521. # 修正:直接从agent_result获取字段,因为它就是final_response
  522. response_type = agent_result.get("type", "UNKNOWN")
  523. response_text = agent_result.get("response", "")
  524. sql = agent_result.get("sql")
  525. query_result = agent_result.get("query_result")
  526. summary = agent_result.get("summary")
  527. execution_path = agent_result.get("execution_path", [])
  528. classification_info = agent_result.get("classification_info", {})
  529. # 确定助手回复内容的优先级
  530. if response_type == "DATABASE":
  531. # DATABASE类型:按优先级选择内容
  532. if response_text:
  533. # 优先级1:错误或解释性回复(如SQL生成失败)
  534. assistant_response = response_text
  535. elif summary:
  536. # 优先级2:查询成功的摘要
  537. assistant_response = summary
  538. elif query_result:
  539. # 优先级3:构造简单描述
  540. row_count = query_result.get("row_count", 0)
  541. assistant_response = f"查询执行完成,共返回 {row_count} 条记录。"
  542. else:
  543. # 异常情况
  544. assistant_response = "数据库查询已处理。"
  545. else:
  546. # CHAT类型:直接使用response
  547. assistant_response = response_text
  548. # 保存助手回复
  549. redis_conversation_manager.save_message(
  550. conversation_id, "assistant", assistant_response,
  551. metadata={
  552. "type": response_type,
  553. "sql": sql,
  554. "execution_path": execution_path
  555. }
  556. )
  557. # 缓存成功的答案(新逻辑:只缓存无上下文的问答)
  558. # 直接缓存agent_result,它已经包含所有需要的字段
  559. redis_conversation_manager.cache_answer(question, agent_result, context)
  560. # 使用agent_success_response的正确方式
  561. return jsonify(agent_success_response(
  562. response_type=response_type,
  563. response=response_text, # 修正:使用response而不是response_text
  564. sql=sql,
  565. query_result=query_result,
  566. summary=summary,
  567. session_id=browser_session_id,
  568. execution_path=execution_path,
  569. classification_info=classification_info,
  570. conversation_id=conversation_id,
  571. user_id=user_id,
  572. is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
  573. context_used=bool(context),
  574. from_cache=False,
  575. conversation_status=conversation_status["status"],
  576. conversation_message=conversation_status["message"],
  577. requested_conversation_id=conversation_status.get("requested_id"),
  578. routing_mode_used=effective_routing_mode, # 新增:实际使用的路由模式
  579. routing_mode_source="api" if api_routing_mode else "config" # 新增:路由模式来源
  580. ))
  581. else:
  582. # 错误处理(修正:确保使用现有的错误响应格式)
  583. error_message = agent_result.get("error", "Agent处理失败")
  584. error_code = agent_result.get("error_code", 500)
  585. return jsonify(agent_error_response(
  586. response_text=error_message,
  587. error_type="agent_processing_failed",
  588. code=error_code,
  589. session_id=browser_session_id,
  590. conversation_id=conversation_id,
  591. user_id=user_id
  592. )), error_code
  593. except Exception as e:
  594. print(f"[ERROR] ask_agent执行失败: {str(e)}")
  595. return jsonify(internal_error_response(
  596. response_text="查询处理失败,请稍后重试"
  597. )), 500
  598. @app.flask_app.route('/api/v0/agent_health', methods=['GET'])
  599. def agent_health():
  600. """
  601. Agent健康检查接口
  602. 响应格式:
  603. {
  604. "success": true/false,
  605. "code": 200/503,
  606. "message": "healthy/degraded/unhealthy",
  607. "data": {
  608. "status": "healthy/degraded/unhealthy",
  609. "test_result": true/false,
  610. "workflow_compiled": true/false,
  611. "tools_count": 4,
  612. "message": "详细信息",
  613. "timestamp": "2024-01-01T12:00:00",
  614. "checks": {
  615. "agent_creation": true/false,
  616. "tools_import": true/false,
  617. "llm_connection": true/false,
  618. "classifier_ready": true/false
  619. }
  620. }
  621. }
  622. """
  623. try:
  624. # 基础健康检查
  625. health_data = {
  626. "status": "unknown",
  627. "test_result": False,
  628. "workflow_compiled": False,
  629. "tools_count": 0,
  630. "message": "",
  631. "timestamp": datetime.now().isoformat(),
  632. "checks": {
  633. "agent_creation": False,
  634. "tools_import": False,
  635. "llm_connection": False,
  636. "classifier_ready": False
  637. }
  638. }
  639. # 检查1: Agent创建
  640. try:
  641. agent = get_citu_langraph_agent()
  642. health_data["checks"]["agent_creation"] = True
  643. health_data["workflow_compiled"] = agent.workflow is not None
  644. health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
  645. except Exception as e:
  646. health_data["message"] = f"Agent创建失败: {str(e)}"
  647. from common.result import health_error_response
  648. return jsonify(health_error_response(
  649. status="unhealthy",
  650. **health_data
  651. )), 503
  652. # 检查2: 工具导入
  653. try:
  654. from agent.tools import TOOLS
  655. health_data["checks"]["tools_import"] = len(TOOLS) > 0
  656. except Exception as e:
  657. health_data["message"] = f"工具导入失败: {str(e)}"
  658. # 检查3: LLM连接(简单测试)
  659. try:
  660. from agent.utils import get_compatible_llm
  661. llm = get_compatible_llm()
  662. health_data["checks"]["llm_connection"] = llm is not None
  663. except Exception as e:
  664. health_data["message"] = f"LLM连接失败: {str(e)}"
  665. # 检查4: 分类器准备
  666. try:
  667. from agent.classifier import QuestionClassifier
  668. classifier = QuestionClassifier()
  669. health_data["checks"]["classifier_ready"] = True
  670. except Exception as e:
  671. health_data["message"] = f"分类器失败: {str(e)}"
  672. # 检查5: 完整流程测试(可选)
  673. try:
  674. if all(health_data["checks"].values()):
  675. test_result = agent.health_check()
  676. health_data["test_result"] = test_result.get("status") == "healthy"
  677. health_data["status"] = test_result.get("status", "unknown")
  678. health_data["message"] = test_result.get("message", "健康检查完成")
  679. else:
  680. health_data["status"] = "degraded"
  681. health_data["message"] = "部分组件异常"
  682. except Exception as e:
  683. health_data["status"] = "degraded"
  684. health_data["message"] = f"完整测试失败: {str(e)}"
  685. # 根据状态返回相应的HTTP代码 - 使用标准化健康检查响应
  686. from common.result import health_success_response, health_error_response
  687. if health_data["status"] == "healthy":
  688. return jsonify(health_success_response(**health_data))
  689. elif health_data["status"] == "degraded":
  690. return jsonify(health_error_response(status="degraded", **health_data)), 503
  691. else:
  692. return jsonify(health_error_response(status="unhealthy", **health_data)), 503
  693. except Exception as e:
  694. print(f"[ERROR] 健康检查异常: {str(e)}")
  695. from common.result import internal_error_response
  696. return jsonify(internal_error_response(
  697. response_text="健康检查失败,请稍后重试"
  698. )), 500
  699. # ==================== 日常管理API ====================
  700. @app.flask_app.route('/api/v0/cache_overview', methods=['GET'])
  701. def cache_overview():
  702. """日常管理:轻量概览 - 合并原cache_inspect的核心功能"""
  703. try:
  704. cache = app.cache
  705. result_data = {
  706. 'overview_summary': {
  707. 'total_conversations': 0,
  708. 'total_sessions': 0,
  709. 'query_time': datetime.now().isoformat()
  710. },
  711. 'recent_conversations': [], # 最近的对话
  712. 'session_summary': [] # 会话摘要
  713. }
  714. if hasattr(cache, 'cache') and isinstance(cache.cache, dict):
  715. result_data['overview_summary']['total_conversations'] = len(cache.cache)
  716. # 获取会话信息
  717. if hasattr(cache, 'get_all_sessions'):
  718. all_sessions = cache.get_all_sessions()
  719. result_data['overview_summary']['total_sessions'] = len(all_sessions)
  720. # 会话摘要(按最近活动排序)
  721. session_list = []
  722. for session_id, session_data in all_sessions.items():
  723. session_summary = {
  724. 'session_id': session_id,
  725. 'start_time': session_data['start_time'].isoformat(),
  726. 'conversation_count': session_data.get('conversation_count', 0),
  727. 'duration_seconds': session_data.get('session_duration_seconds', 0),
  728. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  729. 'is_active': (datetime.now() - session_data.get('last_activity', session_data['start_time'])).total_seconds() < 1800 # 30分钟内活跃
  730. }
  731. session_list.append(session_summary)
  732. # 按最后活动时间排序
  733. session_list.sort(key=lambda x: x['last_activity'], reverse=True)
  734. result_data['session_summary'] = session_list
  735. # 最近的对话(最多显示10个)
  736. conversation_list = []
  737. for conversation_id, conversation_data in cache.cache.items():
  738. conversation_start_time = cache.conversation_start_times.get(conversation_id)
  739. conversation_info = {
  740. 'conversation_id': conversation_id,
  741. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  742. 'session_id': cache.conversation_to_session.get(conversation_id),
  743. 'has_question': 'question' in conversation_data,
  744. 'has_sql': 'sql' in conversation_data,
  745. 'has_data': 'df' in conversation_data and conversation_data['df'] is not None,
  746. 'question_preview': conversation_data.get('question', '')[:80] + '...' if len(conversation_data.get('question', '')) > 80 else conversation_data.get('question', ''),
  747. }
  748. # 计算对话持续时间
  749. if conversation_start_time:
  750. duration = datetime.now() - conversation_start_time
  751. conversation_info['conversation_duration_seconds'] = duration.total_seconds()
  752. conversation_list.append(conversation_info)
  753. # 按对话开始时间排序,显示最新的10个
  754. conversation_list.sort(key=lambda x: x['conversation_start_time'] or '', reverse=True)
  755. result_data['recent_conversations'] = conversation_list[:10]
  756. from common.result import success_response
  757. return jsonify(success_response(
  758. response_text="缓存概览查询完成",
  759. data=result_data
  760. ))
  761. except Exception as e:
  762. from common.result import internal_error_response
  763. return jsonify(internal_error_response(
  764. response_text="获取缓存概览失败,请稍后重试"
  765. )), 500
  766. @app.flask_app.route('/api/v0/cache_stats', methods=['GET'])
  767. def cache_stats():
  768. """日常管理:统计信息 - 合并原session_stats和cache_stats功能"""
  769. try:
  770. cache = app.cache
  771. current_time = datetime.now()
  772. stats = {
  773. 'basic_stats': {
  774. 'total_sessions': len(getattr(cache, 'session_info', {})),
  775. 'total_conversations': len(getattr(cache, 'cache', {})),
  776. 'active_sessions': 0, # 最近30分钟有活动
  777. 'average_conversations_per_session': 0
  778. },
  779. 'time_distribution': {
  780. 'sessions': {
  781. 'last_1_hour': 0,
  782. 'last_6_hours': 0,
  783. 'last_24_hours': 0,
  784. 'last_7_days': 0,
  785. 'older': 0
  786. },
  787. 'conversations': {
  788. 'last_1_hour': 0,
  789. 'last_6_hours': 0,
  790. 'last_24_hours': 0,
  791. 'last_7_days': 0,
  792. 'older': 0
  793. }
  794. },
  795. 'session_details': [],
  796. 'time_ranges': {
  797. 'oldest_session': None,
  798. 'newest_session': None,
  799. 'oldest_conversation': None,
  800. 'newest_conversation': None
  801. }
  802. }
  803. # 会话统计
  804. if hasattr(cache, 'session_info'):
  805. session_times = []
  806. total_conversations = 0
  807. for session_id, session_data in cache.session_info.items():
  808. start_time = session_data['start_time']
  809. session_times.append(start_time)
  810. conversation_count = len(session_data.get('conversations', []))
  811. total_conversations += conversation_count
  812. # 检查活跃状态
  813. last_activity = session_data.get('last_activity', session_data['start_time'])
  814. if (current_time - last_activity).total_seconds() < 1800:
  815. stats['basic_stats']['active_sessions'] += 1
  816. # 时间分布统计
  817. age_hours = (current_time - start_time).total_seconds() / 3600
  818. if age_hours <= 1:
  819. stats['time_distribution']['sessions']['last_1_hour'] += 1
  820. elif age_hours <= 6:
  821. stats['time_distribution']['sessions']['last_6_hours'] += 1
  822. elif age_hours <= 24:
  823. stats['time_distribution']['sessions']['last_24_hours'] += 1
  824. elif age_hours <= 168: # 7 days
  825. stats['time_distribution']['sessions']['last_7_days'] += 1
  826. else:
  827. stats['time_distribution']['sessions']['older'] += 1
  828. # 会话详细信息
  829. session_duration = current_time - start_time
  830. stats['session_details'].append({
  831. 'session_id': session_id,
  832. 'start_time': start_time.isoformat(),
  833. 'last_activity': last_activity.isoformat(),
  834. 'conversation_count': conversation_count,
  835. 'duration_seconds': session_duration.total_seconds(),
  836. 'duration_formatted': str(session_duration),
  837. 'is_active': (current_time - last_activity).total_seconds() < 1800,
  838. 'browser_session_id': session_data.get('browser_session_id')
  839. })
  840. # 计算平均值
  841. if len(cache.session_info) > 0:
  842. stats['basic_stats']['average_conversations_per_session'] = total_conversations / len(cache.session_info)
  843. # 时间范围
  844. if session_times:
  845. stats['time_ranges']['oldest_session'] = min(session_times).isoformat()
  846. stats['time_ranges']['newest_session'] = max(session_times).isoformat()
  847. # 对话统计
  848. if hasattr(cache, 'conversation_start_times'):
  849. conversation_times = []
  850. for conv_time in cache.conversation_start_times.values():
  851. conversation_times.append(conv_time)
  852. age_hours = (current_time - conv_time).total_seconds() / 3600
  853. if age_hours <= 1:
  854. stats['time_distribution']['conversations']['last_1_hour'] += 1
  855. elif age_hours <= 6:
  856. stats['time_distribution']['conversations']['last_6_hours'] += 1
  857. elif age_hours <= 24:
  858. stats['time_distribution']['conversations']['last_24_hours'] += 1
  859. elif age_hours <= 168:
  860. stats['time_distribution']['conversations']['last_7_days'] += 1
  861. else:
  862. stats['time_distribution']['conversations']['older'] += 1
  863. if conversation_times:
  864. stats['time_ranges']['oldest_conversation'] = min(conversation_times).isoformat()
  865. stats['time_ranges']['newest_conversation'] = max(conversation_times).isoformat()
  866. # 按最近活动排序会话详情
  867. stats['session_details'].sort(key=lambda x: x['last_activity'], reverse=True)
  868. from common.result import success_response
  869. return jsonify(success_response(
  870. response_text="缓存统计信息查询完成",
  871. data=stats
  872. ))
  873. except Exception as e:
  874. from common.result import internal_error_response
  875. return jsonify(internal_error_response(
  876. response_text="获取缓存统计失败,请稍后重试"
  877. )), 500
  878. # ==================== 高级功能API ====================
  879. @app.flask_app.route('/api/v0/cache_export', methods=['GET'])
  880. def cache_export():
  881. """高级功能:完整导出 - 保持原cache_raw_export的完整功能"""
  882. try:
  883. cache = app.cache
  884. # 验证缓存的实际结构
  885. if not hasattr(cache, 'cache'):
  886. from common.result import internal_error_response
  887. return jsonify(internal_error_response(
  888. response_text="缓存对象结构异常,请联系系统管理员"
  889. )), 500
  890. if not isinstance(cache.cache, dict):
  891. from common.result import internal_error_response
  892. return jsonify(internal_error_response(
  893. response_text="缓存数据类型异常,请联系系统管理员"
  894. )), 500
  895. # 定义JSON序列化辅助函数
  896. def make_json_serializable(obj):
  897. """将对象转换为JSON可序列化的格式"""
  898. if obj is None:
  899. return None
  900. elif isinstance(obj, (str, int, float, bool)):
  901. return obj
  902. elif isinstance(obj, (list, tuple)):
  903. return [make_json_serializable(item) for item in obj]
  904. elif isinstance(obj, dict):
  905. return {str(k): make_json_serializable(v) for k, v in obj.items()}
  906. elif hasattr(obj, 'isoformat'): # datetime objects
  907. return obj.isoformat()
  908. elif hasattr(obj, 'item'): # numpy scalars
  909. return obj.item()
  910. elif hasattr(obj, 'tolist'): # numpy arrays
  911. return obj.tolist()
  912. elif hasattr(obj, '__dict__'): # pandas dtypes and other objects
  913. return str(obj)
  914. else:
  915. return str(obj)
  916. # 获取完整的原始缓存数据
  917. raw_cache = cache.cache
  918. # 获取会话和对话时间信息
  919. conversation_times = getattr(cache, 'conversation_start_times', {})
  920. session_info = getattr(cache, 'session_info', {})
  921. conversation_to_session = getattr(cache, 'conversation_to_session', {})
  922. export_data = {
  923. 'export_metadata': {
  924. 'export_time': datetime.now().isoformat(),
  925. 'total_conversations': len(raw_cache),
  926. 'total_sessions': len(session_info),
  927. 'cache_type': type(cache).__name__,
  928. 'cache_object_info': str(cache),
  929. 'has_session_times': bool(session_info),
  930. 'has_conversation_times': bool(conversation_times)
  931. },
  932. 'session_info': {
  933. session_id: {
  934. 'start_time': session_data['start_time'].isoformat(),
  935. 'last_activity': session_data.get('last_activity', session_data['start_time']).isoformat(),
  936. 'conversations': session_data['conversations'],
  937. 'conversation_count': len(session_data['conversations']),
  938. 'browser_session_id': session_data.get('browser_session_id'),
  939. 'user_info': session_data.get('user_info', {})
  940. }
  941. for session_id, session_data in session_info.items()
  942. },
  943. 'conversation_times': {
  944. conversation_id: start_time.isoformat()
  945. for conversation_id, start_time in conversation_times.items()
  946. },
  947. 'conversation_to_session_mapping': conversation_to_session,
  948. 'conversations': {}
  949. }
  950. # 处理每个对话的完整数据
  951. for conversation_id, conversation_data in raw_cache.items():
  952. # 获取时间信息
  953. conversation_start_time = conversation_times.get(conversation_id)
  954. session_id = conversation_to_session.get(conversation_id)
  955. session_start_time = None
  956. if session_id and session_id in session_info:
  957. session_start_time = session_info[session_id]['start_time']
  958. processed_conversation = {
  959. 'conversation_id': conversation_id,
  960. 'conversation_start_time': conversation_start_time.isoformat() if conversation_start_time else None,
  961. 'session_id': session_id,
  962. 'session_start_time': session_start_time.isoformat() if session_start_time else None,
  963. 'field_count': len(conversation_data),
  964. 'fields': {}
  965. }
  966. # 添加时间计算
  967. if conversation_start_time:
  968. conversation_duration = datetime.now() - conversation_start_time
  969. processed_conversation['conversation_duration_seconds'] = conversation_duration.total_seconds()
  970. processed_conversation['conversation_duration_formatted'] = str(conversation_duration)
  971. if session_start_time:
  972. session_duration = datetime.now() - session_start_time
  973. processed_conversation['session_duration_seconds'] = session_duration.total_seconds()
  974. processed_conversation['session_duration_formatted'] = str(session_duration)
  975. # 处理每个字段,确保JSON序列化安全
  976. for field_name, field_value in conversation_data.items():
  977. field_info = {
  978. 'field_name': field_name,
  979. 'data_type': type(field_value).__name__,
  980. 'is_none': field_value is None
  981. }
  982. try:
  983. if field_value is None:
  984. field_info['value'] = None
  985. elif field_name in ['conversation_start_time', 'session_start_time']:
  986. # 处理时间字段
  987. field_info['content'] = make_json_serializable(field_value)
  988. elif field_name == 'df' and field_value is not None:
  989. # DataFrame的安全处理
  990. if hasattr(field_value, 'to_dict'):
  991. # 安全地处理dtypes
  992. try:
  993. dtypes_dict = {}
  994. for col, dtype in field_value.dtypes.items():
  995. dtypes_dict[col] = str(dtype)
  996. except Exception:
  997. dtypes_dict = {"error": "无法序列化dtypes"}
  998. # 安全地处理内存使用
  999. try:
  1000. memory_usage = field_value.memory_usage(deep=True)
  1001. memory_dict = {}
  1002. for idx, usage in memory_usage.items():
  1003. memory_dict[str(idx)] = int(usage) if hasattr(usage, 'item') else int(usage)
  1004. except Exception:
  1005. memory_dict = {"error": "无法获取内存使用信息"}
  1006. field_info.update({
  1007. 'dataframe_info': {
  1008. 'shape': list(field_value.shape),
  1009. 'columns': list(field_value.columns),
  1010. 'dtypes': dtypes_dict,
  1011. 'index_info': {
  1012. 'type': type(field_value.index).__name__,
  1013. 'length': len(field_value.index)
  1014. }
  1015. },
  1016. 'data': make_json_serializable(field_value.to_dict('records')),
  1017. 'memory_usage': memory_dict
  1018. })
  1019. else:
  1020. field_info['value'] = str(field_value)
  1021. field_info['note'] = 'not_standard_dataframe'
  1022. elif field_name == 'fig_json':
  1023. # 图表JSON数据处理
  1024. if isinstance(field_value, str):
  1025. try:
  1026. import json
  1027. parsed_fig = json.loads(field_value)
  1028. field_info.update({
  1029. 'json_valid': True,
  1030. 'json_size_bytes': len(field_value),
  1031. 'plotly_structure': {
  1032. 'has_data': 'data' in parsed_fig,
  1033. 'has_layout': 'layout' in parsed_fig,
  1034. 'data_traces_count': len(parsed_fig.get('data', [])),
  1035. },
  1036. 'raw_json': field_value
  1037. })
  1038. except json.JSONDecodeError:
  1039. field_info.update({
  1040. 'json_valid': False,
  1041. 'raw_content': str(field_value)
  1042. })
  1043. else:
  1044. field_info['value'] = make_json_serializable(field_value)
  1045. elif field_name == 'followup_questions':
  1046. # 后续问题列表
  1047. field_info.update({
  1048. 'content': make_json_serializable(field_value)
  1049. })
  1050. elif field_name in ['question', 'sql', 'summary']:
  1051. # 文本字段
  1052. if isinstance(field_value, str):
  1053. field_info.update({
  1054. 'text_length': len(field_value),
  1055. 'content': field_value
  1056. })
  1057. else:
  1058. field_info['value'] = make_json_serializable(field_value)
  1059. else:
  1060. # 未知字段的安全处理
  1061. field_info['content'] = make_json_serializable(field_value)
  1062. except Exception as e:
  1063. field_info.update({
  1064. 'processing_error': str(e),
  1065. 'fallback_value': str(field_value)[:500] + '...' if len(str(field_value)) > 500 else str(field_value)
  1066. })
  1067. processed_conversation['fields'][field_name] = field_info
  1068. export_data['conversations'][conversation_id] = processed_conversation
  1069. # 添加缓存统计信息
  1070. field_frequency = {}
  1071. data_types_found = set()
  1072. total_dataframes = 0
  1073. total_questions = 0
  1074. for conv_data in export_data['conversations'].values():
  1075. for field_name, field_info in conv_data['fields'].items():
  1076. field_frequency[field_name] = field_frequency.get(field_name, 0) + 1
  1077. data_types_found.add(field_info['data_type'])
  1078. if field_name == 'df' and not field_info['is_none']:
  1079. total_dataframes += 1
  1080. if field_name == 'question' and not field_info['is_none']:
  1081. total_questions += 1
  1082. export_data['cache_statistics'] = {
  1083. 'field_frequency': field_frequency,
  1084. 'data_types_found': list(data_types_found),
  1085. 'total_dataframes': total_dataframes,
  1086. 'total_questions': total_questions,
  1087. 'has_session_timing': 'session_start_time' in field_frequency,
  1088. 'has_conversation_timing': 'conversation_start_time' in field_frequency
  1089. }
  1090. from common.result import success_response
  1091. return jsonify(success_response(
  1092. response_text="缓存数据导出完成",
  1093. data=export_data
  1094. ))
  1095. except Exception as e:
  1096. import traceback
  1097. error_details = {
  1098. 'error_message': str(e),
  1099. 'error_type': type(e).__name__,
  1100. 'traceback': traceback.format_exc()
  1101. }
  1102. from common.result import internal_error_response
  1103. return jsonify(internal_error_response(
  1104. response_text="导出缓存失败,请稍后重试"
  1105. )), 500
  1106. # ==================== 清理功能API ====================
  1107. @app.flask_app.route('/api/v0/cache_preview_cleanup', methods=['POST'])
  1108. def cache_preview_cleanup():
  1109. """清理功能:预览删除操作 - 保持原功能"""
  1110. try:
  1111. req = request.get_json(force=True)
  1112. # 时间条件 - 支持三种方式
  1113. older_than_hours = req.get('older_than_hours')
  1114. older_than_days = req.get('older_than_days')
  1115. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1116. cache = app.cache
  1117. # 计算截止时间
  1118. cutoff_time = None
  1119. time_condition = None
  1120. if older_than_hours:
  1121. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1122. time_condition = f"older_than_hours: {older_than_hours}"
  1123. elif older_than_days:
  1124. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1125. time_condition = f"older_than_days: {older_than_days}"
  1126. elif before_timestamp:
  1127. try:
  1128. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1129. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1130. time_condition = f"before_timestamp: {before_timestamp}"
  1131. except ValueError:
  1132. from common.result import validation_failed_response
  1133. return jsonify(validation_failed_response(
  1134. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1135. )), 422
  1136. else:
  1137. from common.result import bad_request_response
  1138. return jsonify(bad_request_response(
  1139. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1140. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1141. )), 400
  1142. preview = {
  1143. 'time_condition': time_condition,
  1144. 'cutoff_time': cutoff_time.isoformat(),
  1145. 'will_be_removed': {
  1146. 'sessions': []
  1147. },
  1148. 'will_be_kept': {
  1149. 'sessions_count': 0,
  1150. 'conversations_count': 0
  1151. },
  1152. 'summary': {
  1153. 'sessions_to_remove': 0,
  1154. 'conversations_to_remove': 0,
  1155. 'sessions_to_keep': 0,
  1156. 'conversations_to_keep': 0
  1157. }
  1158. }
  1159. # 预览按session删除
  1160. sessions_to_remove_count = 0
  1161. conversations_to_remove_count = 0
  1162. for session_id, session_data in cache.session_info.items():
  1163. session_preview = {
  1164. 'session_id': session_id,
  1165. 'start_time': session_data['start_time'].isoformat(),
  1166. 'conversation_count': len(session_data['conversations']),
  1167. 'conversations': []
  1168. }
  1169. # 添加conversation详情
  1170. for conv_id in session_data['conversations']:
  1171. if conv_id in cache.cache:
  1172. conv_data = cache.cache[conv_id]
  1173. session_preview['conversations'].append({
  1174. 'conversation_id': conv_id,
  1175. 'question': conv_data.get('question', '')[:50] + '...' if conv_data.get('question') else '',
  1176. 'start_time': cache.conversation_start_times.get(conv_id, '').isoformat() if cache.conversation_start_times.get(conv_id) else ''
  1177. })
  1178. if session_data['start_time'] < cutoff_time:
  1179. preview['will_be_removed']['sessions'].append(session_preview)
  1180. sessions_to_remove_count += 1
  1181. conversations_to_remove_count += len(session_data['conversations'])
  1182. else:
  1183. preview['will_be_kept']['sessions_count'] += 1
  1184. preview['will_be_kept']['conversations_count'] += len(session_data['conversations'])
  1185. # 更新摘要统计
  1186. preview['summary'] = {
  1187. 'sessions_to_remove': sessions_to_remove_count,
  1188. 'conversations_to_remove': conversations_to_remove_count,
  1189. 'sessions_to_keep': preview['will_be_kept']['sessions_count'],
  1190. 'conversations_to_keep': preview['will_be_kept']['conversations_count']
  1191. }
  1192. from common.result import success_response
  1193. return jsonify(success_response(
  1194. response_text=f"清理预览完成,将删除 {sessions_to_remove_count} 个会话和 {conversations_to_remove_count} 个对话",
  1195. data=preview
  1196. ))
  1197. except Exception as e:
  1198. from common.result import internal_error_response
  1199. return jsonify(internal_error_response(
  1200. response_text="预览清理操作失败,请稍后重试"
  1201. )), 500
  1202. @app.flask_app.route('/api/v0/cache_cleanup', methods=['POST'])
  1203. def cache_cleanup():
  1204. """清理功能:实际删除缓存 - 保持原功能"""
  1205. try:
  1206. req = request.get_json(force=True)
  1207. # 时间条件 - 支持三种方式
  1208. older_than_hours = req.get('older_than_hours')
  1209. older_than_days = req.get('older_than_days')
  1210. before_timestamp = req.get('before_timestamp') # YYYY-MM-DD HH:MM:SS 格式
  1211. cache = app.cache
  1212. if not hasattr(cache, 'session_info'):
  1213. from common.result import service_unavailable_response
  1214. return jsonify(service_unavailable_response(
  1215. response_text="缓存不支持会话功能"
  1216. )), 503
  1217. # 计算截止时间
  1218. cutoff_time = None
  1219. time_condition = None
  1220. if older_than_hours:
  1221. cutoff_time = datetime.now() - timedelta(hours=older_than_hours)
  1222. time_condition = f"older_than_hours: {older_than_hours}"
  1223. elif older_than_days:
  1224. cutoff_time = datetime.now() - timedelta(days=older_than_days)
  1225. time_condition = f"older_than_days: {older_than_days}"
  1226. elif before_timestamp:
  1227. try:
  1228. # 支持 YYYY-MM-DD HH:MM:SS 格式
  1229. cutoff_time = datetime.strptime(before_timestamp, '%Y-%m-%d %H:%M:%S')
  1230. time_condition = f"before_timestamp: {before_timestamp}"
  1231. except ValueError:
  1232. from common.result import validation_failed_response
  1233. return jsonify(validation_failed_response(
  1234. response_text="before_timestamp格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式"
  1235. )), 422
  1236. else:
  1237. from common.result import bad_request_response
  1238. return jsonify(bad_request_response(
  1239. response_text="必须提供时间条件:older_than_hours, older_than_days 或 before_timestamp (YYYY-MM-DD HH:MM:SS)",
  1240. missing_params=["older_than_hours", "older_than_days", "before_timestamp"]
  1241. )), 400
  1242. cleanup_stats = {
  1243. 'time_condition': time_condition,
  1244. 'cutoff_time': cutoff_time.isoformat(),
  1245. 'sessions_removed': 0,
  1246. 'conversations_removed': 0,
  1247. 'sessions_kept': 0,
  1248. 'conversations_kept': 0,
  1249. 'removed_session_ids': [],
  1250. 'removed_conversation_ids': []
  1251. }
  1252. # 按session删除
  1253. sessions_to_remove = []
  1254. for session_id, session_data in cache.session_info.items():
  1255. if session_data['start_time'] < cutoff_time:
  1256. sessions_to_remove.append(session_id)
  1257. # 删除符合条件的sessions及其所有conversations
  1258. for session_id in sessions_to_remove:
  1259. session_data = cache.session_info[session_id]
  1260. conversations_in_session = session_data['conversations'].copy()
  1261. # 删除session中的所有conversations
  1262. for conv_id in conversations_in_session:
  1263. if conv_id in cache.cache:
  1264. del cache.cache[conv_id]
  1265. cleanup_stats['conversations_removed'] += 1
  1266. cleanup_stats['removed_conversation_ids'].append(conv_id)
  1267. # 清理conversation相关的时间记录
  1268. if hasattr(cache, 'conversation_start_times') and conv_id in cache.conversation_start_times:
  1269. del cache.conversation_start_times[conv_id]
  1270. if hasattr(cache, 'conversation_to_session') and conv_id in cache.conversation_to_session:
  1271. del cache.conversation_to_session[conv_id]
  1272. # 删除session记录
  1273. del cache.session_info[session_id]
  1274. cleanup_stats['sessions_removed'] += 1
  1275. cleanup_stats['removed_session_ids'].append(session_id)
  1276. # 统计保留的sessions和conversations
  1277. cleanup_stats['sessions_kept'] = len(cache.session_info)
  1278. cleanup_stats['conversations_kept'] = len(cache.cache)
  1279. from common.result import success_response
  1280. return jsonify(success_response(
  1281. response_text=f"缓存清理完成,删除了 {cleanup_stats['sessions_removed']} 个会话和 {cleanup_stats['conversations_removed']} 个对话",
  1282. data=cleanup_stats
  1283. ))
  1284. except Exception as e:
  1285. from common.result import internal_error_response
  1286. return jsonify(internal_error_response(
  1287. response_text="缓存清理失败,请稍后重试"
  1288. )), 500
  1289. @app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
  1290. def training_error_question_sql():
  1291. """
  1292. 存储错误的question-sql对到error_sql集合中
  1293. 此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
  1294. Args:
  1295. question (str, required): 用户问题
  1296. sql (str, required): 对应的错误SQL查询语句
  1297. Returns:
  1298. JSON: 包含训练ID和成功消息的响应
  1299. """
  1300. try:
  1301. data = request.get_json()
  1302. question = data.get('question')
  1303. sql = data.get('sql')
  1304. print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
  1305. if not question or not sql:
  1306. from common.result import bad_request_response
  1307. missing_params = []
  1308. if not question:
  1309. missing_params.append("question")
  1310. if not sql:
  1311. missing_params.append("sql")
  1312. return jsonify(bad_request_response(
  1313. response_text="question和sql参数都是必需的",
  1314. missing_params=missing_params
  1315. )), 400
  1316. # 使用vn实例的train_error_sql方法存储错误SQL
  1317. id = vn.train_error_sql(question=question, sql=sql)
  1318. print(f"[INFO] 成功存储错误SQL,ID: {id}")
  1319. from common.result import success_response
  1320. return jsonify(success_response(
  1321. response_text="错误SQL对已成功存储",
  1322. data={
  1323. "id": id,
  1324. "message": "错误SQL对已成功存储到error_sql集合"
  1325. }
  1326. ))
  1327. except Exception as e:
  1328. print(f"[ERROR] 存储错误SQL失败: {str(e)}")
  1329. from common.result import internal_error_response
  1330. return jsonify(internal_error_response(
  1331. response_text="存储错误SQL失败,请稍后重试"
  1332. )), 500
  1333. # ==================== Redis对话管理API ====================
  1334. @app.flask_app.route('/api/v0/user/<user_id>/conversations', methods=['GET'])
  1335. def get_user_conversations(user_id: str):
  1336. """获取用户的对话列表(按时间倒序)"""
  1337. try:
  1338. limit = request.args.get('limit', USER_MAX_CONVERSATIONS, type=int)
  1339. conversations = redis_conversation_manager.get_conversations(user_id, limit)
  1340. return jsonify(success_response(
  1341. response_text="获取用户对话列表成功",
  1342. data={
  1343. "user_id": user_id,
  1344. "conversations": conversations,
  1345. "total_count": len(conversations)
  1346. }
  1347. ))
  1348. except Exception as e:
  1349. return jsonify(internal_error_response(
  1350. response_text="获取对话列表失败,请稍后重试"
  1351. )), 500
  1352. @app.flask_app.route('/api/v0/conversation/<conversation_id>/messages', methods=['GET'])
  1353. def get_conversation_messages(conversation_id: str):
  1354. """获取特定对话的消息历史"""
  1355. try:
  1356. limit = request.args.get('limit', type=int) # 可选参数
  1357. messages = redis_conversation_manager.get_conversation_messages(conversation_id, limit)
  1358. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1359. return jsonify(success_response(
  1360. response_text="获取对话消息成功",
  1361. data={
  1362. "conversation_id": conversation_id,
  1363. "conversation_meta": meta,
  1364. "messages": messages,
  1365. "message_count": len(messages)
  1366. }
  1367. ))
  1368. except Exception as e:
  1369. return jsonify(internal_error_response(
  1370. response_text="获取对话消息失败"
  1371. )), 500
  1372. @app.flask_app.route('/api/v0/conversation/<conversation_id>/context', methods=['GET'])
  1373. def get_conversation_context(conversation_id: str):
  1374. """获取对话上下文(格式化用于LLM)"""
  1375. try:
  1376. count = request.args.get('count', CONVERSATION_CONTEXT_COUNT, type=int)
  1377. context = redis_conversation_manager.get_context(conversation_id, count)
  1378. return jsonify(success_response(
  1379. response_text="获取对话上下文成功",
  1380. data={
  1381. "conversation_id": conversation_id,
  1382. "context": context,
  1383. "context_message_count": count
  1384. }
  1385. ))
  1386. except Exception as e:
  1387. return jsonify(internal_error_response(
  1388. response_text="获取对话上下文失败"
  1389. )), 500
  1390. @app.flask_app.route('/api/v0/conversation_stats', methods=['GET'])
  1391. def conversation_stats():
  1392. """获取对话系统统计信息"""
  1393. try:
  1394. stats = redis_conversation_manager.get_stats()
  1395. return jsonify(success_response(
  1396. response_text="获取统计信息成功",
  1397. data=stats
  1398. ))
  1399. except Exception as e:
  1400. return jsonify(internal_error_response(
  1401. response_text="获取统计信息失败,请稍后重试"
  1402. )), 500
  1403. @app.flask_app.route('/api/v0/conversation_cleanup', methods=['POST'])
  1404. def conversation_cleanup():
  1405. """手动清理过期对话"""
  1406. try:
  1407. redis_conversation_manager.cleanup_expired_conversations()
  1408. return jsonify(success_response(
  1409. response_text="对话清理完成"
  1410. ))
  1411. except Exception as e:
  1412. return jsonify(internal_error_response(
  1413. response_text="对话清理失败,请稍后重试"
  1414. )), 500
  1415. @app.flask_app.route('/api/v0/user/<user_id>/conversations/full', methods=['GET'])
  1416. def get_user_conversations_with_messages(user_id: str):
  1417. """
  1418. 获取用户的完整对话数据(包含所有消息)
  1419. 一次性返回用户的所有对话和每个对话下的消息历史
  1420. Args:
  1421. user_id: 用户ID(路径参数)
  1422. conversation_limit: 对话数量限制(查询参数,可选,不传则返回所有对话)
  1423. message_limit: 每个对话的消息数限制(查询参数,可选,不传则返回所有消息)
  1424. Returns:
  1425. 包含用户所有对话和消息的完整数据
  1426. """
  1427. try:
  1428. # 获取可选参数,不传递时使用None(返回所有记录)
  1429. conversation_limit = request.args.get('conversation_limit', type=int)
  1430. message_limit = request.args.get('message_limit', type=int)
  1431. # 获取用户的对话列表
  1432. conversations = redis_conversation_manager.get_conversations(user_id, conversation_limit)
  1433. # 为每个对话获取消息历史
  1434. full_conversations = []
  1435. total_messages = 0
  1436. for conversation in conversations:
  1437. conversation_id = conversation['conversation_id']
  1438. # 获取对话消息
  1439. messages = redis_conversation_manager.get_conversation_messages(
  1440. conversation_id, message_limit
  1441. )
  1442. # 获取对话元数据
  1443. meta = redis_conversation_manager.get_conversation_meta(conversation_id)
  1444. # 组合完整数据
  1445. full_conversation = {
  1446. **conversation, # 基础对话信息
  1447. 'meta': meta, # 对话元数据
  1448. 'messages': messages, # 消息列表
  1449. 'message_count': len(messages)
  1450. }
  1451. full_conversations.append(full_conversation)
  1452. total_messages += len(messages)
  1453. return jsonify(success_response(
  1454. response_text="获取用户完整对话数据成功",
  1455. data={
  1456. "user_id": user_id,
  1457. "conversations": full_conversations,
  1458. "total_conversations": len(full_conversations),
  1459. "total_messages": total_messages,
  1460. "conversation_limit_applied": conversation_limit,
  1461. "message_limit_applied": message_limit,
  1462. "query_time": datetime.now().isoformat()
  1463. }
  1464. ))
  1465. except Exception as e:
  1466. print(f"[ERROR] 获取用户完整对话数据失败: {str(e)}")
  1467. return jsonify(internal_error_response(
  1468. response_text="获取用户对话数据失败,请稍后重试"
  1469. )), 500
  1470. # ==================== Embedding缓存管理接口 ====================
  1471. @app.flask_app.route('/api/v0/embedding_cache_stats', methods=['GET'])
  1472. def embedding_cache_stats():
  1473. """获取embedding缓存统计信息"""
  1474. try:
  1475. from common.embedding_cache_manager import get_embedding_cache_manager
  1476. cache_manager = get_embedding_cache_manager()
  1477. stats = cache_manager.get_cache_stats()
  1478. return jsonify(success_response(
  1479. response_text="获取embedding缓存统计成功",
  1480. data=stats
  1481. ))
  1482. except Exception as e:
  1483. print(f"[ERROR] 获取embedding缓存统计失败: {str(e)}")
  1484. return jsonify(internal_error_response(
  1485. response_text="获取embedding缓存统计失败,请稍后重试"
  1486. )), 500
  1487. @app.flask_app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
  1488. def embedding_cache_cleanup():
  1489. """清空所有embedding缓存"""
  1490. try:
  1491. from common.embedding_cache_manager import get_embedding_cache_manager
  1492. cache_manager = get_embedding_cache_manager()
  1493. if not cache_manager.is_available():
  1494. return jsonify(internal_error_response(
  1495. response_text="Embedding缓存功能未启用或不可用"
  1496. )), 400
  1497. success = cache_manager.clear_all_cache()
  1498. if success:
  1499. return jsonify(success_response(
  1500. response_text="所有embedding缓存已清空",
  1501. data={"cleared": True}
  1502. ))
  1503. else:
  1504. return jsonify(internal_error_response(
  1505. response_text="清空embedding缓存失败"
  1506. )), 500
  1507. except Exception as e:
  1508. print(f"[ERROR] 清空embedding缓存失败: {str(e)}")
  1509. return jsonify(internal_error_response(
  1510. response_text="清空embedding缓存失败,请稍后重试"
  1511. )), 500
  1512. @app.flask_app.route('/api/v0/cache_overview_full', methods=['GET'])
  1513. def cache_overview_full():
  1514. """获取所有缓存系统的综合概览"""
  1515. try:
  1516. from common.embedding_cache_manager import get_embedding_cache_manager
  1517. from common.vanna_instance import get_vanna_instance
  1518. from common.session_aware_cache import get_cache
  1519. # 获取现有的缓存统计
  1520. vanna_cache = get_vanna_instance()
  1521. cache = get_cache()
  1522. cache_overview = {
  1523. "conversation_aware_cache": {
  1524. "enabled": True,
  1525. "total_items": len(cache.cache),
  1526. "sessions": list(cache.cache.keys()) if hasattr(cache, 'cache') else []
  1527. },
  1528. "question_answer_cache": {
  1529. "enabled": ENABLE_QUESTION_ANSWER_CACHE,
  1530. "stats": redis_conversation_manager.get_stats() if redis_conversation_manager.is_available() else None
  1531. },
  1532. "embedding_cache": get_embedding_cache_manager().get_cache_stats()
  1533. }
  1534. return jsonify(success_response(
  1535. response_text="获取综合缓存概览成功",
  1536. data=cache_overview
  1537. ))
  1538. except Exception as e:
  1539. print(f"[ERROR] 获取综合缓存概览失败: {str(e)}")
  1540. return jsonify(internal_error_response(
  1541. response_text="获取缓存概览失败,请稍后重试"
  1542. )), 500
  1543. # 前端JavaScript示例 - 如何维持会话
  1544. """
  1545. // 前端需要维护一个会话ID
  1546. class ChatSession {
  1547. constructor() {
  1548. // 从localStorage获取或创建新的会话ID
  1549. this.sessionId = localStorage.getItem('chat_session_id') || this.generateSessionId();
  1550. localStorage.setItem('chat_session_id', this.sessionId);
  1551. }
  1552. generateSessionId() {
  1553. return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
  1554. }
  1555. async askQuestion(question) {
  1556. const response = await fetch('/api/v0/ask', {
  1557. method: 'POST',
  1558. headers: {
  1559. 'Content-Type': 'application/json',
  1560. },
  1561. body: JSON.stringify({
  1562. question: question,
  1563. session_id: this.sessionId // 关键:传递会话ID
  1564. })
  1565. });
  1566. return await response.json();
  1567. }
  1568. // 开始新会话
  1569. startNewSession() {
  1570. this.sessionId = this.generateSessionId();
  1571. localStorage.setItem('chat_session_id', this.sessionId);
  1572. }
  1573. }
  1574. // 使用示例
  1575. const chatSession = new ChatSession();
  1576. chatSession.askQuestion("各年龄段客户的流失率如何?");
  1577. """
  1578. print("正在启动Flask应用: http://localhost:8084")
  1579. app.run(host="0.0.0.0", port=8084, debug=True)