Bläddra i källkod

第三次修改返回结果规范化,替换剩余的data_result.

wangxq 1 vecka sedan
förälder
incheckning
58c7729f20
5 ändrade filer med 33 tillägg och 33 borttagningar
  1. 14 14
      agent/citu_agent.py
  2. 1 1
      agent/state.py
  3. 1 1
      agent/tools/sql_execution.py
  4. 15 15
      agent/tools/summary_generation.py
  5. 2 2
      citu_app.py

+ 14 - 14
agent/citu_agent.py

@@ -229,28 +229,28 @@ class CituLangGraphAgent:
                 state["execution_path"].append("agent_database_error")
                 return state
             
-            data_result = execute_result.get("data_result")
-            state["data_result"] = data_result
-            print(f"[DATABASE_AGENT] SQL执行成功,返回 {data_result.get('row_count', 0)} 行数据")
+            query_result = execute_result.get("data_result")
+            state["query_result"] = query_result
+            print(f"[DATABASE_AGENT] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
             
             # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
-            if ENABLE_RESULT_SUMMARY and data_result.get('row_count', 0) > 0:
+            if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
                 print(f"[DATABASE_AGENT] 步骤3:生成摘要")
                 summary_result = generate_summary.invoke({
                     "question": question,
-                    "data_result": data_result,
+                    "query_result": query_result,
                     "sql": sql
                 })
                 
                 if not summary_result.get("success"):
                     print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
                     # 摘要生成失败不是致命错误,使用默认摘要
-                    state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
+                    state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
                 else:
                     state["summary"] = summary_result.get("summary")
                     print(f"[DATABASE_AGENT] 摘要生成成功")
             else:
-                print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={data_result.get('row_count', 0)})")
+                print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
                 # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
             
             state["current_step"] = "database_completed"
@@ -345,7 +345,7 @@ class CituLangGraphAgent:
                         "response": state["chat_response"],
                         "type": "DATABASE",
                         "sql": state.get("sql"),
-                        "query_result": state.get("data_result"),  # 字段重命名:data_result → query_result
+                        "query_result": state.get("query_result"),  # 获取query_result字段
                         "execution_path": state["execution_path"],
                         "classification_info": {
                             "confidence": state["classification_confidence"],
@@ -360,7 +360,7 @@ class CituLangGraphAgent:
                         "success": True,
                         "type": "DATABASE",
                         "sql": state.get("sql"),
-                        "query_result": state.get("data_result"),  # 字段重命名:data_result → query_result
+                        "query_result": state.get("query_result"),  # 获取query_result字段
                         "summary": state["summary"],
                         "execution_path": state["execution_path"],
                         "classification_info": {
@@ -369,10 +369,10 @@ class CituLangGraphAgent:
                             "method": state["classification_method"]
                         }
                     }
-                elif state.get("data_result"):
+                elif state.get("query_result"):
                     # 有数据但没有摘要(摘要被配置禁用)
-                    data_result = state.get("data_result")
-                    row_count = data_result.get("row_count", 0)
+                    query_result = state.get("query_result")
+                    row_count = query_result.get("row_count", 0)
                     
                     # 构建基本响应,不包含summary字段和response字段
                     # 用户应该直接从query_result.columns和query_result.rows获取数据
@@ -380,7 +380,7 @@ class CituLangGraphAgent:
                         "success": True,
                         "type": "DATABASE",
                         "sql": state.get("sql"),
