|
@@ -53,14 +53,25 @@ class CituLangGraphAgent:
|
|
|
|
|
|
# 根据路由模式创建不同的工作流
|
|
|
if QUESTION_ROUTING_MODE == "database_direct":
|
|
|
- # 直接数据库模式:跳过分类,直接进入数据库处理
|
|
|
+ # 直接数据库模式:跳过分类,直接进入数据库处理(使用新的拆分节点)
|
|
|
workflow.add_node("init_direct_database", self._init_direct_database_node)
|
|
|
- workflow.add_node("agent_database", self._agent_database_node)
|
|
|
+ workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
|
|
|
+ workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
|
|
|
workflow.add_node("format_response", self._format_response_node)
|
|
|
|
|
|
workflow.set_entry_point("init_direct_database")
|
|
|
- workflow.add_edge("init_direct_database", "agent_database")
|
|
|
- workflow.add_edge("agent_database", "format_response")
|
|
|
+
|
|
|
+ # 添加条件路由
|
|
|
+ workflow.add_edge("init_direct_database", "agent_sql_generation")
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "agent_sql_generation",
|
|
|
+ self._route_after_sql_generation,
|
|
|
+ {
|
|
|
+ "continue_execution": "agent_sql_execution",
|
|
|
+ "return_to_user": "format_response"
|
|
|
+ }
|
|
|
+ )
|
|
|
+ workflow.add_edge("agent_sql_execution", "format_response")
|
|
|
workflow.add_edge("format_response", END)
|
|
|
|
|
|
elif QUESTION_ROUTING_MODE == "chat_direct":
|
|
@@ -75,10 +86,11 @@ class CituLangGraphAgent:
|
|
|
workflow.add_edge("format_response", END)
|
|
|
|
|
|
else:
|
|
|
- # 其他模式(hybrid, llm_only):使用原有的分类工作流
|
|
|
+ # 其他模式(hybrid, llm_only):使用新的拆分工作流
|
|
|
workflow.add_node("classify_question", self._classify_question_node)
|
|
|
workflow.add_node("agent_chat", self._agent_chat_node)
|
|
|
- workflow.add_node("agent_database", self._agent_database_node)
|
|
|
+ workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
|
|
|
+ workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
|
|
|
workflow.add_node("format_response", self._format_response_node)
|
|
|
|
|
|
workflow.set_entry_point("classify_question")
|
|
@@ -88,13 +100,24 @@ class CituLangGraphAgent:
|
|
|
"classify_question",
|
|
|
self._route_after_classification,
|
|
|
{
|
|
|
- "DATABASE": "agent_database",
|
|
|
+ "DATABASE": "agent_sql_generation",
|
|
|
"CHAT": "agent_chat"
|
|
|
}
|
|
|
)
|
|
|
|
|
|
+ # 添加条件边:SQL生成后的路由
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "agent_sql_generation",
|
|
|
+ self._route_after_sql_generation,
|
|
|
+ {
|
|
|
+ "continue_execution": "agent_sql_execution",
|
|
|
+ "return_to_user": "format_response"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # 普通边
|
|
|
workflow.add_edge("agent_chat", "format_response")
|
|
|
- workflow.add_edge("agent_database", "format_response")
|
|
|
+ workflow.add_edge("agent_sql_execution", "format_response")
|
|
|
workflow.add_edge("format_response", END)
|
|
|
|
|
|
return workflow.compile()
|
|
@@ -188,9 +211,250 @@ class CituLangGraphAgent:
|
|
|
state["execution_path"].append("classify_error")
|
|
|
return state
|
|
|
|
|
|
+ async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
|
|
|
+ """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""
|
|
|
+ try:
|
|
|
+ print(f"[SQL_GENERATION] 开始处理SQL生成和验证: {state['question']}")
|
|
|
+
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ # 步骤1:生成SQL
|
|
|
+ print(f"[SQL_GENERATION] 步骤1:生成SQL")
|
|
|
+ sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
|
|
|
+
|
|
|
+ if not sql_result.get("success"):
|
|
|
+ # SQL生成失败的统一处理
|
|
|
+ error_message = sql_result.get("error", "")
|
|
|
+ error_type = sql_result.get("error_type", "")
|
|
|
+
|
|
|
+ print(f"[SQL_GENERATION] SQL生成失败: {error_message}")
|
|
|
+
|
|
|
+ # 根据错误类型生成用户提示
|
|
|
+ if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
|
|
|
+ user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
|
|
|
+ failure_reason = "missing_database_info"
|
|
|
+ elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
|
|
|
+ user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
|
|
|
+ failure_reason = "ambiguous_question"
|
|
|
+ elif error_type == "llm_explanation":
|
|
|
+ user_prompt = error_message + " 请尝试重新描述您的问题或询问其他内容。"
|
|
|
+ failure_reason = "llm_explanation"
|
|
|
+ else:
|
|
|
+ user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
|
|
|
+ failure_reason = "unknown_generation_failure"
|
|
|
+
|
|
|
+ # 统一返回失败状态
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["user_prompt"] = user_prompt
|
|
|
+ state["validation_error_type"] = failure_reason
|
|
|
+ state["current_step"] = "sql_generation_failed"
|
|
|
+ state["execution_path"].append("agent_sql_generation_failed")
|
|
|
+
|
|
|
+ print(f"[SQL_GENERATION] 生成失败: {failure_reason} - {user_prompt}")
|
|
|
+ return state
|
|
|
+
|
|
|
+ sql = sql_result.get("sql")
|
|
|
+ state["sql"] = sql
|
|
|
+ print(f"[SQL_GENERATION] SQL生成成功: {sql}")
|
|
|
+
|
|
|
+ # 步骤1.5:检查是否为解释性响应而非SQL
|
|
|
+ error_type = sql_result.get("error_type")
|
|
|
+ if error_type == "llm_explanation":
|
|
|
+ # LLM返回了解释性文本,直接作为最终答案
|
|
|
+ explanation = sql_result.get("error", "")
|
|
|
+ state["chat_response"] = explanation + " 请尝试提问其它问题。"
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["validation_error_type"] = "llm_explanation"
|
|
|
+ state["current_step"] = "sql_generation_completed"
|
|
|
+ state["execution_path"].append("agent_sql_generation")
|
|
|
+ print(f"[SQL_GENERATION] 返回LLM解释性答案: {explanation}")
|
|
|
+ return state
|
|
|
+
|
|
|
+ # 额外验证:检查SQL格式(防止工具误判)
|
|
|
+ from agent.utils import _is_valid_sql_format
|
|
|
+ if not _is_valid_sql_format(sql):
|
|
|
+ # 内容看起来不是SQL,当作解释性响应处理
|
|
|
+ state["chat_response"] = sql + " 请尝试提问其它问题。"
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["validation_error_type"] = "invalid_sql_format"
|
|
|
+ state["current_step"] = "sql_generation_completed"
|
|
|
+ state["execution_path"].append("agent_sql_generation")
|
|
|
+ print(f"[SQL_GENERATION] 内容不是有效SQL,当作解释返回: {sql}")
|
|
|
+ return state
|
|
|
+
|
|
|
+ # 步骤2:SQL验证(如果启用)
|
|
|
+ if self._is_sql_validation_enabled():
|
|
|
+ print(f"[SQL_GENERATION] 步骤2:验证SQL")
|
|
|
+ validation_result = await self._validate_sql_with_custom_priority(sql)
|
|
|
+
|
|
|
+ if not validation_result.get("valid"):
|
|
|
+ # 验证失败,检查是否可以修复
|
|
|
+ error_type = validation_result.get("error_type")
|
|
|
+ error_message = validation_result.get("error_message")
|
|
|
+ can_repair = validation_result.get("can_repair", False)
|
|
|
+
|
|
|
+ print(f"[SQL_GENERATION] SQL验证失败: {error_type} - {error_message}")
|
|
|
+
|
|
|
+ if error_type == "forbidden_keywords":
|
|
|
+ # 禁止词错误,直接失败,不尝试修复
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["sql_validation_success"] = False
|
|
|
+ state["user_prompt"] = error_message
|
|
|
+ state["validation_error_type"] = "forbidden_keywords"
|
|
|
+ state["current_step"] = "sql_validation_failed"
|
|
|
+ state["execution_path"].append("forbidden_keywords_failed")
|
|
|
+ print(f"[SQL_GENERATION] 禁止词验证失败,直接结束")
|
|
|
+ return state
|
|
|
+
|
|
|
+ elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():
|
|
|
+ # 语法错误,尝试修复(仅一次)
|
|
|
+ print(f"[SQL_GENERATION] 尝试修复SQL语法错误(仅一次): {error_message}")
|
|
|
+ state["sql_repair_attempted"] = True
|
|
|
+
|
|
|
+ repair_result = await self._attempt_sql_repair_once(sql, error_message)
|
|
|
+
|
|
|
+ if repair_result.get("success"):
|
|
|
+ # 修复成功
|
|
|
+ repaired_sql = repair_result.get("repaired_sql")
|
|
|
+ state["sql"] = repaired_sql
|
|
|
+ state["sql_generation_success"] = True
|
|
|
+ state["sql_validation_success"] = True
|
|
|
+ state["sql_repair_success"] = True
|
|
|
+ state["current_step"] = "sql_generation_completed"
|
|
|
+ state["execution_path"].append("sql_repair_success")
|
|
|
+ print(f"[SQL_GENERATION] SQL修复成功: {repaired_sql}")
|
|
|
+ return state
|
|
|
+ else:
|
|
|
+ # 修复失败,直接结束
|
|
|
+ repair_error = repair_result.get("error", "修复失败")
|
|
|
+ print(f"[SQL_GENERATION] SQL修复失败: {repair_error}")
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["sql_validation_success"] = False
|
|
|
+ state["sql_repair_success"] = False
|
|
|
+ state["user_prompt"] = f"SQL语法修复失败: {repair_error}"
|
|
|
+ state["validation_error_type"] = "syntax_repair_failed"
|
|
|
+ state["current_step"] = "sql_repair_failed"
|
|
|
+ state["execution_path"].append("sql_repair_failed")
|
|
|
+ return state
|
|
|
+ else:
|
|
|
+ # 不启用修复或其他错误类型,直接失败
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["sql_validation_success"] = False
|
|
|
+ state["user_prompt"] = f"SQL验证失败: {error_message}"
|
|
|
+ state["validation_error_type"] = error_type
|
|
|
+ state["current_step"] = "sql_validation_failed"
|
|
|
+ state["execution_path"].append("sql_validation_failed")
|
|
|
+ print(f"[SQL_GENERATION] SQL验证失败,不尝试修复")
|
|
|
+ return state
|
|
|
+ else:
|
|
|
+ print(f"[SQL_GENERATION] SQL验证通过")
|
|
|
+ state["sql_validation_success"] = True
|
|
|
+ else:
|
|
|
+ print(f"[SQL_GENERATION] 跳过SQL验证(未启用)")
|
|
|
+ state["sql_validation_success"] = True
|
|
|
+
|
|
|
+ # 生成和验证都成功
|
|
|
+ state["sql_generation_success"] = True
|
|
|
+ state["current_step"] = "sql_generation_completed"
|
|
|
+ state["execution_path"].append("agent_sql_generation")
|
|
|
+
|
|
|
+ print(f"[SQL_GENERATION] SQL生成验证完成,准备执行")
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] SQL生成验证节点异常: {str(e)}")
|
|
|
+ import traceback
|
|
|
+ print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
|
|
|
+ state["sql_generation_success"] = False
|
|
|
+ state["sql_validation_success"] = False
|
|
|
+ state["user_prompt"] = f"SQL生成验证异常: {str(e)}"
|
|
|
+ state["validation_error_type"] = "node_exception"
|
|
|
+ state["current_step"] = "sql_generation_error"
|
|
|
+ state["execution_path"].append("agent_sql_generation_error")
|
|
|
+ return state
|
|
|
+
|
|
|
+ def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
|
|
|
+ """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""
|
|
|
+ try:
|
|
|
+ print(f"[SQL_EXECUTION] 开始执行SQL: {state.get('sql', 'N/A')}")
|
|
|
+
|
|
|
+ sql = state.get("sql")
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ if not sql:
|
|
|
+ print(f"[SQL_EXECUTION] 没有可执行的SQL")
|
|
|
+ state["error"] = "没有可执行的SQL语句"
|
|
|
+ state["error_code"] = 500
|
|
|
+ state["current_step"] = "sql_execution_error"
|
|
|
+ state["execution_path"].append("agent_sql_execution_error")
|
|
|
+ return state
|
|
|
+
|
|
|
+ # 步骤1:执行SQL
|
|
|
+ print(f"[SQL_EXECUTION] 步骤1:执行SQL")
|
|
|
+ execute_result = execute_sql.invoke({"sql": sql})
|
|
|
+
|
|
|
+ if not execute_result.get("success"):
|
|
|
+ print(f"[SQL_EXECUTION] SQL执行失败: {execute_result.get('error')}")
|
|
|
+ state["error"] = execute_result.get("error", "SQL执行失败")
|
|
|
+ state["error_code"] = 500
|
|
|
+ state["current_step"] = "sql_execution_error"
|
|
|
+ state["execution_path"].append("agent_sql_execution_error")
|
|
|
+ return state
|
|
|
+
|
|
|
+ query_result = execute_result.get("data_result")
|
|
|
+ state["query_result"] = query_result
|
|
|
+ print(f"[SQL_EXECUTION] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
|
|
|
+
|
|
|
+ # 步骤2:生成摘要(根据配置和数据情况)
|
|
|
+ if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
|
|
|
+ print(f"[SQL_EXECUTION] 步骤2:生成摘要")
|
|
|
+
|
|
|
+ # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
|
|
|
+ original_question = self._extract_original_question(question)
|
|
|
+ print(f"[SQL_EXECUTION] 原始问题: {original_question}")
|
|
|
+
|
|
|
+ summary_result = generate_summary.invoke({
|
|
|
+ "question": original_question, # 使用原始问题而不是enhanced_question
|
|
|
+ "query_result": query_result,
|
|
|
+ "sql": sql
|
|
|
+ })
|
|
|
+
|
|
|
+ if not summary_result.get("success"):
|
|
|
+ print(f"[SQL_EXECUTION] 摘要生成失败: {summary_result.get('message')}")
|
|
|
+ # 摘要生成失败不是致命错误,使用默认摘要
|
|
|
+ state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
|
|
|
+ else:
|
|
|
+ state["summary"] = summary_result.get("summary")
|
|
|
+ print(f"[SQL_EXECUTION] 摘要生成成功")
|
|
|
+ else:
|
|
|
+ print(f"[SQL_EXECUTION] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
|
|
|
+ # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
|
|
|
+
|
|
|
+ state["current_step"] = "sql_execution_completed"
|
|
|
+ state["execution_path"].append("agent_sql_execution")
|
|
|
+
|
|
|
+ print(f"[SQL_EXECUTION] SQL执行完成")
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] SQL执行节点异常: {str(e)}")
|
|
|
+ import traceback
|
|
|
+ print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
|
|
|
+ state["error"] = f"SQL执行失败: {str(e)}"
|
|
|
+ state["error_code"] = 500
|
|
|
+ state["current_step"] = "sql_execution_error"
|
|
|
+ state["execution_path"].append("agent_sql_execution_error")
|
|
|
+ return state
|
|
|
+
|
|
|
def _agent_database_node(self, state: AgentState) -> AgentState:
|
|
|
- """数据库Agent节点 - 直接工具调用模式"""
|
|
|
+ """
|
|
|
+ 数据库Agent节点 - 直接工具调用模式 [已废弃]
|
|
|
+
|
|
|
+ 注意:此方法已被拆分为 _agent_sql_generation_node 和 _agent_sql_execution_node
|
|
|
+ 保留此方法仅为向后兼容,新的工作流使用拆分后的节点
|
|
|
+ """
|
|
|
try:
|
|
|
+ print(f"[DATABASE_AGENT] ⚠️ 使用已废弃的database节点,建议使用新的拆分节点")
|
|
|
print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
|
|
|
|
|
|
question = state["question"]
|
|
@@ -362,7 +626,30 @@ class CituLangGraphAgent:
|
|
|
|
|
|
elif state["question_type"] == "DATABASE":
|
|
|
# 数据库查询类型
|
|
|
- if state.get("chat_response"):
|
|
|
+
|
|
|
+ # 处理SQL生成失败的情况
|
|
|
+ if not state.get("sql_generation_success", True) and state.get("user_prompt"):
|
|
|
+ state["final_response"] = {
|
|
|
+ "success": False,
|
|
|
+ "response": state["user_prompt"],
|
|
|
+ "type": "DATABASE",
|
|
|
+ "sql_generation_failed": True,
|
|
|
+ "validation_error_type": state.get("validation_error_type"),
|
|
|
+ "sql": state.get("sql"),
|
|
|
+ "execution_path": state["execution_path"],
|
|
|
+ "classification_info": {
|
|
|
+ "confidence": state["classification_confidence"],
|
|
|
+ "reason": state["classification_reason"],
|
|
|
+ "method": state["classification_method"]
|
|
|
+ },
|
|
|
+ "sql_validation_info": {
|
|
|
+ "sql_generation_success": state.get("sql_generation_success", False),
|
|
|
+ "sql_validation_success": state.get("sql_validation_success", False),
|
|
|
+ "sql_repair_attempted": state.get("sql_repair_attempted", False),
|
|
|
+ "sql_repair_success": state.get("sql_repair_success", False)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ elif state.get("chat_response"):
|
|
|
# SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
|
|
|
state["final_response"] = {
|
|
|
"success": True,
|
|
@@ -450,6 +737,23 @@ class CituLangGraphAgent:
|
|
|
}
|
|
|
return state
|
|
|
|
|
|
+ def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
|
|
|
+ """
|
|
|
+ SQL生成后的路由决策
|
|
|
+
|
|
|
+ 根据SQL生成和验证的结果决定后续流向:
|
|
|
+ - SQL生成验证成功 → 继续执行SQL
|
|
|
+ - SQL生成验证失败 → 直接返回用户提示
|
|
|
+ """
|
|
|
+ sql_generation_success = state.get("sql_generation_success", False)
|
|
|
+
|
|
|
+ print(f"[ROUTE] SQL生成路由: success={sql_generation_success}")
|
|
|
+
|
|
|
+ if sql_generation_success:
|
|
|
+ return "continue_execution" # 路由到SQL执行节点
|
|
|
+ else:
|
|
|
+ return "return_to_user" # 路由到format_response,结束流程
|
|
|
+
|
|
|
def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
|
|
|
"""
|
|
|
分类后的路由决策
|
|
@@ -472,7 +776,7 @@ class CituLangGraphAgent:
|
|
|
# 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
|
|
|
return "CHAT"
|
|
|
|
|
|
- def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
|
|
|
+ async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
|
|
|
"""
|
|
|
统一的问题处理入口
|
|
|
|
|
@@ -499,7 +803,7 @@ class CituLangGraphAgent:
|
|
|
initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
|
|
|
|
|
|
# 执行工作流
|
|
|
- final_state = workflow.invoke(
|
|
|
+ final_state = await workflow.ainvoke(
|
|
|
initial_state,
|
|
|
config={
|
|
|
"configurable": {"session_id": session_id}
|
|
@@ -554,6 +858,14 @@ class CituLangGraphAgent:
|
|
|
query_result=None,
|
|
|
summary=None,
|
|
|
|
|
|
+ # SQL验证和修复相关状态
|
|
|
+ sql_generation_success=False,
|
|
|
+ sql_validation_success=False,
|
|
|
+ sql_repair_attempted=False,
|
|
|
+ sql_repair_success=False,
|
|
|
+ validation_error_type=None,
|
|
|
+ user_prompt=None,
|
|
|
+
|
|
|
# 聊天响应
|
|
|
chat_response=None,
|
|
|
|
|
@@ -577,6 +889,208 @@ class CituLangGraphAgent:
|
|
|
routing_mode=effective_routing_mode
|
|
|
)
|
|
|
|
|
|
+ # ==================== SQL验证和修复相关方法 ====================
|
|
|
+
|
|
|
+ def _is_sql_validation_enabled(self) -> bool:
|
|
|
+ """检查是否启用SQL验证"""
|
|
|
+ from agent.config import get_nested_config
|
|
|
+ return (get_nested_config(self.config, "sql_validation.enable_syntax_validation", False) or
|
|
|
+ get_nested_config(self.config, "sql_validation.enable_forbidden_check", False))
|
|
|
+
|
|
|
+ def _is_auto_repair_enabled(self) -> bool:
|
|
|
+ """检查是否启用自动修复"""
|
|
|
+ from agent.config import get_nested_config
|
|
|
+ return (get_nested_config(self.config, "sql_validation.enable_auto_repair", False) and
|
|
|
+ get_nested_config(self.config, "sql_validation.enable_syntax_validation", False))
|
|
|
+
|
|
|
+ async def _validate_sql_with_custom_priority(self, sql: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 按照自定义优先级验证SQL:先禁止词,再语法
|
|
|
+
|
|
|
+ Args:
|
|
|
+ sql: 要验证的SQL语句
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 验证结果字典
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ from agent.config import get_nested_config
|
|
|
+
|
|
|
+ # 1. 优先检查禁止词(您要求的优先级)
|
|
|
+ if get_nested_config(self.config, "sql_validation.enable_forbidden_check", True):
|
|
|
+ forbidden_result = self._check_forbidden_keywords(sql)
|
|
|
+ if not forbidden_result.get("valid"):
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error_type": "forbidden_keywords",
|
|
|
+ "error_message": forbidden_result.get("error"),
|
|
|
+ "can_repair": False # 禁止词错误不能修复
|
|
|
+ }
|
|
|
+
|
|
|
+ # 2. 再检查语法(EXPLAIN SQL)
|
|
|
+ if get_nested_config(self.config, "sql_validation.enable_syntax_validation", True):
|
|
|
+ syntax_result = await self._validate_sql_syntax(sql)
|
|
|
+ if not syntax_result.get("valid"):
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error_type": "syntax_error",
|
|
|
+ "error_message": syntax_result.get("error"),
|
|
|
+ "can_repair": True # 语法错误可以尝试修复
|
|
|
+ }
|
|
|
+
|
|
|
+ return {"valid": True}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error_type": "validation_exception",
|
|
|
+ "error_message": str(e),
|
|
|
+ "can_repair": False
|
|
|
+ }
|
|
|
+
|
|
|
+ def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
|
|
|
+ """检查禁止的SQL关键词"""
|
|
|
+ try:
|
|
|
+ from agent.config import get_nested_config
|
|
|
+ forbidden_operations = get_nested_config(
|
|
|
+ self.config,
|
|
|
+ "sql_validation.forbidden_operations",
|
|
|
+ ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
|
|
|
+ )
|
|
|
+
|
|
|
+ sql_upper = sql.upper().strip()
|
|
|
+
|
|
|
+ for operation in forbidden_operations:
|
|
|
+ if sql_upper.startswith(operation.upper()):
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error": f"不允许的操作: {operation}。本系统只支持查询操作(SELECT)。"
|
|
|
+ }
|
|
|
+
|
|
|
+ return {"valid": True}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error": f"禁止词检查异常: {str(e)}"
|
|
|
+ }
|
|
|
+
|
|
|
+ async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
|
|
|
+ """语法验证 - 使用EXPLAIN SQL"""
|
|
|
+ try:
|
|
|
+ from common.vanna_instance import get_vanna_instance
|
|
|
+ import asyncio
|
|
|
+
|
|
|
+ vn = get_vanna_instance()
|
|
|
+
|
|
|
+ # 构建EXPLAIN查询
|
|
|
+ explain_sql = f"EXPLAIN {sql}"
|
|
|
+
|
|
|
+ # 异步执行验证
|
|
|
+ result = await asyncio.to_thread(vn.run_sql, explain_sql)
|
|
|
+
|
|
|
+ if result is not None:
|
|
|
+ return {"valid": True}
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error": "SQL语法验证失败"
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return {
|
|
|
+ "valid": False,
|
|
|
+ "error": str(e)
|
|
|
+ }
|
|
|
+
|
|
|
+ async def _attempt_sql_repair_once(self, sql: str, error_message: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 使用LLM尝试修复SQL - 只修复一次
|
|
|
+
|
|
|
+ Args:
|
|
|
+ sql: 原始SQL
|
|
|
+ error_message: 错误信息
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 修复结果字典
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ from common.vanna_instance import get_vanna_instance
|
|
|
+ from agent.config import get_nested_config
|
|
|
+ import asyncio
|
|
|
+
|
|
|
+ vn = get_vanna_instance()
|
|
|
+
|
|
|
+ # 构建修复提示词
|
|
|
+ repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
|
|
|
+
|
|
|
+当前数据库类型: PostgreSQL
|
|
|
+错误信息: {error_message}
|
|
|
+
|
|
|
+需要修复的SQL:
|
|
|
+{sql}
|
|
|
+
|
|
|
+修复要求:
|
|
|
+1. 只修复语法错误和表结构错误
|
|
|
+2. 保持SQL的原始业务逻辑不变
|
|
|
+3. 使用PostgreSQL标准语法
|
|
|
+4. 确保修复后的SQL语法正确
|
|
|
+
|
|
|
+请直接输出修复后的SQL语句,不要添加其他说明文字。"""
|
|
|
+
|
|
|
+ # 获取超时配置
|
|
|
+ timeout = get_nested_config(self.config, "sql_validation.repair_timeout", 60)
|
|
|
+
|
|
|
+ # 异步调用LLM修复
|
|
|
+ response = await asyncio.wait_for(
|
|
|
+ asyncio.to_thread(
|
|
|
+ vn.chat_with_llm,
|
|
|
+ question=repair_prompt,
|
|
|
+ system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
|
|
|
+ ),
|
|
|
+ timeout=timeout
|
|
|
+ )
|
|
|
+
|
|
|
+ if response and response.strip():
|
|
|
+ repaired_sql = response.strip()
|
|
|
+
|
|
|
+ # 验证修复后的SQL
|
|
|
+ validation_result = await self._validate_sql_syntax(repaired_sql)
|
|
|
+
|
|
|
+ if validation_result.get("valid"):
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "repaired_sql": repaired_sql,
|
|
|
+ "error": None
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "repaired_sql": None,
|
|
|
+ "error": f"修复后的SQL仍然无效: {validation_result.get('error')}"
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "repaired_sql": None,
|
|
|
+ "error": "LLM返回空响应"
|
|
|
+ }
|
|
|
+
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "repaired_sql": None,
|
|
|
+ "error": f"修复超时({get_nested_config(self.config, 'sql_validation.repair_timeout', 60)}秒)"
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "repaired_sql": None,
|
|
|
+ "error": f"修复异常: {str(e)}"
|
|
|
+ }
|
|
|
+
|
|
|
+ # ==================== 原有方法 ====================
|
|
|
+
|
|
|
def _extract_original_question(self, question: str) -> str:
|
|
|
"""
|
|
|
从enhanced_question中提取原始问题
|
|
@@ -603,7 +1117,7 @@ class CituLangGraphAgent:
|
|
|
print(f"[WARNING] 提取原始问题失败: {str(e)}")
|
|
|
return question.strip()
|
|
|
|
|
|
- def health_check(self) -> Dict[str, Any]:
|
|
|
+ async def health_check(self) -> Dict[str, Any]:
|
|
|
"""健康检查"""
|
|
|
try:
|
|
|
# 从配置获取健康检查参数
|
|
@@ -613,7 +1127,7 @@ class CituLangGraphAgent:
|
|
|
|
|
|
if enable_full_test:
|
|
|
# 完整流程测试
|
|
|
- test_result = self.process_question(test_question, "health_check")
|
|
|
+ test_result = await self.process_question(test_question, "health_check")
|
|
|
|
|
|
return {
|
|
|
"status": "healthy" if test_result.get("success") else "degraded",
|