|
@@ -5,12 +5,12 @@ from flask import request, jsonify
|
|
|
import pandas as pd
|
|
|
import common.result as result
|
|
|
from datetime import datetime, timedelta
|
|
|
-from common.session_aware_cache import SessionAwareMemoryCache
|
|
|
+from common.session_aware_cache import WebSessionAwareMemoryCache
|
|
|
|
|
|
vn = create_vanna_instance()
|
|
|
|
|
|
# 创建带时间戳的缓存
|
|
|
-timestamped_cache = SessionAwareMemoryCache()
|
|
|
+timestamped_cache = WebSessionAwareMemoryCache()
|
|
|
|
|
|
# 实例化 VannaFlaskApp,使用自定义缓存
|
|
|
app = VannaFlaskApp(
|
|
@@ -46,7 +46,7 @@ def ask_full():
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
- sql, df, fig = vn.ask(
|
|
|
+ sql, df, _ = vn.ask(
|
|
|
question=question,
|
|
|
print_results=False,
|
|
|
visualize=False,
|
|
@@ -84,6 +84,123 @@ def ask_full():
|
|
|
code=500
|
|
|
)), 500
|
|
|
|
|
|
+@app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
|
|
|
+def citu_run_sql():
|
|
|
+ req = request.get_json(force=True)
|
|
|
+ sql = req.get('sql')
|
|
|
+
|
|
|
+ if not sql:
|
|
|
+ return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
|
|
|
+
|
|
|
+ try:
|
|
|
+ df = vn.run_sql(sql)
|
|
|
+
|
|
|
+ rows, columns = [], []
|
|
|
+
|
|
|
+ if isinstance(df, pd.DataFrame) and not df.empty:
|
|
|
+ rows = df.head(1000).to_dict(orient="records")
|
|
|
+ columns = list(df.columns)
|
|
|
+
|
|
|
+ return jsonify(result.success(data={
|
|
|
+ "sql": sql,
|
|
|
+ "rows": rows,
|
|
|
+ "columns": columns
|
|
|
+ }))
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
|
|
|
+ return jsonify(result.failed(
|
|
|
+ message=f"SQL执行失败: {str(e)}",
|
|
|
+ code=500
|
|
|
+ )), 500
|
|
|
+
|
|
|
+@app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
|
|
|
+def ask_cached():
|
|
|
+ """
|
|
|
+ 带缓存功能的智能查询接口
|
|
|
+ 支持会话管理和结果缓存,提高查询效率
|
|
|
+ """
|
|
|
+ req = request.get_json(force=True)
|
|
|
+ question = req.get("question", None)
|
|
|
+ browser_session_id = req.get("session_id", None)
|
|
|
+
|
|
|
+ if not question:
|
|
|
+ return jsonify(result.failed(message="未提供问题", code=400)), 400
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 生成conversation_id
|
|
|
+ # 调试:查看generate_id的实际行为
|
|
|
+ print(f"[DEBUG] 输入问题: '{question}'")
|
|
|
+ conversation_id = app.cache.generate_id(question=question)
|
|
|
+ print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
|
|
|
+
|
|
|
+ # 再次用相同问题测试
|
|
|
+ conversation_id2 = app.cache.generate_id(question=question)
|
|
|
+ print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
|
|
|
+ print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
|
|
|
+
|
|
|
+ # 检查缓存
|
|
|
+ cached_sql = app.cache.get(id=conversation_id, field="sql")
|
|
|
+
|
|
|
+ if cached_sql is not None:
|
|
|
+ # 缓存命中
|
|
|
+ print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
|
|
|
+ sql = cached_sql
|
|
|
+ df = app.cache.get(id=conversation_id, field="df")
|
|
|
+ summary = app.cache.get(id=conversation_id, field="summary")
|
|
|
+ else:
|
|
|
+ # 缓存未命中,执行新查询
|
|
|
+ print(f"[CACHE MISS] 执行新查询: {conversation_id}")
|
|
|
+
|
|
|
+ sql, df, _ = vn.ask(
|
|
|
+ question=question,
|
|
|
+ print_results=False,
|
|
|
+ visualize=False,
|
|
|
+ allow_llm_to_see_data=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 缓存结果
|
|
|
+ app.cache.set(id=conversation_id, field="question", value=question)
|
|
|
+ app.cache.set(id=conversation_id, field="sql", value=sql)
|
|
|
+ app.cache.set(id=conversation_id, field="df", value=df)
|
|
|
+
|
|
|
+ # 生成并缓存摘要
|
|
|
+ summary = None
|
|
|
+ if isinstance(df, pd.DataFrame) and not df.empty:
|
|
|
+ try:
|
|
|
+ summary = vn.generate_summary(question=question, df=df)
|
|
|
+ print(f"[INFO] 成功生成摘要: {summary}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 生成摘要失败: {str(e)}")
|
|
|
+ summary = None
|
|
|
+
|
|
|
+ app.cache.set(id=conversation_id, field="summary", value=summary)
|
|
|
+
|
|
|
+ # 处理返回数据
|
|
|
+ rows, columns = [], []
|
|
|
+
|
|
|
+ if isinstance(df, pd.DataFrame) and not df.empty:
|
|
|
+ rows = df.head(1000).to_dict(orient="records")
|
|
|
+ columns = list(df.columns)
|
|
|
+
|
|
|
+ return jsonify(result.success(data={
|
|
|
+ "sql": sql,
|
|
|
+ "rows": rows,
|
|
|
+ "columns": columns,
|
|
|
+ "summary": summary,
|
|
|
+ "conversation_id": conversation_id,
|
|
|
+ "session_id": browser_session_id,
|
|
|
+ "cached": cached_sql is not None # 标识是否来自缓存
|
|
|
+ }))
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] ask_cached执行失败: {str(e)}")
|
|
|
+ return jsonify(result.failed(
|
|
|
+ message=f"查询处理失败: {str(e)}",
|
|
|
+ code=500
|
|
|
+ )), 500
|
|
|
+
|
|
|
+
|
|
|
|
|
|
@app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
|
|
|
def citu_train_question_sql():
|