-                        "query_result": data_result,  # 字段重命名:data_result → query_result
+                        "query_result": query_result,  # 获取query_result字段
                         "execution_path": state["execution_path"],
                         "classification_info": {
                             "confidence": state["classification_confidence"],
@@ -509,7 +509,7 @@ class CituLangGraphAgent:
             # 数据库查询流程状态
             sql=None,
             sql_generation_attempts=0,
-            data_result=None,
+            query_result=None,
             summary=None,
             
             # 聊天响应

+ 1 - 1
agent/state.py

@@ -18,7 +18,7 @@ class AgentState(TypedDict):
     # 数据库查询流程状态
     sql: Optional[str]
     sql_generation_attempts: int
-    data_result: Optional[Dict[str, Any]]
+    query_result: Optional[Dict[str, Any]]
     summary: Optional[str]
     
     # 聊天响应

+ 1 - 1
agent/tools/sql_execution.py

@@ -65,7 +65,7 @@ def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
         包含查询结果的字典,格式:
         {
             "success": bool,
-            "data_result": dict或None,
+            "data_result": dict或None,  # 注意:工具内部仍使用data_result,但会被Agent重命名为query_result
             "error": str或None,
             "can_retry": bool
         }

+ 15 - 15
agent/tools/summary_generation.py

@@ -5,13 +5,13 @@ import pandas as pd
 from common.vanna_instance import get_vanna_instance
 
 @tool
-def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Dict[str, Any]:
+def generate_summary(question: str, query_result: Dict[str, Any], sql: str) -> Dict[str, Any]:
     """
     为查询结果生成自然语言摘要。
     
     Args:
         question: 原始问题
-        data_result: 查询结果数据
+        query_result: 查询结果数据
         sql: 执行的SQL语句
         
     Returns:
@@ -25,7 +25,7 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
     try:
         print(f"[TOOL:generate_summary] 开始生成摘要,问题: {question}")
         
-        if not data_result or not data_result.get("rows"):
+        if not query_result or not query_result.get("rows"):
             return {
                 "success": True,
                 "summary": "查询执行完成,但没有找到符合条件的数据。",
@@ -33,7 +33,7 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
             }
         
         # 重构DataFrame用于摘要生成
-        df = _reconstruct_dataframe(data_result)
+        df = _reconstruct_dataframe(query_result)
         
         if df is None or df.empty:
             return {
@@ -48,7 +48,7 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
         
         if summary is None:
             # 生成默认摘要
-            summary = _generate_default_summary(question, data_result, sql)
+            summary = _generate_default_summary(question, query_result, sql)
         
         print(f"[TOOL:generate_summary] 摘要生成成功: {summary[:100]}...")
         
@@ -62,7 +62,7 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
         print(f"[ERROR] 摘要生成异常: {str(e)}")
         
         # 生成备用摘要
-        fallback_summary = _generate_fallback_summary(question, data_result, sql)
+        fallback_summary = _generate_fallback_summary(question, query_result, sql)
         
         return {
             "success": True,  # 即使异常也返回成功,因为有备用摘要
@@ -70,11 +70,11 @@ def generate_summary(question: str, data_result: Dict[str, Any], sql: str) -> Di
             "message": f"使用备用摘要生成: {str(e)}"
         }
 
-def _reconstruct_dataframe(data_result: Dict[str, Any]) -> pd.DataFrame:
+def _reconstruct_dataframe(query_result: Dict[str, Any]) -> pd.DataFrame:
     """从查询结果重构DataFrame"""
     try:
-        rows = data_result.get("rows", [])
-        columns = data_result.get("columns", [])
+        rows = query_result.get("rows", [])
+        columns = query_result.get("columns", [])
         
         if not rows or not columns:
             return pd.DataFrame()
@@ -85,11 +85,11 @@ def _reconstruct_dataframe(data_result: Dict[str, Any]) -> pd.DataFrame:
         print(f"[WARNING] DataFrame重构失败: {str(e)}")
         return pd.DataFrame()
 
-def _generate_default_summary(question: str, data_result: Dict[str, Any], sql: str) -> str:
+def _generate_default_summary(question: str, query_result: Dict[str, Any], sql: str) -> str:
     """生成默认摘要"""
     try:
-        row_count = data_result.get("row_count", 0)
-        columns = data_result.get("columns", [])
+        row_count = query_result.get("row_count", 0)
+        columns = query_result.get("columns", [])
         
         if row_count == 0:
             return "查询执行完成,但没有找到符合条件的数据。"
@@ -102,11 +102,11 @@ def _generate_default_summary(question: str, data_result: Dict[str, Any], sql: s
         return ' '.join(summary_parts)
         
     except Exception:
-        return f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
+        return f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
 
-def _generate_fallback_summary(question: str, data_result: Dict[str, Any], sql: str) -> str:
+def _generate_fallback_summary(question: str, query_result: Dict[str, Any], sql: str) -> str:
     """生成备用摘要"""
-    row_count = data_result.get("row_count", 0)
+    row_count = query_result.get("row_count", 0)
     
     if row_count == 0:
         return "查询执行完成,但没有找到符合条件的数据。请检查查询条件是否正确。"

+ 2 - 2
citu_app.py

@@ -435,7 +435,7 @@ def ask_agent():
             "response": "最终回答",
             "type": "DATABASE/CHAT",
             "sql": "生成的SQL(如果是数据库查询)",
-            "data_result": {
+            "query_result": {
                 "rows": [...],
                 "columns": [...],
                 "row_count": 数字
@@ -491,7 +491,7 @@ def ask_agent():
                 classification_info=agent_result.get("classification_info", {}),
                 response=agent_result.get("response", ""),
                 sql=agent_result.get("sql"),
-                query_result=agent_result.get("data_result"),  # 字段重命名:data_result → query_result
+                query_result=agent_result.get("query_result"),  # 修复:从query_result字段获取数据
                 summary=agent_result.get("summary")
             ))
         else: