Bläddra i källkod

增加了citu_run_sql和ask_cache

wangxq 4 veckor sedan
förälder
incheckning
fc1521efad
1 ändrade filer med 120 tillägg och 3 borttagningar
  1. 120 3
      citu_app.py

+ 120 - 3
citu_app.py

@@ -5,12 +5,12 @@ from flask import request, jsonify
 import pandas as pd
 import common.result as result
 from datetime import datetime, timedelta
-from common.session_aware_cache import SessionAwareMemoryCache
+from common.session_aware_cache import WebSessionAwareMemoryCache
 
 vn = create_vanna_instance()
 
 # 创建带时间戳的缓存
-timestamped_cache = SessionAwareMemoryCache()
+timestamped_cache = WebSessionAwareMemoryCache()
 
 # 实例化 VannaFlaskApp,使用自定义缓存
 app = VannaFlaskApp(
@@ -46,7 +46,7 @@ def ask_full():
         )
 
     try:
-        sql, df, fig = vn.ask(
+        sql, df, _ = vn.ask(
             question=question,
             print_results=False,
             visualize=False,
@@ -84,6 +84,123 @@ def ask_full():
             code=500
         )), 500
 
+@app.flask_app.route('/api/v0/citu_run_sql', methods=['POST'])
+def citu_run_sql():
+    req = request.get_json(force=True)
+    sql = req.get('sql')
+    
+    if not sql:
+        return jsonify(result.failed(message="未提供SQL查询", code=400)), 400
+    
+    try:
+        df = vn.run_sql(sql)
+        
+        rows, columns = [], []
+        
+        if isinstance(df, pd.DataFrame) and not df.empty:
+            rows = df.head(1000).to_dict(orient="records")
+            columns = list(df.columns)
+        
+        return jsonify(result.success(data={
+            "sql": sql,
+            "rows": rows,
+            "columns": columns
+        }))
+        
+    except Exception as e:
+        print(f"[ERROR] citu_run_sql执行失败: {str(e)}")
+        return jsonify(result.failed(
+            message=f"SQL执行失败: {str(e)}", 
+            code=500
+        )), 500
+
+@app.flask_app.route('/api/v0/ask_cached', methods=['POST'])
+def ask_cached():
+    """
+    带缓存功能的智能查询接口
+    支持会话管理和结果缓存,提高查询效率
+    """
+    req = request.get_json(force=True)
+    question = req.get("question", None)
+    browser_session_id = req.get("session_id", None)
+    
+    if not question:
+        return jsonify(result.failed(message="未提供问题", code=400)), 400
+
+    try:
+        # 生成conversation_id
+        # 调试:查看generate_id的实际行为
+        print(f"[DEBUG] 输入问题: '{question}'")
+        conversation_id = app.cache.generate_id(question=question)
+        print(f"[DEBUG] 生成的conversation_id: {conversation_id}")
+        
+        # 再次用相同问题测试
+        conversation_id2 = app.cache.generate_id(question=question)
+        print(f"[DEBUG] 再次生成的conversation_id: {conversation_id2}")
+        print(f"[DEBUG] 两次ID是否相同: {conversation_id == conversation_id2}")
+        
+        # 检查缓存
+        cached_sql = app.cache.get(id=conversation_id, field="sql")
+        
+        if cached_sql is not None:
+            # 缓存命中
+            print(f"[CACHE HIT] 使用缓存结果: {conversation_id}")
+            sql = cached_sql
+            df = app.cache.get(id=conversation_id, field="df")
+            summary = app.cache.get(id=conversation_id, field="summary")
+        else:
+            # 缓存未命中,执行新查询
+            print(f"[CACHE MISS] 执行新查询: {conversation_id}")
+            
+            sql, df, _ = vn.ask(
+                question=question,
+                print_results=False,
+                visualize=False,
+                allow_llm_to_see_data=True
+            )
+            
+            # 缓存结果
+            app.cache.set(id=conversation_id, field="question", value=question)
+            app.cache.set(id=conversation_id, field="sql", value=sql)
+            app.cache.set(id=conversation_id, field="df", value=df)
+            
+            # 生成并缓存摘要
+            summary = None
+            if isinstance(df, pd.DataFrame) and not df.empty:
+                try:
+                    summary = vn.generate_summary(question=question, df=df)
+                    print(f"[INFO] 成功生成摘要: {summary}")
+                except Exception as e:
+                    print(f"[WARNING] 生成摘要失败: {str(e)}")
+                    summary = None
+            
+            app.cache.set(id=conversation_id, field="summary", value=summary)
+
+        # 处理返回数据
+        rows, columns = [], []
+        
+        if isinstance(df, pd.DataFrame) and not df.empty:
+            rows = df.head(1000).to_dict(orient="records")
+            columns = list(df.columns)
+
+        return jsonify(result.success(data={
+            "sql": sql,
+            "rows": rows,
+            "columns": columns,
+            "summary": summary,
+            "conversation_id": conversation_id,
+            "session_id": browser_session_id,
+            "cached": cached_sql is not None  # 标识是否来自缓存
+        }))
+        
+    except Exception as e:
+        print(f"[ERROR] ask_cached执行失败: {str(e)}")
+        return jsonify(result.failed(
+            message=f"查询处理失败: {str(e)}", 
+            code=500
+        )), 500
+    
+
 
 @app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
 def citu_train_question_sql():