瀏覽代碼

增加了citu_run_sql和ask_cache

wangxq 4 周之前
父節點
當前提交
fc1521efad
共有 1 個文件被更改,包括 120 次插入3 次删除
  1. 120 3
      citu_app.py

+ 120 - 3
citu_app.py

@@ -5,12 +5,12 @@ from flask import request, jsonify
 import pandas as pd
 import pandas as pd
 import common.result as result
 import common.result as result
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from common.session_aware_cache import SessionAwareMemoryCache
+from common.session_aware_cache import WebSessionAwareMemoryCache
 
 
 vn = create_vanna_instance()
 vn = create_vanna_instance()
 
 
 # 创建带时间戳的缓存
 # 创建带时间戳的缓存
-timestamped_cache = SessionAwareMemoryCache()
+timestamped_cache = WebSessionAwareMemoryCache()
 
 
 # 实例化 VannaFlaskApp,使用自定义缓存
 # 实例化 VannaFlaskApp,使用自定义缓存
 app = VannaFlaskApp(
 app = VannaFlaskApp(
@@ -46,7 +46,7 @@ def ask_full():
         )
         )
 
 
     try:
     try:
-        sql, df, fig = vn.ask(
+        sql, df, _ = vn.ask(
             question=question,
             question=question,
             print_results=False,
             print_results=False,
             visualize=False,
             visualize=False,
@@ -84,6 +84,123 @@ def ask_full():
             code=500
             code=500
         )), 500
         )), 500
 
 
+@app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
+def citu_run_sql():
+    req = request.get_json(force=True)
+    sql = req.get('sql')
+    
+    if not sql:
+        return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
+    
+    try:
+        df = vn.run_sql(sql)
+        
+        rows, columns = [], []
+        
+        if isinstance(df, pd.DataFrame) and not df.empty:
+            rows = df.head(1000).to_dict(orient="records")
+            columns = list(df.columns)
+        
+        return jsonify(result.success(data={
+            "sql": sql,
+            "rows": rows,
+            "columns": columns
+        }))
+        
+    except Exception as e:
+        print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
+        return jsonify(result.failed(
+            message=f"SQL执行失败: {str(e)}", 
+            code=500
+        )), 500
+
+@app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
+def ask_cached():
+    """
+    带缓存功能的智能查询接口
+    支持会话管理和结果缓存,提高查询效率
+    """
+    req = request.get_json(force=True)
+    question = req.get("question", None)
+    browser_session_id = req.get("session_id", None)
+    
+    if not question:
+        return jsonify(result.failed(message="未提供问题", code=400)), 400
+
+    try:
+        # 生成conversation_id
+        # 调试:查看generate_id的实际行为
+        print(f"[DEBUG] 输入问题: '{question}'")
+        conversation_id = app.cache.generate_id(question=question)
+        print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
+        
+        # 再次用相同问题测试
+        conversation_id2 = app.cache.generate_id(question=question)
+        print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
+        print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
+        
+        # 检查缓存
+        cached_sql = app.cache.get(id=conversation_id, field="sql")
+        
+        if cached_sql is not None:
+            # 缓存命中
+            print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
+            sql = cached_sql
+            df = app.cache.get(id=conversation_id, field="df")
+            summary = app.cache.get(id=conversation_id, field="summary")
+        else:
+            # 缓存未命中,执行新查询
+            print(f"[CACHE MISS] 执行新查询: {conversation_id}")
+            
+            sql, df, _ = vn.ask(
+                question=question,
+                print_results=False,
+                visualize=False,
+                allow_llm_to_see_data=True
+            )
+            
+            # 缓存结果
+            app.cache.set(id=conversation_id, field="question", value=question)
+            app.cache.set(id=conversation_id, field="sql", value=sql)
+            app.cache.set(id=conversation_id, field="df", value=df)
+            
+            # 生成并缓存摘要
+            summary = None
+            if isinstance(df, pd.DataFrame) and not df.empty:
+                try:
+                    summary = vn.generate_summary(question=question, df=df)
+                    print(f"[INFO] 成功生成摘要: {summary}")
+                except Exception as e:
+                    print(f"[WARNING] 生成摘要失败: {str(e)}")
+                    summary = None
+            
+            app.cache.set(id=conversation_id, field="summary", value=summary)
+
+        # 处理返回数据
+        rows, columns = [], []
+        
+        if isinstance(df, pd.DataFrame) and not df.empty:
+            rows = df.head(1000).to_dict(orient="records")
+            columns = list(df.columns)
+
+        return jsonify(result.success(data={
+            "sql": sql,
+            "rows": rows,
+            "columns": columns,
+            "summary": summary,
+            "conversation_id": conversation_id,
+            "session_id": browser_session_id,
+            "cached": cached_sql is not None  # 标识是否来自缓存
+        }))
+        
+    except Exception as e:
+        print(f"[ERROR] ask_cached执行失败: {str(e)}")
+        return jsonify(result.failed(
+            message=f"查询处理失败: {str(e)}", 
+            code=500
+        )), 500
+    
+
 
 
 @app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
 @app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
 def citu_train_question_sql():
 def citu_train_question_sql():