Browse Source

给ask_agent() 增加API_MAX_RETURN_ROWS参数.

wangxq 2 weeks ago
parent
commit
303446ea75
5 changed files with 98 additions and 42 deletions
  1. 39 15
      agent/citu_agent.py
  2. 7 2
      agent/tools/sql_execution.py
  3. 7 2
      app_config.py
  4. 35 20
      citu_app.py
  5. 10 3
      docs/app_config参数说明.md

+ 39 - 15
agent/citu_agent.py

@@ -9,6 +9,7 @@ from agent.state import AgentState
 from agent.classifier import QuestionClassifier
 from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
 from agent.utils import get_compatible_llm
+from app_config import ENABLE_RESULT_SUMMARY
 
 class CituLangGraphAgent:
     """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
@@ -136,7 +137,7 @@ class CituLangGraphAgent:
             
             # 步骤2:执行SQL
             print(f"[DATABASE_AGENT] 步骤2:执行SQL")
-            execute_result = execute_sql.invoke({"sql": sql, "max_rows": 200})
+            execute_result = execute_sql.invoke({"sql": sql})
             
             if not execute_result.get("success"):
                 print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
@@ -150,21 +151,25 @@ class CituLangGraphAgent:
             state["data_result"] = data_result
             print(f"[DATABASE_AGENT] SQL执行成功,返回 {data_result.get('row_count', 0)} 行数据")
             
-            # 步骤3:生成摘要
-            print(f"[DATABASE_AGENT] 步骤3:生成摘要")
-            summary_result = generate_summary.invoke({
-                "question": question,
-                "data_result": data_result,
-                "sql": sql
-            })
-            
-            if not summary_result.get("success"):
-                print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
-                # 摘要生成失败不是致命错误,使用默认摘要
-                state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
+            # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
+            if ENABLE_RESULT_SUMMARY and data_result.get('row_count', 0) > 0:
+                print(f"[DATABASE_AGENT] 步骤3:生成摘要")
+                summary_result = generate_summary.invoke({
+                    "question": question,
+                    "data_result": data_result,
+                    "sql": sql
+                })
+                
+                if not summary_result.get("success"):
+                    print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
+                    # 摘要生成失败不是致命错误,使用默认摘要
+                    state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
+                else:
+                    state["summary"] = summary_result.get("summary")
+                    print(f"[DATABASE_AGENT] 摘要生成成功")
             else:
-                state["summary"] = summary_result.get("summary")
-                print(f"[DATABASE_AGENT] 摘要生成成功")
+                print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={data_result.get('row_count', 0)})")
+                # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
             
             state["current_step"] = "database_completed"
             state["execution_path"].append("agent_database")
@@ -267,6 +272,25 @@ class CituLangGraphAgent:
                             "method": state["classification_method"]
                         }
                     }
+                elif state.get("data_result"):
+                    # 有数据但没有摘要(摘要被配置禁用)
+                    data_result = state.get("data_result")
+                    row_count = data_result.get("row_count", 0)
+                    
+                    # 构建基本响应,不包含summary字段
+                    state["final_response"] = {
+                        "success": True,
+                        "response": f"查询执行完成,共返回 {row_count} 条记录。",
+                        "type": "DATABASE",
+                        "sql": state.get("sql"),
+                        "data_result": data_result,
+                        "execution_path": state["execution_path"],
+                        "classification_info": {
+                            "confidence": state["classification_confidence"],
+                            "reason": state["classification_reason"],
+                            "method": state["classification_method"]
+                        }
+                    }
                 else:
                     # 数据库查询失败,没有任何结果
                     state["final_response"] = {

+ 7 - 2
agent/tools/sql_execution.py

@@ -5,6 +5,7 @@ import pandas as pd
 import time
 import functools
 from common.vanna_instance import get_vanna_instance
+from app_config import API_MAX_RETURN_ROWS
 
 def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: float = 2.0):
     """
@@ -52,13 +53,13 @@ def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: f
 
 @tool
 @retry_on_failure(max_retries=2)
-def execute_sql(sql: str, max_rows: int = 200) -> Dict[str, Any]:
+def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
     """
     执行SQL查询并返回结果。
     
     Args:
         sql: 要执行的SQL查询语句
-        max_rows: 最大返回行数,默认200
+        max_rows: 最大返回行数,默认使用API_MAX_RETURN_ROWS配置
         
     Returns:
         包含查询结果的字典,格式:
@@ -69,6 +70,10 @@ def execute_sql(sql: str, max_rows: int = 200) -> Dict[str, Any]:
             "can_retry": bool
         }
     """
+    # 设置默认的最大返回行数,与ask()接口保持一致
+    DEFAULT_MAX_RETURN_ROWS = 200
+    if max_rows is None:
+        max_rows = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
     try:
         print(f"[TOOL:execute_sql] 开始执行SQL: {sql[:100]}...")
         

+ 7 - 2
app_config.py

