将当前的 _agent_database_node
拆分为两个独立节点,实现SQL生成验证与执行的分离,提供中间路由能力,提升系统的智能性和用户体验。
classify_question → agent_database → format_response → END
(generate_sql + execute_sql + summary)
classify_question → agent_sql_generation_node → agent_sql_execution_node → format_response → END
(generate_sql + validation) (execute_sql + summary)
↓ (失败路由)
format_response → END
generate_sql()
工具生成SQLdef _agent_sql_generation_node(self, state: AgentState) -> AgentState:
"""SQL生成验证节点"""
try:
question = state["question"]
# 步骤1: SQL生成
sql_result = generate_sql(question, allow_llm_to_see_data=True)
if not sql_result.get("success"):
# SQL生成失败处理
return self._handle_sql_generation_failure(state, sql_result)
sql = sql_result.get("sql")
state["sql"] = sql
# 步骤2: SQL验证 (如果启用)
if self._is_sql_validation_enabled():
validation_result = await self._validate_sql_with_schema_tools(sql)
if not validation_result.get("valid"):
# 验证失败,尝试修复
return await self._handle_sql_validation_failure(state, sql, validation_result)
# 生成和验证都成功
state["sql_generation_success"] = True
state["execution_path"].append("agent_sql_generation")
return state
except Exception as e:
state["error"] = f"SQL生成节点异常: {str(e)}"
return state
async def _validate_sql_with_schema_tools(self, sql: str) -> Dict[str, Any]:
"""复用schema_tools的SQL验证逻辑"""
try:
# 1. 语法验证 (EXPLAIN SQL)
syntax_valid = await self._validate_sql_syntax(sql)
if not syntax_valid.get("valid"):
return {
"valid": False,
"error_type": "syntax_error",
"error_message": syntax_valid.get("error"),
"can_repair": True
}
# 2. 禁止词检查
forbidden_check = self._check_forbidden_keywords(sql)
if not forbidden_check.get("valid"):
return {
"valid": False,
"error_type": "forbidden_keywords",
"error_message": forbidden_check.get("error"),
"can_repair": False
}
return {"valid": True}
except Exception as e:
return {
"valid": False,
"error_type": "validation_exception",
"error_message": str(e),
"can_repair": False
}
def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
"""检查禁止的SQL关键词"""
forbidden_keywords = ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
sql_upper = sql.upper()
for keyword in forbidden_keywords:
if keyword in sql_upper:
return {
"valid": False,
"error": f"不允许的操作: {keyword}。本系统只支持查询操作(SELECT)。"
}
return {"valid": True}
async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
"""语法验证 - 复用schema_tools逻辑"""
try:
# 获取数据库连接 (复用现有连接逻辑)
from common.vanna_instance import get_vanna_instance
vn = get_vanna_instance()
# 执行EXPLAIN验证
explain_sql = f"EXPLAIN {sql}"
# 注意: 这里需要适配到实际的数据库连接方式
result = await vn.run_sql(explain_sql)
return {"valid": True}
except Exception as e:
return {
"valid": False,
"error": str(e)
}
async def _handle_sql_validation_failure(self, state: AgentState, sql: str, validation_result: Dict) -> AgentState:
"""处理SQL验证失败"""
error_type = validation_result.get("error_type")
# 禁止词错误,直接失败
if error_type == "forbidden_keywords":
state["sql_generation_success"] = False
state["user_prompt"] = validation_result.get("error_message")
return state
# 语法错误,尝试LLM修复 (只修复一次)
if error_type == "syntax_error" and self._is_auto_repair_enabled():
repaired_sql = await self._repair_sql_with_llm(sql, validation_result.get("error_message"))
if repaired_sql:
# 再次验证修复后的SQL
revalidation = await self._validate_sql_with_schema_tools(repaired_sql)
if revalidation.get("valid"):
state["sql"] = repaired_sql
state["sql_generation_success"] = True
state["sql_repair_applied"] = True
return state
# 修复失败或不支持修复
state["sql_generation_success"] = False
state["user_prompt"] = f"SQL生成遇到问题: {validation_result.get('error_message')}"
return state
```python
async def _repair_sql_with_llm(self, sql: str, error_message: str) -> Optional[str]:
"""使用LLM修复SQL - 只尝试一次"""
try:
from common.vanna_instance import get_vanna_instance
vn = get_vanna_instance()
# 构建修复提示词
repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
当前数据库类型: PostgreSQL
错误信息: {error_message}
需要修复的SQL:
{sql}
修复要求:
1. 只修复语法错误和表结构错误
2. 保持SQL的原始业务逻辑不变
3. 使用PostgreSQL标准语法
4. 确保修复后的SQL语法正确
请直接输出修复后的SQL语句,不要添加其他说明文字。"""
# 调用LLM修复 - 复用schema_tools的异步调用方式
response = await asyncio.to_thread(
vn.chat_with_llm,
question=repair_prompt,
system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
)
if response and response.strip():
return response.strip()
return None
except Exception as e:
print(f"[ERROR] SQL修复失败: {str(e)}")
return None
async def _handle_sql_validation_failure(self, state: AgentState, sql: str, validation_result: Dict) -> AgentState:
"""处理SQL验证失败 - 重要约束:只修复一次"""
error_type = validation_result.get("error_type")
# 禁止词错误,直接失败,不尝试修复
if error_type == "forbidden_keywords":
state["sql_generation_success"] = False
state["user_prompt"] = validation_result.get("error_message")
state["execution_path"].append("forbidden_keywords_failed")
return state
# 语法错误,仅尝试修复一次
if error_type == "syntax_error" and self._is_auto_repair_enabled():
print(f"[SQL_REPAIR] 尝试修复SQL语法错误(仅一次): {validation_result.get('error_message')}")
repaired_sql = await self._repair_sql_with_llm(sql, validation_result.get("error_message"))
if repaired_sql:
# 对修复后的SQL进行验证 - 不管结果如何,不再重试
revalidation = await self._validate_sql_with_schema_tools(repaired_sql)
if revalidation.get("valid"):
# 修复成功
state["sql"] = repaired_sql
state["sql_generation_success"] = True
state["sql_repair_applied"] = True
state["execution_path"].append("sql_repair_success")
print(f"[SQL_REPAIR] SQL修复成功")
return state
else:
# 修复后仍然失败,直接结束
print(f"[SQL_REPAIR] 修复后验证仍然失败: {revalidation.get('error_message')}")
state["sql_generation_success"] = False
state["user_prompt"] = f"SQL修复尝试失败: {revalidation.get('error_message')}"
state["execution_path"].append("sql_repair_failed")
return state
else:
# LLM修复失败
print(f"[SQL_REPAIR] LLM修复调用失败")
state["sql_generation_success"] = False
state["user_prompt"] = f"SQL语法修复失败: {validation_result.get('error_message')}"
state["execution_path"].append("sql_repair_failed")
return state
# 不启用修复或其他错误类型,直接失败
state["sql_generation_success"] = False
state["user_prompt"] = f"SQL验证失败: {validation_result.get('error_message')}"
state["execution_path"].append("sql_validation_failed")
return state
def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
"""SQL执行节点 - 保持原有逻辑"""
try:
sql = state.get("sql")
question = state["question"]
# 步骤1: 执行SQL (复用原有逻辑)
execute_result = execute_sql.invoke({"sql": sql})
if not execute_result.get("success"):
state["error"] = execute_result.get("error", "SQL执行失败")
return state
query_result = execute_result.get("data_result")
state["query_result"] = query_result
# 步骤2: 生成摘要 (根据配置)
if self._should_generate_summary(query_result):
original_question = self._extract_original_question(question)
summary_result = generate_summary.invoke({
"question": original_question,
"query_result": query_result,
"sql": sql
})
if summary_result.get("success"):
state["summary"] = summary_result.get("summary")
else:
# 摘要生成失败不是致命错误
state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
state["execution_path"].append("agent_sql_execution")
return state
except Exception as e:
state["error"] = f"SQL执行节点异常: {str(e)}"
return state
def _should_generate_summary(self, query_result: Dict) -> bool:
"""判断是否应该生成摘要"""
from app_config import ENABLE_RESULT_SUMMARY
return ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0
def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
"""SQL生成后的路由决策"""
if state.get("sql_generation_success"):
return "continue_execution" # 路由到SQL执行节点
else:
return "return_to_user" # 路由到format_response,结束流程
def _create_workflow(self, routing_mode: str = None) -> StateGraph:
"""更新工作流创建逻辑"""
workflow = StateGraph(AgentState)
# 添加新的节点
workflow.add_node("classify_question", self._classify_question_node)
workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
workflow.add_node("agent_chat", self._agent_chat_node)
workflow.add_node("format_response", self._format_response_node)
# 设置条件路由
workflow.add_conditional_edges(
"classify_question",
self._route_after_classification,
{
"DATABASE": "agent_sql_generation",
"CHAT": "agent_chat"
}
)
# SQL生成后的条件路由
workflow.add_conditional_edges(
"agent_sql_generation",
self._route_after_sql_generation,
{
"continue_execution": "agent_sql_execution",
"return_to_user": "format_response"
}
)
# 普通边缘
workflow.add_edge("agent_sql_execution", "format_response")
workflow.add_edge("agent_chat", "format_response")
workflow.add_edge("format_response", END)
return workflow.compile()
# 在app_config.py中添加
SQL_VALIDATION_CONFIG = {
"enable_syntax_validation": True, # 是否启用语法验证(EXPLAIN SQL)
"enable_forbidden_check": True, # 是否启用禁止词检查
"enable_auto_repair": True, # 是否启用自动修复(只尝试一次)
}
# 现有配置保持不变
ENABLE_RESULT_SUMMARY = True # 控制摘要生成
def _is_sql_validation_enabled(self) -> bool:
"""检查是否启用SQL验证"""
# 注意:任一验证功能启用都算启用验证
return (SQL_VALIDATION_CONFIG.get("enable_syntax_validation", False) or
SQL_VALIDATION_CONFIG.get("enable_forbidden_check", False))
def _is_auto_repair_enabled(self) -> bool:
"""检查是否启用自动修复"""
# 只有在语法验证启用的情况下,自动修复才有意义
return (SQL_VALIDATION_CONFIG.get("enable_auto_repair", False) and
SQL_VALIDATION_CONFIG.get("enable_syntax_validation", False))
def _should_skip_validation(self) -> bool:
"""判断是否跳过所有验证"""
# 当所有验证功能都禁用时,跳过验证步骤
return not self._is_sql_validation_enabled()
# 验证流程的完整决策树
if not self._is_sql_validation_enabled():
# 跳过所有验证,直接使用生成的SQL
pass
else:
# 按优先级执行验证
if SQL_VALIDATION_CONFIG.get("enable_syntax_validation"):
# 1. 语法验证 (EXPLAIN SQL)
syntax_result = await self._validate_sql_syntax(sql)
if not syntax_result.valid and self._is_auto_repair_enabled():
# 尝试修复 (只一次)
repaired_sql = await self._repair_sql_with_llm(sql, syntax_result.error)
# 修复后不管成功失败,都不再重试
if SQL_VALIDATION_CONFIG.get("enable_forbidden_check"):
# 2. 禁止词检查 (不可修复)
forbidden_result = self._check_forbidden_keywords(sql)
if not forbidden_result.valid:
# 直接失败,不尝试修复
return self._handle_forbidden_keywords_error(state, forbidden_result)
class AgentState(TypedDict):
# 现有字段保持不变...
# 新增字段
sql_generation_success: bool # SQL生成是否成功
sql_repair_applied: bool # 是否应用了SQL修复
user_prompt: Optional[str] # 给用户的提示信息
```python
def _handle_sql_generation_failure(self, state: AgentState, sql_result: Dict) -> AgentState:
"""处理SQL生成失败 - 统一处理三种情况"""
error_message = sql_result.get("error", "")
error_type = sql_result.get("error_type", "")
# 重要设计决策:不进行二次分类判断,统一按数据库问题处理
# 原因:第一次LLM分类已经判断为DATABASE,第二次大概率仍是DATABASE
# 根据错误类型和内容生成统一的用户提示
if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
# 情况1:数据库缺少表/字段
user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
failure_reason = "missing_database_info"
elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
# 情况2:问题太模糊
user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
failure_reason = "ambiguous_question"
elif error_type == "llm_explanation":
# 情况3:LLM返回解释性文本而非SQL
user_prompt = error_message + " 请尝试重新描述您的问题或询问其他内容。"
failure_reason = "llm_explanation"
else:
# 其他未分类的失败情况
user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
failure_reason = "unknown_generation_failure"
# 关键决策:所有失败都返回用户提示,不路由到CHAT
state["sql_generation_success"] = False
state["user_prompt"] = user_prompt
state["sql_generation_failure_reason"] = failure_reason
state["execution_path"].append("sql_generation_failed")
print(f"[SQL_GENERATION] 生成失败: {failure_reason} - {user_prompt}")
return state
def _format_response_node(self, state: AgentState) -> AgentState:
"""格式化响应节点 - 适配新的失败处理"""
# 处理SQL生成失败的情况
if not state.get("sql_generation_success", True) and state.get("user_prompt"):
state["final_response"] = {
"success": False,
"response": state["user_prompt"],
"type": "DATABASE",
"sql_generation_failed": True,
"execution_path": state["execution_path"],
"classification_info": {
"confidence": state.get("classification_confidence", 0),
"reason": state.get("classification_reason", ""),
"method": state.get("classification_method", "")
}
}
return state
# 其他情况保持原有逻辑
# ... (原有的format_response逻辑)
_agent_sql_generation_node
节点_agent_sql_execution_node
节点