Преглед на файлове

对database agent的改造完成,拆分为两个节点:_agent_sql_generation_node/_agent_sql_execution_node.部分测试通过.

wangxq преди 1 седмица
родител
ревизия
ecc0fdbb71
променени са 5 файла, в които са добавени 586 реда и са изтрити 27 реда
  1. 528 14
      agent/citu_agent.py
  2. 27 0
      agent/config.py
  3. 8 0
      agent/state.py
  4. 2 2
      app_config.py
  5. 21 11
      citu_app.py

+ 528 - 14
agent/citu_agent.py

@@ -53,14 +53,25 @@ class CituLangGraphAgent:
         
         # 根据路由模式创建不同的工作流
         if QUESTION_ROUTING_MODE == "database_direct":
-            # 直接数据库模式:跳过分类,直接进入数据库处理
+            # 直接数据库模式:跳过分类,直接进入数据库处理(使用新的拆分节点)
             workflow.add_node("init_direct_database", self._init_direct_database_node)
-            workflow.add_node("agent_database", self._agent_database_node)
+            workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
+            workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
             workflow.add_node("format_response", self._format_response_node)
             
             workflow.set_entry_point("init_direct_database")
-            workflow.add_edge("init_direct_database", "agent_database")
-            workflow.add_edge("agent_database", "format_response")
+            
+            # 添加条件路由
+            workflow.add_edge("init_direct_database", "agent_sql_generation")
+            workflow.add_conditional_edges(
+                "agent_sql_generation",
+                self._route_after_sql_generation,
+                {
+                    "continue_execution": "agent_sql_execution",
+                    "return_to_user": "format_response"
+                }
+            )
+            workflow.add_edge("agent_sql_execution", "format_response")
             workflow.add_edge("format_response", END)
             
         elif QUESTION_ROUTING_MODE == "chat_direct":
@@ -75,10 +86,11 @@ class CituLangGraphAgent:
             workflow.add_edge("format_response", END)
             
         else:
-            # 其他模式(hybrid, llm_only):使用原有的分类工作流
+            # 其他模式(hybrid, llm_only):使用新的拆分工作流
             workflow.add_node("classify_question", self._classify_question_node)
             workflow.add_node("agent_chat", self._agent_chat_node)
-            workflow.add_node("agent_database", self._agent_database_node)
+            workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
+            workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
             workflow.add_node("format_response", self._format_response_node)
             
             workflow.set_entry_point("classify_question")
@@ -88,13 +100,24 @@ class CituLangGraphAgent:
                 "classify_question",
                 self._route_after_classification,
                 {
-                    "DATABASE": "agent_database",
+                    "DATABASE": "agent_sql_generation",
                     "CHAT": "agent_chat"
                 }
             )
             
+            # 添加条件边:SQL生成后的路由
+            workflow.add_conditional_edges(
+                "agent_sql_generation",
+                self._route_after_sql_generation,
+                {
+                    "continue_execution": "agent_sql_execution",
+                    "return_to_user": "format_response"
+                }
+            )
+            
+            # 普通边
             workflow.add_edge("agent_chat", "format_response")
-            workflow.add_edge("agent_database", "format_response")
+            workflow.add_edge("agent_sql_execution", "format_response")
             workflow.add_edge("format_response", END)
         
         return workflow.compile()
@@ -188,9 +211,250 @@ class CituLangGraphAgent:
             state["execution_path"].append("classify_error")
             return state
         