@@ -37,12 +37,12 @@ API_DEEPSEEK_CONFIG = {
 API_QIANWEN_CONFIG = {
     "api_key": os.getenv("QWEN_API_KEY"),  # 从环境变量读取API密钥
     "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",  # 千问API地址
-    "model": "qwen-plus",
+    "model": "qwen3-235b-a22b",
     "allow_llm_to_see_data": True,
     "temperature": 0.6,
     "n_results": 6,
     "language": "Chinese",
-    "stream": False,  # 是否使用流式模式
+    "stream": True,  # 是否使用流式模式
     "enable_thinking": False  # 是否启用思考功能(要求stream=True)
 }
 #qwen3-30b-a3b
@@ -134,6 +134,11 @@ TRAINING_DATA_PATH = "./training/data"
 # 是否启用问题重写功能,也就是上下文问题合并。
 REWRITE_QUESTION_ENABLED = False
 
+# 是否启用数据库查询结果摘要生成
+# True: 执行完SQL后生成摘要(默认)
+# False: 只返回SQL执行结果,跳过摘要生成,节省LLM调用
+ENABLE_RESULT_SUMMARY = True
+
 # 是否在返回结果中显示thinking过程
 # True: 显示 <think></think> 内容
 # False: 隐藏 <think></think> 内容,只显示最终答案

+ 35 - 20
citu_app.py

@@ -6,7 +6,7 @@ import pandas as pd
 import common.result as result
 from datetime import datetime, timedelta
 from common.session_aware_cache import WebSessionAwareMemoryCache
-from app_config import API_MAX_RETURN_ROWS
+from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
 import re
 import chainlit as cl
 import json
@@ -99,26 +99,34 @@ def ask_full():
         rows, columns = [], []
         summary = None
         
-        if isinstance(df, pd.DataFrame) and not df.empty:
-            rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
+        if isinstance(df, pd.DataFrame):
+            if not df.empty:
+                rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
             columns = list(df.columns)
             
-            # 生成数据摘要(thinking内容已在base_llm_chat.py中统一处理)
-            try:
-                summary = vn.generate_summary(question=question, df=df)
-                print(f"[INFO] 成功生成摘要: {summary}")
-            except Exception as e:
-                print(f"[WARNING] 生成摘要失败: {str(e)}")
-                summary = None
+            # 生成数据摘要(可通过配置控制,仅在有数据时生成)
+            if ENABLE_RESULT_SUMMARY and not df.empty:
+                try:
+                    summary = vn.generate_summary(question=question, df=df)
+                    print(f"[INFO] 成功生成摘要: {summary}")
+                except Exception as e:
+                    print(f"[WARNING] 生成摘要失败: {str(e)}")
+                    summary = None
 
-        return jsonify(result.success(data={
+        # 构建返回数据,根据摘要配置决定是否包含summary字段
+        response_data = {
             "sql": sql,
             "rows": rows,
             "columns": columns,
-            "summary": summary,  # 添加摘要到返回结果
             "conversation_id": conversation_id if 'conversation_id' in locals() else None,
             "session_id": browser_session_id
-        }))
+        }
+        
+        # 只有启用摘要且确实生成了摘要时才添加summary字段
+        if ENABLE_RESULT_SUMMARY and summary is not None:
+            response_data["summary"] = summary
+            
+        return jsonify(result.success(data=response_data))
         
     except Exception as e:
         print(f"[ERROR] ask_full执行失败: {str(e)}")
@@ -262,9 +270,9 @@ def ask_cached():
             app.cache.set(id=conversation_id, field="sql", value=sql)
             app.cache.set(id=conversation_id, field="df", value=df)
             
-            # 生成并缓存摘要
+            # 生成并缓存摘要(可通过配置控制,仅在有数据时生成)
             summary = None
-            if isinstance(df, pd.DataFrame) and not df.empty:
+            if ENABLE_RESULT_SUMMARY and isinstance(df, pd.DataFrame) and not df.empty:
                 try:
                     summary = vn.generate_summary(question=question, df=df)
                     print(f"[INFO] 成功生成摘要: {summary}")
@@ -277,19 +285,26 @@ def ask_cached():
         # 处理返回数据
         rows, columns = [], []
         
-        if isinstance(df, pd.DataFrame) and not df.empty:
-            rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
+        if isinstance(df, pd.DataFrame):
+            if not df.empty:
+                rows = df.head(MAX_RETURN_ROWS).to_dict(orient="records")
             columns = list(df.columns)
 
-        return jsonify(result.success(data={
+        # 构建返回数据,根据摘要配置决定是否包含summary字段
+        response_data = {
             "sql": sql,
             "rows": rows,
             "columns": columns,
-            "summary": summary,
             "conversation_id": conversation_id,
             "session_id": browser_session_id,
             "cached": cached_sql is not None  # 标识是否来自缓存
-        }))
+        }
+        
+        # 只有启用摘要且确实生成了摘要时才添加summary字段
+        if ENABLE_RESULT_SUMMARY and summary is not None:
+            response_data["summary"] = summary
+            
+        return jsonify(result.success(data=response_data))
         
     except Exception as e:
         print(f"[ERROR] ask_cached执行失败: {str(e)}")

+ 10 - 3
docs/app_config参数说明.md

@@ -235,14 +235,21 @@
   - 样例值:`False`
   - 功能:将上下文问题合并优化
 
-### 2. 思考过程显示
-- **`DISPLAY_SUMMARY_THINKING`**: 是否在摘要中显示思考过程
+### 2. 结果摘要生成
+- **`ENABLE_RESULT_SUMMARY`**: 是否启用数据库查询结果摘要生成
+  - 可选值:`True` 或 `False`
+  - 默认值:`True`
+  - 样例值:`True`
+  - 功能:控制是否为数据库查询结果生成自然语言摘要。禁用时可节省LLM调用,仅影响数据库查询,不影响一般聊天
+
+### 3. 思考过程显示
+- **`DISPLAY_RESULT_THINKING`**: 是否在返回结果中显示思考过程
   - 可选值:`True` 或 `False`
   - 默认值:`False`
   - 样例值:`False`
   - 功能:控制是否显示 `<think></think>` 标签内容
 
-### 3. SQL错误修正
+### 4. SQL错误修正
 - **`ENABLE_ERROR_SQL_PROMPT`**: 是否启用SQL错误修正提示
   - 可选值:`True` 或 `False`
   - 默认值:`True`