+    async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
+        """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""
+        try:
+            print(f"[SQL_GENERATION] 开始处理SQL生成和验证: {state['question']}")
+            
+            question = state["question"]
+            
+            # 步骤1:生成SQL
+            print(f"[SQL_GENERATION] 步骤1:生成SQL")
+            sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
+            
+            if not sql_result.get("success"):
+                # SQL生成失败的统一处理
+                error_message = sql_result.get("error", "")
+                error_type = sql_result.get("error_type", "")
+                
+                print(f"[SQL_GENERATION] SQL生成失败: {error_message}")
+                
+                # 根据错误类型生成用户提示
+                if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
+                    user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
+                    failure_reason = "missing_database_info"
+                elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
+                    user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
+                    failure_reason = "ambiguous_question"
+                elif error_type == "llm_explanation":
+                    user_prompt = error_message + " 请尝试重新描述您的问题或询问其他内容。"
+                    failure_reason = "llm_explanation"
+                else:
+                    user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
+                    failure_reason = "unknown_generation_failure"
+                
+                # 统一返回失败状态
+                state["sql_generation_success"] = False
+                state["user_prompt"] = user_prompt
+                state["validation_error_type"] = failure_reason
+                state["current_step"] = "sql_generation_failed"
+                state["execution_path"].append("agent_sql_generation_failed")
+                
+                print(f"[SQL_GENERATION] 生成失败: {failure_reason} - {user_prompt}")
+                return state
+            
+            sql = sql_result.get("sql")
+            state["sql"] = sql
+            print(f"[SQL_GENERATION] SQL生成成功: {sql}")
+            
+            # 步骤1.5:检查是否为解释性响应而非SQL
+            error_type = sql_result.get("error_type")
+            if error_type == "llm_explanation":
+                # LLM返回了解释性文本,直接作为最终答案
+                explanation = sql_result.get("error", "")
+                state["chat_response"] = explanation + " 请尝试提问其它问题。"
+                state["sql_generation_success"] = False
+                state["validation_error_type"] = "llm_explanation"
+                state["current_step"] = "sql_generation_completed"
+                state["execution_path"].append("agent_sql_generation")
+                print(f"[SQL_GENERATION] 返回LLM解释性答案: {explanation}")
+                return state
+            
+            # 额外验证:检查SQL格式(防止工具误判)
+            from agent.utils import _is_valid_sql_format
+            if not _is_valid_sql_format(sql):
+                # 内容看起来不是SQL,当作解释性响应处理
+                state["chat_response"] = sql + " 请尝试提问其它问题。"
+                state["sql_generation_success"] = False
+                state["validation_error_type"] = "invalid_sql_format"
+                state["current_step"] = "sql_generation_completed"  
+                state["execution_path"].append("agent_sql_generation")
+                print(f"[SQL_GENERATION] 内容不是有效SQL,当作解释返回: {sql}")
+                return state
+            
+            # 步骤2:SQL验证(如果启用)
+            if self._is_sql_validation_enabled():
+                print(f"[SQL_GENERATION] 步骤2:验证SQL")
+                validation_result = await self._validate_sql_with_custom_priority(sql)
+                
+                if not validation_result.get("valid"):
+                    # 验证失败,检查是否可以修复
+                    error_type = validation_result.get("error_type")
+                    error_message = validation_result.get("error_message")
+                    can_repair = validation_result.get("can_repair", False)
+                    
+                    print(f"[SQL_GENERATION] SQL验证失败: {error_type} - {error_message}")
+                    
+                    if error_type == "forbidden_keywords":
+                        # 禁止词错误,直接失败,不尝试修复
+                        state["sql_generation_success"] = False
+                        state["sql_validation_success"] = False
+                        state["user_prompt"] = error_message
+                        state["validation_error_type"] = "forbidden_keywords"
+                        state["current_step"] = "sql_validation_failed"
+                        state["execution_path"].append("forbidden_keywords_failed")
+                        print(f"[SQL_GENERATION] 禁止词验证失败,直接结束")
+                        return state
+                    
+                    elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():
+                        # 语法错误,尝试修复(仅一次)
+                        print(f"[SQL_GENERATION] 尝试修复SQL语法错误(仅一次): {error_message}")
+                        state["sql_repair_attempted"] = True
+                        
+                        repair_result = await self._attempt_sql_repair_once(sql, error_message)
+                        
+                        if repair_result.get("success"):
+                            # 修复成功
+                            repaired_sql = repair_result.get("repaired_sql")
+                            state["sql"] = repaired_sql
+                            state["sql_generation_success"] = True
+                            state["sql_validation_success"] = True
+                            state["sql_repair_success"] = True
+                            state["current_step"] = "sql_generation_completed"
+                            state["execution_path"].append("sql_repair_success")
+                            print(f"[SQL_GENERATION] SQL修复成功: {repaired_sql}")
+                            return state
+                        else:
+                            # 修复失败,直接结束
+                            repair_error = repair_result.get("error", "修复失败")
+                            print(f"[SQL_GENERATION] SQL修复失败: {repair_error}")
+                            state["sql_generation_success"] = False
+                            state["sql_validation_success"] = False
+                            state["sql_repair_success"] = False
+                            state["user_prompt"] = f"SQL语法修复失败: {repair_error}"
+                            state["validation_error_type"] = "syntax_repair_failed"
+                            state["current_step"] = "sql_repair_failed"
+                            state["execution_path"].append("sql_repair_failed")
+                            return state
+                    else:
+                        # 不启用修复或其他错误类型,直接失败
+                        state["sql_generation_success"] = False
+                        state["sql_validation_success"] = False
+                        state["user_prompt"] = f"SQL验证失败: {error_message}"
+                        state["validation_error_type"] = error_type
+                        state["current_step"] = "sql_validation_failed"
+                        state["execution_path"].append("sql_validation_failed")
+                        print(f"[SQL_GENERATION] SQL验证失败,不尝试修复")
+                        return state
+                else:
+                    print(f"[SQL_GENERATION] SQL验证通过")
+                    state["sql_validation_success"] = True
+            else:
+                print(f"[SQL_GENERATION] 跳过SQL验证(未启用)")
+                state["sql_validation_success"] = True
+            
+            # 生成和验证都成功
+            state["sql_generation_success"] = True
+            state["current_step"] = "sql_generation_completed"
+            state["execution_path"].append("agent_sql_generation")
+            
+            print(f"[SQL_GENERATION] SQL生成验证完成,准备执行")
+            return state
+            
+        except Exception as e:
+            print(f"[ERROR] SQL生成验证节点异常: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            state["sql_generation_success"] = False
+            state["sql_validation_success"] = False
+            state["user_prompt"] = f"SQL生成验证异常: {str(e)}"
+            state["validation_error_type"] = "node_exception"
+            state["current_step"] = "sql_generation_error"
+            state["execution_path"].append("agent_sql_generation_error")
+            return state
+
+    def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
+        """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""
+        try:
+            print(f"[SQL_EXECUTION] 开始执行SQL: {state.get('sql', 'N/A')}")
+            
+            sql = state.get("sql")
+            question = state["question"]
+            
+            if not sql:
+                print(f"[SQL_EXECUTION] 没有可执行的SQL")
+                state["error"] = "没有可执行的SQL语句"
+                state["error_code"] = 500
+                state["current_step"] = "sql_execution_error"
+                state["execution_path"].append("agent_sql_execution_error")
+                return state
+            
+            # 步骤1:执行SQL
+            print(f"[SQL_EXECUTION] 步骤1:执行SQL")
+            execute_result = execute_sql.invoke({"sql": sql})
+            
+            if not execute_result.get("success"):
+                print(f"[SQL_EXECUTION] SQL执行失败: {execute_result.get('error')}")
+                state["error"] = execute_result.get("error", "SQL执行失败")
+                state["error_code"] = 500
+                state["current_step"] = "sql_execution_error"
+                state["execution_path"].append("agent_sql_execution_error")
+                return state
+            
+            query_result = execute_result.get("data_result")
+            state["query_result"] = query_result
+            print(f"[SQL_EXECUTION] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
+            
+            # 步骤2:生成摘要(根据配置和数据情况)
+            if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
+                print(f"[SQL_EXECUTION] 步骤2:生成摘要")
+                
+                # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
+                original_question = self._extract_original_question(question)
+                print(f"[SQL_EXECUTION] 原始问题: {original_question}")
+                
+                summary_result = generate_summary.invoke({
+                    "question": original_question,  # 使用原始问题而不是enhanced_question
+                    "query_result": query_result,
+                    "sql": sql
+                })
+                
+                if not summary_result.get("success"):
+                    print(f"[SQL_EXECUTION] 摘要生成失败: {summary_result.get('message')}")
+                    # 摘要生成失败不是致命错误,使用默认摘要
+                    state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
+                else:
+                    state["summary"] = summary_result.get("summary")
+                    print(f"[SQL_EXECUTION] 摘要生成成功")
+            else:
+                print(f"[SQL_EXECUTION] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
+                # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
+            
+            state["current_step"] = "sql_execution_completed"
+            state["execution_path"].append("agent_sql_execution")
+            
+            print(f"[SQL_EXECUTION] SQL执行完成")
+            return state
+            
+        except Exception as e:
+            print(f"[ERROR] SQL执行节点异常: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
+            state["error"] = f"SQL执行失败: {str(e)}"
+            state["error_code"] = 500
+            state["current_step"] = "sql_execution_error"
+            state["execution_path"].append("agent_sql_execution_error")
+            return state
+
     def _agent_database_node(self, state: AgentState) -> AgentState:
-        """数据库Agent节点 - 直接工具调用模式"""
+        """
+        数据库Agent节点 - 直接工具调用模式 [已废弃]
+        
+        注意:此方法已被拆分为 _agent_sql_generation_node 和 _agent_sql_execution_node
+        保留此方法仅为向后兼容,新的工作流使用拆分后的节点
+        """
         try:
+            print(f"[DATABASE_AGENT] ⚠️  使用已废弃的database节点,建议使用新的拆分节点")
             print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
             
             question = state["question"]
@@ -362,7 +626,30 @@ class CituLangGraphAgent:
             
             elif state["question_type"] == "DATABASE":
                 # 数据库查询类型
-                if state.get("chat_response"):
+                
+                # 处理SQL生成失败的情况
+                if not state.get("sql_generation_success", True) and state.get("user_prompt"):
+                    state["final_response"] = {
+                        "success": False,
+                        "response": state["user_prompt"],
+                        "type": "DATABASE",
+                        "sql_generation_failed": True,
+                        "validation_error_type": state.get("validation_error_type"),
+                        "sql": state.get("sql"),
+                        "execution_path": state["execution_path"],
+                        "classification_info": {
+                            "confidence": state["classification_confidence"],
+                            "reason": state["classification_reason"],
+                            "method": state["classification_method"]
+                        },
+                        "sql_validation_info": {
+                            "sql_generation_success": state.get("sql_generation_success", False),
+                            "sql_validation_success": state.get("sql_validation_success", False),
+                            "sql_repair_attempted": state.get("sql_repair_attempted", False),
+                            "sql_repair_success": state.get("sql_repair_success", False)
+                        }
+                    }
+                elif state.get("chat_response"):
                     # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
                     state["final_response"] = {
                         "success": True,
@@ -450,6 +737,23 @@ class CituLangGraphAgent:
             }
             return state
     
+    def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
+        """
+        SQL生成后的路由决策
+        
+        根据SQL生成和验证的结果决定后续流向:
+        - SQL生成验证成功 → 继续执行SQL
+        - SQL生成验证失败 → 直接返回用户提示
+        """
+        sql_generation_success = state.get("sql_generation_success", False)
+        
+        print(f"[ROUTE] SQL生成路由: success={sql_generation_success}")
+        
+        if sql_generation_success:
+            return "continue_execution"  # 路由到SQL执行节点
+        else:
+            return "return_to_user"      # 路由到format_response,结束流程
+
     def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
         """
         分类后的路由决策
@@ -472,7 +776,7 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
     
-    def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
+    async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
         """
         统一的问题处理入口
         
@@ -499,7 +803,7 @@ class CituLangGraphAgent:
             initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
             
             # 执行工作流
-            final_state = workflow.invoke(
+            final_state = await workflow.ainvoke(
                 initial_state,
                 config={
                     "configurable": {"session_id": session_id}
@@ -554,6 +858,14 @@ class CituLangGraphAgent:
             query_result=None,
             summary=None,
             
+            # SQL验证和修复相关状态
+            sql_generation_success=False,
+            sql_validation_success=False,
+            sql_repair_attempted=False,
+            sql_repair_success=False,
+            validation_error_type=None,
+            user_prompt=None,
+            
             # 聊天响应
             chat_response=None,
             
@@ -577,6 +889,208 @@ class CituLangGraphAgent:
             routing_mode=effective_routing_mode
         )
     
+    # ==================== SQL验证和修复相关方法 ====================
+    
+    def _is_sql_validation_enabled(self) -> bool:
+        """检查是否启用SQL验证"""
+        from agent.config import get_nested_config
+        return (get_nested_config(self.config, "sql_validation.enable_syntax_validation", False) or 
+                get_nested_config(self.config, "sql_validation.enable_forbidden_check", False))
+
+    def _is_auto_repair_enabled(self) -> bool:
+        """检查是否启用自动修复"""
+        from agent.config import get_nested_config
+        return (get_nested_config(self.config, "sql_validation.enable_auto_repair", False) and 
+                get_nested_config(self.config, "sql_validation.enable_syntax_validation", False))
+
+    async def _validate_sql_with_custom_priority(self, sql: str) -> Dict[str, Any]:
+        """
+        按照自定义优先级验证SQL:先禁止词,再语法
+        
+        Args:
+            sql: 要验证的SQL语句
+            
+        Returns:
+            验证结果字典
+        """
+        try:
+            from agent.config import get_nested_config
+            
+            # 1. 优先检查禁止词(您要求的优先级)
+            if get_nested_config(self.config, "sql_validation.enable_forbidden_check", True):
+                forbidden_result = self._check_forbidden_keywords(sql)
+                if not forbidden_result.get("valid"):
+                    return {
+                        "valid": False,
+                        "error_type": "forbidden_keywords",
+                        "error_message": forbidden_result.get("error"),
+                        "can_repair": False  # 禁止词错误不能修复
+                    }
+            
+            # 2. 再检查语法(EXPLAIN SQL)
+            if get_nested_config(self.config, "sql_validation.enable_syntax_validation", True):
+                syntax_result = await self._validate_sql_syntax(sql)
+                if not syntax_result.get("valid"):
+                    return {
+                        "valid": False,
+                        "error_type": "syntax_error",
+                        "error_message": syntax_result.get("error"),
+                        "can_repair": True  # 语法错误可以尝试修复
+                    }
+            
+            return {"valid": True}
+            
+        except Exception as e:
+            return {
+                "valid": False,
+                "error_type": "validation_exception",
+                "error_message": str(e),
+                "can_repair": False
+            }
+
+    def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
+        """检查禁止的SQL关键词"""
+        try:
+            from agent.config import get_nested_config
+            forbidden_operations = get_nested_config(
+                self.config, 
+                "sql_validation.forbidden_operations", 
+                ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
+            )
+            
+            sql_upper = sql.upper().strip()
+            
+            for operation in forbidden_operations:
+                if sql_upper.startswith(operation.upper()):
+                    return {
+                        "valid": False,
+                        "error": f"不允许的操作: {operation}。本系统只支持查询操作(SELECT)。"
+                    }
+            
+            return {"valid": True}
+            
+        except Exception as e:
+            return {
+                "valid": False,
+                "error": f"禁止词检查异常: {str(e)}"
+            }
+
+    async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
+        """语法验证 - 使用EXPLAIN SQL"""
+        try:
+            from common.vanna_instance import get_vanna_instance
+            import asyncio
+            
+            vn = get_vanna_instance()
+            
+            # 构建EXPLAIN查询
+            explain_sql = f"EXPLAIN {sql}"
+            
+            # 异步执行验证
+            result = await asyncio.to_thread(vn.run_sql, explain_sql)
+            
+            if result is not None:
+                return {"valid": True}
+            else:
+                return {
+                    "valid": False,
+                    "error": "SQL语法验证失败"
+                }
+                
+        except Exception as e:
+            return {
+                "valid": False,
+                "error": str(e)
+            }
+
+    async def _attempt_sql_repair_once(self, sql: str, error_message: str) -> Dict[str, Any]:
+        """
+        使用LLM尝试修复SQL - 只修复一次
+        
+        Args:
+            sql: 原始SQL
+            error_message: 错误信息
+            
+        Returns:
+            修复结果字典
+        """
+        try:
+            from common.vanna_instance import get_vanna_instance
+            from agent.config import get_nested_config
+            import asyncio
+            
+            vn = get_vanna_instance()
+            
+            # 构建修复提示词
+            repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
+
+当前数据库类型: PostgreSQL
+错误信息: {error_message}
+
+需要修复的SQL:
+{sql}
+
+修复要求:
+1. 只修复语法错误和表结构错误
+2. 保持SQL的原始业务逻辑不变  
+3. 使用PostgreSQL标准语法
+4. 确保修复后的SQL语法正确
+
+请直接输出修复后的SQL语句,不要添加其他说明文字。"""
+
+            # 获取超时配置
+            timeout = get_nested_config(self.config, "sql_validation.repair_timeout", 60)
+            
+            # 异步调用LLM修复
+            response = await asyncio.wait_for(
+                asyncio.to_thread(
+                    vn.chat_with_llm,
+                    question=repair_prompt,
+                    system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
+                ),
+                timeout=timeout
+            )
+            
+            if response and response.strip():
+                repaired_sql = response.strip()
+                
+                # 验证修复后的SQL
+                validation_result = await self._validate_sql_syntax(repaired_sql)
+                
+                if validation_result.get("valid"):
+                    return {
+                        "success": True,
+                        "repaired_sql": repaired_sql,
+                        "error": None
+                    }
+                else:
+                    return {
+                        "success": False,
+                        "repaired_sql": None,
+                        "error": f"修复后的SQL仍然无效: {validation_result.get('error')}"
+                    }
+            else:
+                return {
+                    "success": False,
+                    "repaired_sql": None,
+                    "error": "LLM返回空响应"
+                }
+                
+        except asyncio.TimeoutError:
+            return {
+                "success": False,
+                "repaired_sql": None,
+                "error": f"修复超时({get_nested_config(self.config, 'sql_validation.repair_timeout', 60)}秒)"
+            }
+        except Exception as e:
+            return {
+                "success": False,
+                "repaired_sql": None,
+                "error": f"修复异常: {str(e)}"
+            }
+
+    # ==================== 原有方法 ====================
+    
     def _extract_original_question(self, question: str) -> str:
         """
         从enhanced_question中提取原始问题
@@ -603,7 +1117,7 @@ class CituLangGraphAgent:
             print(f"[WARNING] 提取原始问题失败: {str(e)}")
             return question.strip()
 
-    def health_check(self) -> Dict[str, Any]:
+    async def health_check(self) -> Dict[str, Any]:
         """健康检查"""
         try:
             # 从配置获取健康检查参数
@@ -613,7 +1127,7 @@ class CituLangGraphAgent:
             
             if enable_full_test:
                 # 完整流程测试
-                test_result = self.process_question(test_question, "health_check")
+                test_result = await self.process_question(test_question, "health_check")
                 
                 return {
                     "status": "healthy" if test_result.get("success") else "degraded",

+ 27 - 0
agent/config.py

@@ -91,6 +91,33 @@ AGENT_CONFIG = {
         # 生产环境建议启用,内存受限环境可关闭
         "enable_agent_reuse": True,
     },
+    
+    # ==================== SQL验证配置 ====================
+    "sql_validation": {
+        # 是否启用禁止词检查:检查SQL中是否包含危险操作
+        # 禁止词检查优先级高于语法检查,失败时不尝试修复
+        "enable_forbidden_check": True,
+        
+        # 是否启用语法验证:使用EXPLAIN SQL验证语法正确性
+        # 语法验证失败时可以尝试LLM修复
+        "enable_syntax_validation": True,
+        
+        # 是否启用自动修复:当语法验证失败时,调用LLM尝试修复
+        # 仅对语法错误有效,禁止词错误不会尝试修复
+        "enable_auto_repair": True,
+        
+        # 禁止的SQL操作:这些操作会被直接拒绝,不允许执行
+        # 系统只支持查询操作,不允许修改数据
+        "forbidden_operations": ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT'],
+        
+        # LLM修复超时时间:单次修复调用的最大等待时间(秒)
+        # 超时后将放弃修复,直接返回失败
+        "repair_timeout": 60,
+        
+        # 修复重试次数:目前固定为1次,不进行多次重试
+        # 这是设计约束,避免无限修复循环
+        "max_repair_attempts": 1,
+    },
 }
 
 def get_nested_config(config: dict, key_path: str, default=None):

+ 8 - 0
agent/state.py

@@ -24,6 +24,14 @@ class AgentState(TypedDict):
     query_result: Optional[Dict[str, Any]]
     summary: Optional[str]
     
+    # SQL验证和修复相关状态
+    sql_generation_success: bool
+    sql_validation_success: bool
+    sql_repair_attempted: bool
+    sql_repair_success: bool
+    validation_error_type: Optional[str]  # "forbidden_keywords" | "syntax_error" | None
+    user_prompt: Optional[str]
+    
     # 聊天响应
     chat_response: Optional[str]
     

+ 2 - 2
app_config.py

@@ -181,8 +181,8 @@ REDIS_DB = 0
 REDIS_PASSWORD = None
 
 # 缓存开关配置
-ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
-ENABLE_QUESTION_ANSWER_CACHE = True     # 是否启用问答结果缓存
+ENABLE_CONVERSATION_CONTEXT = False      # 是否启用对话上下文
+ENABLE_QUESTION_ANSWER_CACHE = False     # 是否启用问答结果缓存
 ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 
 # TTL配置(单位:秒)

+ 21 - 11
citu_app.py

@@ -597,12 +597,14 @@ def ask_agent():
                 can_retry=True
             )), 503
         
-        agent_result = agent.process_question(
+        # 异步调用Agent处理问题
+        import asyncio
+        agent_result = asyncio.run(agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
             session_id=browser_session_id,
             context_type=context_type,  # 传递上下文类型
             routing_mode=effective_routing_mode  # 新增:传递路由模式
-        )
+        ))
         
         # 8. 处理Agent结果
         if agent_result.get("success", False):
@@ -737,15 +739,14 @@ def agent_health():
         try:
             agent = get_citu_langraph_agent()
             health_data["checks"]["agent_creation"] = True
-            health_data["workflow_compiled"] = agent.workflow is not None
+            # 修正:Agent现在是动态创建workflow的,不再有预创建的workflow属性
+            health_data["workflow_compiled"] = True  # 动态创建,始终可用
             health_data["tools_count"] = len(agent.tools) if hasattr(agent, 'tools') else 0
         except Exception as e:
             health_data["message"] = f"Agent创建失败: {str(e)}"
+            health_data["status"] = "unhealthy"  # 设置状态
             from common.result import health_error_response
-            return jsonify(health_error_response(
-                status="unhealthy",
-                **health_data
-            )), 503
+            return jsonify(health_error_response(**health_data)), 503
         
         # 检查2: 工具导入
         try:
@@ -773,7 +774,9 @@ def agent_health():
         # 检查5: 完整流程测试(可选)
         try:
             if all(health_data["checks"].values()):
-                test_result = agent.health_check()
+                import asyncio
+                # 异步调用健康检查
+                test_result = asyncio.run(agent.health_check())
                 health_data["test_result"] = test_result.get("status") == "healthy"
                 health_data["status"] = test_result.get("status", "unknown")
                 health_data["message"] = test_result.get("message", "健康检查完成")
@@ -781,6 +784,9 @@ def agent_health():
                 health_data["status"] = "degraded"
                 health_data["message"] = "部分组件异常"
         except Exception as e:
+            print(f"[ERROR] 健康检查异常: {str(e)}")
+            import traceback
+            print(f"[ERROR] 详细健康检查错误: {traceback.format_exc()}")
             health_data["status"] = "degraded"
             health_data["message"] = f"完整测试失败: {str(e)}"
         
@@ -790,12 +796,16 @@ def agent_health():
         if health_data["status"] == "healthy":
             return jsonify(health_success_response(**health_data))
         elif health_data["status"] == "degraded":
-            return jsonify(health_error_response(status="degraded", **health_data)), 503
+            return jsonify(health_error_response(**health_data)), 503
         else:
-            return jsonify(health_error_response(status="unhealthy", **health_data)), 503
+            # 确保状态设置为unhealthy
+            health_data["status"] = "unhealthy"
+            return jsonify(health_error_response(**health_data)), 503
             
     except Exception as e:
-        print(f"[ERROR] 健康检查异常: {str(e)}")
+        print(f"[ERROR] 顶层健康检查异常: {str(e)}")
+        import traceback
+        print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
         from common.result import internal_error_response
         return jsonify(internal_error_response(
             response_text="健康检查失败,请稍后重试"