| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176 | # agent/citu_agent.pyfrom typing import Dict, Any, Literalfrom langgraph.graph import StateGraph, ENDfrom langchain.agents import AgentExecutor, create_openai_tools_agentfrom langchain.prompts import ChatPromptTemplate, MessagesPlaceholderfrom langchain_core.messages import SystemMessage, HumanMessagefrom core.logging import get_agent_loggerfrom agent.state import AgentStatefrom agent.classifier import QuestionClassifierfrom agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chatfrom agent.tools.utils import get_compatible_llmfrom app_config import ENABLE_RESULT_SUMMARYclass CituLangGraphAgent:    """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""        def __init__(self):        # 初始化日志        self.logger = get_agent_logger("CituAgent")                # 加载配置        try:            from agent.config import get_current_config, get_nested_config            self.config = get_current_config()            self.logger.info("加载Agent配置完成")        except ImportError:            self.config = {}            self.logger.warning("配置文件不可用,使用默认配置")                self.classifier = QuestionClassifier()        self.tools = TOOLS        self.llm = get_compatible_llm()                # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器        self.logger.info("使用直接工具调用模式")                # 不在构造时创建workflow,改为动态创建以支持路由模式参数        # self.workflow = self._create_workflow()        self.logger.info("LangGraph Agent with Direct Tools初始化完成")        def _create_workflow(self, routing_mode: str = None) -> StateGraph:        """根据路由模式创建不同的工作流"""        # 确定使用的路由模式        if routing_mode:            QUESTION_ROUTING_MODE = routing_mode            self.logger.info(f"创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")        else:            try:                from app_config import QUESTION_ROUTING_MODE                self.logger.info(f"创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")            except ImportError:                QUESTION_ROUTING_MODE = "hybrid"                self.logger.warning(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")                workflow = StateGraph(AgentState)                # 根据路由模式创建不同的工作流        if QUESTION_ROUTING_MODE == "database_direct":            # 直接数据库模式:跳过分类,直接进入数据库处理(使用新的拆分节点)            workflow.add_node("init_direct_database", self._init_direct_database_node)            workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)            workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)            workflow.add_node("format_response", self._format_response_node)                        workflow.set_entry_point("init_direct_database")                        # 添加条件路由            workflow.add_edge("init_direct_database", "agent_sql_generation")            workflow.add_conditional_edges(                "agent_sql_generation",                self._route_after_sql_generation,                {                    "continue_execution": "agent_sql_execution",                    "return_to_user": "format_response"                }            )            workflow.add_edge("agent_sql_execution", "format_response")            workflow.add_edge("format_response", END)                    elif QUESTION_ROUTING_MODE == "chat_direct":            # 直接聊天模式:跳过分类,直接进入聊天处理            workflow.add_node("init_direct_chat", self._init_direct_chat_node)            workflow.add_node("agent_chat", self._agent_chat_node)            workflow.add_node("format_response", self._format_response_node)                        workflow.set_entry_point("init_direct_chat")            workflow.add_edge("init_direct_chat", "agent_chat")            workflow.add_edge("agent_chat", "format_response")            workflow.add_edge("format_response", END)                    else:            # 其他模式(hybrid, llm_only):使用新的拆分工作流            workflow.add_node("classify_question", self._classify_question_node)            workflow.add_node("agent_chat", self._agent_chat_node)            workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)            workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)            workflow.add_node("format_response", self._format_response_node)                        workflow.set_entry_point("classify_question")                        # 添加条件边:分类后的路由            workflow.add_conditional_edges(                "classify_question",                self._route_after_classification,                {                    "DATABASE": "agent_sql_generation",                    "CHAT": "agent_chat"                }            )                        # 添加条件边:SQL生成后的路由            workflow.add_conditional_edges(                "agent_sql_generation",                self._route_after_sql_generation,                {                    "continue_execution": "agent_sql_execution",                    "return_to_user": "format_response"                }            )                        # 普通边            workflow.add_edge("agent_chat", "format_response")            workflow.add_edge("agent_sql_execution", "format_response")            workflow.add_edge("format_response", END)                return workflow.compile()        def _init_direct_database_node(self, state: AgentState) -> AgentState:        """初始化直接数据库模式的状态"""        try:            # 从state中获取路由模式,而不是从配置文件读取            routing_mode = state.get("routing_mode", "database_direct")                        # 设置直接数据库模式的分类状态            state["question_type"] = "DATABASE"            state["classification_confidence"] = 1.0            state["classification_reason"] = "配置为直接数据库查询模式"            state["classification_method"] = "direct_database"            state["routing_mode"] = routing_mode            state["current_step"] = "direct_database_init"            state["execution_path"].append("init_direct_database")                        self.logger.info("直接数据库模式初始化完成")                        return state                    except Exception as e:            self.logger.error(f"直接数据库模式初始化异常: {str(e)}")            state["error"] = f"直接数据库模式初始化失败: {str(e)}"            state["error_code"] = 500            state["execution_path"].append("init_direct_database_error")            return state    def _init_direct_chat_node(self, state: AgentState) -> AgentState:        """初始化直接聊天模式的状态"""        try:            # 从state中获取路由模式,而不是从配置文件读取            routing_mode = state.get("routing_mode", "chat_direct")                        # 设置直接聊天模式的分类状态            state["question_type"] = "CHAT"            state["classification_confidence"] = 1.0            state["classification_reason"] = "配置为直接聊天模式"            state["classification_method"] = "direct_chat"            state["routing_mode"] = routing_mode            state["current_step"] = "direct_chat_init"            state["execution_path"].append("init_direct_chat")                        self.logger.info("直接聊天模式初始化完成")                        return state                    except Exception as e:            self.logger.error(f"直接聊天模式初始化异常: {str(e)}")            state["error"] = f"直接聊天模式初始化失败: {str(e)}"            state["error_code"] = 500            state["execution_path"].append("init_direct_chat_error")            return state        def _classify_question_node(self, state: AgentState) -> AgentState:        """问题分类节点 - 支持渐进式分类策略"""        try:            # 从state中获取路由模式,而不是从配置文件读取            routing_mode = state.get("routing_mode", "hybrid")                        self.logger.info(f"开始分类问题: {state['question']}")                        # 获取上下文类型(如果有的话)            context_type = state.get("context_type")            if context_type:                self.logger.info(f"检测到上下文类型: {context_type}")                        # 使用渐进式分类策略,传递路由模式            classification_result = self.classifier.classify(state["question"], context_type, routing_mode)                        # 更新状态            state["question_type"] = classification_result.question_type            state["classification_confidence"] = classification_result.confidence            state["classification_reason"] = classification_result.reason            state["classification_method"] = classification_result.method            state["routing_mode"] = routing_mode            state["current_step"] = "classified"            state["execution_path"].append("classify")                        self.logger.info(f"分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")            self.logger.info(f"路由模式: {routing_mode}, 分类方法: {classification_result.method}")                        return state                    except Exception as e:            self.logger.error(f"问题分类异常: {str(e)}")            state["error"] = f"问题分类失败: {str(e)}"            state["error_code"] = 500            state["execution_path"].append("classify_error")            return state            async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:        """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""        try:            self.logger.info(f"开始处理SQL生成和验证: {state['question']}")                        question = state["question"]                        # 步骤1:生成SQL            self.logger.info("步骤1:生成SQL")            sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})                        if not sql_result.get("success"):                # SQL生成失败的统一处理                error_message = sql_result.get("error", "")                error_type = sql_result.get("error_type", "")                                #print(f"[SQL_GENERATION] SQL生成失败: {error_message}")                self.logger.debug(f"error_type = '{error_type}'")                                # 根据错误类型生成用户提示                if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():                    user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"                    failure_reason = "missing_database_info"                elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():                    user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"                    failure_reason = "ambiguous_question"                elif error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":                    # 对于解释性文本,直接设置为聊天响应                    state["chat_response"] = error_message + " 请尝试提问其它问题。"                    state["sql_generation_success"] = False                    state["validation_error_type"] = "llm_explanation"                    state["current_step"] = "sql_generation_completed"                    state["execution_path"].append("agent_sql_generation")                    self.logger.info(f"返回LLM解释性答案: {error_message}")                    return state                else:                    user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"                    failure_reason = "unknown_generation_failure"                                # 统一返回失败状态                state["sql_generation_success"] = False                state["user_prompt"] = user_prompt                state["validation_error_type"] = failure_reason                state["current_step"] = "sql_generation_failed"                state["execution_path"].append("agent_sql_generation_failed")                                self.logger.warning(f"生成失败: {failure_reason} - {user_prompt}")                return state                        sql = sql_result.get("sql")            state["sql"] = sql                        # 步骤1.5:检查是否为解释性响应而非SQL            error_type = sql_result.get("error_type")            if error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":                # LLM返回了解释性文本,直接作为最终答案                explanation = sql_result.get("error", "")                state["chat_response"] = explanation + " 请尝试提问其它问题。"                state["sql_generation_success"] = False                state["validation_error_type"] = "llm_explanation"                state["current_step"] = "sql_generation_completed"                state["execution_path"].append("agent_sql_generation")                self.logger.info(f"返回LLM解释性答案: {explanation}")                return state                        if sql:                self.logger.info(f"SQL生成成功: {sql}")            else:                self.logger.warning("SQL为空,但不是解释性响应")                # 这种情况应该很少见,但为了安全起见保留原有的错误处理                return state                        # 额外验证:检查SQL格式(防止工具误判)            from agent.tools.utils import _is_valid_sql_format            if not _is_valid_sql_format(sql):                # 内容看起来不是SQL,当作解释性响应处理                state["chat_response"] = sql + " 请尝试提问其它问题。"                state["sql_generation_success"] = False                state["validation_error_type"] = "invalid_sql_format"                state["current_step"] = "sql_generation_completed"                  state["execution_path"].append("agent_sql_generation")                self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")                return state                        # 步骤2:SQL验证(如果启用)            if self._is_sql_validation_enabled():                self.logger.info("步骤2:验证SQL")                validation_result = await self._validate_sql_with_custom_priority(sql)                                if not validation_result.get("valid"):                    # 验证失败,检查是否可以修复                    error_type = validation_result.get("error_type")                    error_message = validation_result.get("error_message")                    can_repair = validation_result.get("can_repair", False)                                        self.logger.warning(f"SQL验证失败: {error_type} - {error_message}")                                        if error_type == "forbidden_keywords":                        # 禁止词错误,直接失败,不尝试修复                        state["sql_generation_success"] = False                        state["sql_validation_success"] = False                        state["user_prompt"] = error_message                        state["validation_error_type"] = "forbidden_keywords"                        state["current_step"] = "sql_validation_failed"                        state["execution_path"].append("forbidden_keywords_failed")                        self.logger.warning("禁止词验证失败,直接结束")                        return state                                        elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():                        # 语法错误,尝试修复(仅一次)                        self.logger.info(f"尝试修复SQL语法错误(仅一次): {error_message}")                        state["sql_repair_attempted"] = True                                                repair_result = await self._attempt_sql_repair_once(sql, error_message)                                                if repair_result.get("success"):                            # 修复成功                            repaired_sql = repair_result.get("repaired_sql")                            state["sql"] = repaired_sql                            state["sql_generation_success"] = True                            state["sql_validation_success"] = True                            state["sql_repair_success"] = True                            state["current_step"] = "sql_generation_completed"                            state["execution_path"].append("sql_repair_success")                            self.logger.info(f"SQL修复成功: {repaired_sql}")                            return state                        else:                            # 修复失败,直接结束                            repair_error = repair_result.get("error", "修复失败")                            self.logger.warning(f"SQL修复失败: {repair_error}")                            state["sql_generation_success"] = False                            state["sql_validation_success"] = False                            state["sql_repair_success"] = False                            state["user_prompt"] = f"SQL语法修复失败: {repair_error}"                            state["validation_error_type"] = "syntax_repair_failed"                            state["current_step"] = "sql_repair_failed"                            state["execution_path"].append("sql_repair_failed")                            return state                    else:                        # 不启用修复或其他错误类型,直接失败                        state["sql_generation_success"] = False                        state["sql_validation_success"] = False                        state["user_prompt"] = f"SQL验证失败: {error_message}"                        state["validation_error_type"] = error_type                        state["current_step"] = "sql_validation_failed"                        state["execution_path"].append("sql_validation_failed")                        self.logger.warning("SQL验证失败,不尝试修复")                        return state                else:                    self.logger.info("SQL验证通过")                    state["sql_validation_success"] = True            else:                self.logger.info("跳过SQL验证(未启用)")                state["sql_validation_success"] = True                        # 生成和验证都成功            state["sql_generation_success"] = True            state["current_step"] = "sql_generation_completed"            state["execution_path"].append("agent_sql_generation")                        self.logger.info("SQL生成验证完成,准备执行")            return state                    except Exception as e:            self.logger.error(f"SQL生成验证节点异常: {str(e)}")            import traceback            self.logger.error(f"详细错误信息: {traceback.format_exc()}")            state["sql_generation_success"] = False            state["sql_validation_success"] = False            state["user_prompt"] = f"SQL生成验证异常: {str(e)}"            state["validation_error_type"] = "node_exception"            state["current_step"] = "sql_generation_error"            state["execution_path"].append("agent_sql_generation_error")            return state    def _agent_sql_execution_node(self, state: AgentState) -> AgentState:        """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""        try:            self.logger.info(f"开始执行SQL: {state.get('sql', 'N/A')}")                        sql = state.get("sql")            question = state["question"]                        if not sql:                self.logger.warning("没有可执行的SQL")                state["error"] = "没有可执行的SQL语句"                state["error_code"] = 500                state["current_step"] = "sql_execution_error"                state["execution_path"].append("agent_sql_execution_error")                return state                        # 步骤1:执行SQL            self.logger.info("步骤1:执行SQL")            execute_result = execute_sql.invoke({"sql": sql})                        if not execute_result.get("success"):                self.logger.error(f"SQL执行失败: {execute_result.get('error')}")                state["error"] = execute_result.get("error", "SQL执行失败")                state["error_code"] = 500                state["current_step"] = "sql_execution_error"                state["execution_path"].append("agent_sql_execution_error")                return state                        query_result = execute_result.get("data_result")            state["query_result"] = query_result            self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")                        # 步骤2:生成摘要(根据配置和数据情况)            if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:                self.logger.info("步骤2:生成摘要")                                # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套                original_question = self._extract_original_question(question)                self.logger.debug(f"原始问题: {original_question}")                                summary_result = generate_summary.invoke({                    "question": original_question,  # 使用原始问题而不是enhanced_question                    "query_result": query_result,                    "sql": sql                })                                if not summary_result.get("success"):                    self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")                    # 摘要生成失败不是致命错误,使用默认摘要                    state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"                else:                    state["summary"] = summary_result.get("summary")                    self.logger.info("摘要生成成功")            else:                self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")                # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理                        state["current_step"] = "sql_execution_completed"            state["execution_path"].append("agent_sql_execution")                        self.logger.info("SQL执行完成")            return state                    except Exception as e:            self.logger.error(f"SQL执行节点异常: {str(e)}")            import traceback            self.logger.error(f"详细错误信息: {traceback.format_exc()}")            state["error"] = f"SQL执行失败: {str(e)}"            state["error_code"] = 500            state["current_step"] = "sql_execution_error"            state["execution_path"].append("agent_sql_execution_error")            return state    def _agent_database_node(self, state: AgentState) -> AgentState:        """        数据库Agent节点 - 直接工具调用模式 [已废弃]                注意:此方法已被拆分为 _agent_sql_generation_node 和 _agent_sql_execution_node        保留此方法仅为向后兼容,新的工作流使用拆分后的节点        """        try:            self.logger.warning("使用已废弃的database节点,建议使用新的拆分节点")            self.logger.info(f"开始处理数据库查询: {state['question']}")                        question = state["question"]                        # 步骤1:生成SQL            self.logger.info("步骤1:生成SQL")            sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})                        if not sql_result.get("success"):                self.logger.error(f"SQL生成失败: {sql_result.get('error')}")                state["error"] = sql_result.get("error", "SQL生成失败")                state["error_code"] = 500                state["current_step"] = "database_error"                state["execution_path"].append("agent_database_error")                return state                        sql = sql_result.get("sql")            state["sql"] = sql            self.logger.info(f"SQL生成成功: {sql}")                        # 步骤1.5:检查是否为解释性响应而非SQL            error_type = sql_result.get("error_type")            if error_type == "llm_explanation":                # LLM返回了解释性文本,直接作为最终答案                explanation = sql_result.get("error", "")                state["chat_response"] = explanation + " 请尝试提问其它问题。"                state["current_step"] = "database_completed"                state["execution_path"].append("agent_database")                self.logger.info(f"返回LLM解释性答案: {explanation}")                return state                        # 额外验证:检查SQL格式(防止工具误判)            from agent.tools.utils import _is_valid_sql_format            if not _is_valid_sql_format(sql):                # 内容看起来不是SQL,当作解释性响应处理                state["chat_response"] = sql + " 请尝试提问其它问题。"                state["current_step"] = "database_completed"                  state["execution_path"].append("agent_database")                self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")                return state                        # 步骤2:执行SQL            self.logger.info("步骤2:执行SQL")            execute_result = execute_sql.invoke({"sql": sql})                        if not execute_result.get("success"):                self.logger.error(f"SQL执行失败: {execute_result.get('error')}")                state["error"] = execute_result.get("error", "SQL执行失败")                state["error_code"] = 500                state["current_step"] = "database_error"                state["execution_path"].append("agent_database_error")                return state                        query_result = execute_result.get("data_result")            state["query_result"] = query_result            self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")                        # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)            if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:                self.logger.info("步骤3:生成摘要")                                # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套                original_question = self._extract_original_question(question)                self.logger.debug(f"原始问题: {original_question}")                                summary_result = generate_summary.invoke({                    "question": original_question,  # 使用原始问题而不是enhanced_question                    "query_result": query_result,                    "sql": sql                })                                if not summary_result.get("success"):                    self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")                    # 摘要生成失败不是致命错误,使用默认摘要                    state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"                else:                    state["summary"] = summary_result.get("summary")                    self.logger.info("摘要生成成功")            else:                self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")                # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理                        state["current_step"] = "database_completed"            state["execution_path"].append("agent_database")                        self.logger.info("数据库查询完成")            return state                    except Exception as e:            self.logger.error(f"数据库Agent异常: {str(e)}")            import traceback            self.logger.error(f"详细错误信息: {traceback.format_exc()}")            state["error"] = f"数据库查询失败: {str(e)}"            state["error_code"] = 500            state["current_step"] = "database_error"            state["execution_path"].append("agent_database_error")            return state        def _agent_chat_node(self, state: AgentState) -> AgentState:        """聊天Agent节点 - 直接工具调用模式"""        try:            self.logger.info(f"开始处理聊天: {state['question']}")                        question = state["question"]                        # 构建上下文 - 仅使用真实的对话历史上下文            # 注意:不要将分类原因传递给LLM,那是系统内部的路由信息            enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)            context = None            if enable_context_injection:                # TODO: 在这里可以添加真实的对话历史上下文                # 例如从Redis或其他存储中获取最近的对话记录                # context = get_conversation_history(state.get("session_id"))                pass                        # 直接调用general_chat工具            self.logger.info("调用general_chat工具")            chat_result = general_chat.invoke({                "question": question,                "context": context            })                        if chat_result.get("success"):                state["chat_response"] = chat_result.get("response", "")                self.logger.info("聊天处理成功")            else:                # 处理失败,使用备用响应                state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")                self.logger.warning(f"聊天处理失败,使用备用响应: {chat_result.get('error')}")                        state["current_step"] = "chat_completed"            state["execution_path"].append("agent_chat")                        self.logger.info("聊天处理完成")            return state                    except Exception as e:            self.logger.error(f"聊天Agent异常: {str(e)}")            import traceback            self.logger.error(f"详细错误信息: {traceback.format_exc()}")            state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"            state["current_step"] = "chat_error"            state["execution_path"].append("agent_chat_error")            return state        def _format_response_node(self, state: AgentState) -> AgentState:        """格式化最终响应节点"""        try:            self.logger.info(f"开始格式化响应,问题类型: {state['question_type']}")                        state["current_step"] = "completed"            state["execution_path"].append("format_response")                        # 根据问题类型和执行状态格式化响应            if state.get("error"):                # 有错误的情况                state["final_response"] = {                    "success": False,                    "error": state["error"],                    "error_code": state.get("error_code", 500),                    "question_type": state["question_type"],                    "execution_path": state["execution_path"],                    "classification_info": {                        "confidence": state.get("classification_confidence", 0),                        "reason": state.get("classification_reason", ""),                        "method": state.get("classification_method", "")                    }                }                        elif state["question_type"] == "DATABASE":                # 数据库查询类型                                # 处理SQL生成失败的情况                if not state.get("sql_generation_success", True) and state.get("user_prompt"):                    state["final_response"] = {                        "success": False,                        "response": state["user_prompt"],                        "type": "DATABASE",                        "sql_generation_failed": True,                        "validation_error_type": state.get("validation_error_type"),                        "sql": state.get("sql"),                        "execution_path": state["execution_path"],                        "classification_info": {                            "confidence": state["classification_confidence"],                            "reason": state["classification_reason"],                            "method": state["classification_method"]                        },                        "sql_validation_info": {                            "sql_generation_success": state.get("sql_generation_success", False),                            "sql_validation_success": state.get("sql_validation_success", False),                            "sql_repair_attempted": state.get("sql_repair_attempted", False),                            "sql_repair_success": state.get("sql_repair_success", False)                        }                    }                elif state.get("chat_response"):                    # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)                    state["final_response"] = {                        "success": True,                        "response": state["chat_response"],                        "type": "DATABASE",                        "sql": state.get("sql"),                        "query_result": state.get("query_result"),  # 保持内部字段名不变                        "execution_path": state["execution_path"],                        "classification_info": {                            "confidence": state["classification_confidence"],                            "reason": state["classification_reason"],                            "method": state["classification_method"]                        }                    }                elif state.get("summary"):                    # 正常的数据库查询结果,有摘要的情况                    # 将summary的值同时赋给response字段(为将来移除summary字段做准备)                    state["final_response"] = {                        "success": True,                        "type": "DATABASE",                        "response": state["summary"],  # 新增:将summary的值赋给response                        "sql": state.get("sql"),                        "query_result": state.get("query_result"),  # 保持内部字段名不变                        "summary": state["summary"],  # 暂时保留summary字段                        "execution_path": state["execution_path"],                        "classification_info": {                            "confidence": state["classification_confidence"],                            "reason": state["classification_reason"],                            "method": state["classification_method"]                        }                    }                elif state.get("query_result"):                    # 有数据但没有摘要(摘要被配置禁用)                    query_result = state.get("query_result")                    row_count = query_result.get("row_count", 0)                                        # 构建基本响应,不包含summary字段和response字段                    # 用户应该直接从query_result.columns和query_result.rows获取数据                    state["final_response"] = {                        "success": True,                        "type": "DATABASE",                        "sql": state.get("sql"),                        "query_result": query_result,  # 保持内部字段名不变                        "execution_path": state["execution_path"],                        "classification_info": {                            "confidence": state["classification_confidence"],                            "reason": state["classification_reason"],                            "method": state["classification_method"]                        }                    }                else:                    # 数据库查询失败,没有任何结果                    state["final_response"] = {                        "success": False,                        "error": state.get("error", "数据库查询未完成"),                        "type": "DATABASE",                        "sql": state.get("sql"),                        "execution_path": state["execution_path"]                    }                        else:                # 聊天类型                state["final_response"] = {                    "success": True,                    "response": state.get("chat_response", ""),                    "type": "CHAT",                    "execution_path": state["execution_path"],                    "classification_info": {                        "confidence": state["classification_confidence"],                        "reason": state["classification_reason"],                        "method": state["classification_method"]                    }                }                        self.logger.info("响应格式化完成")            return state                    except Exception as e:            self.logger.error(f"响应格式化异常: {str(e)}")            state["final_response"] = {                "success": False,                "error": f"响应格式化异常: {str(e)}",                "error_code": 500,                "execution_path": state["execution_path"]            }            return state        def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:        """        SQL生成后的路由决策                根据SQL生成和验证的结果决定后续流向:        - SQL生成验证成功 → 继续执行SQL        - SQL生成验证失败 → 直接返回用户提示        """        sql_generation_success = state.get("sql_generation_success", False)                self.logger.debug(f"SQL生成路由: success={sql_generation_success}")                if sql_generation_success:            return "continue_execution"  # 路由到SQL执行节点        else:            return "return_to_user"      # 路由到format_response,结束流程    def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:        """        分类后的路由决策                完全信任QuestionClassifier的决策:        - DATABASE类型 → 数据库Agent        - CHAT和UNCERTAIN类型 → 聊天Agent                这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中        """        question_type = state["question_type"]        confidence = state["classification_confidence"]                self.logger.debug(f"分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")                if question_type == "DATABASE":            return "DATABASE"        else:            # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程            # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息            return "CHAT"        async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:        """        统一的问题处理入口                Args:            question: 用户问题            session_id: 会话ID            context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类            routing_mode: 路由模式,可选,用于覆盖配置文件设置                    Returns:            Dict包含完整的处理结果        """        try:            self.logger.info(f"开始处理问题: {question}")            if context_type:                self.logger.info(f"上下文类型: {context_type}")            if routing_mode:                self.logger.info(f"使用指定路由模式: {routing_mode}")                        # 动态创建workflow(基于路由模式)            workflow = self._create_workflow(routing_mode)                        # 初始化状态            initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)                        # 执行工作流            final_state = await workflow.ainvoke(                initial_state,                config={                    "configurable": {"session_id": session_id}                } if session_id else None            )                        # 提取最终结果            result = final_state["final_response"]                        self.logger.info(f"问题处理完成: {result.get('success', False)}")                        return result                    except Exception as e:            self.logger.error(f"Agent执行异常: {str(e)}")            return {                "success": False,                "error": f"Agent系统异常: {str(e)}",                "error_code": 500,                "execution_path": ["error"]            }        def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:        """创建初始状态 - 支持渐进式分类"""        # 确定使用的路由模式        if routing_mode:            effective_routing_mode = routing_mode        else:            try:                from app_config import QUESTION_ROUTING_MODE                effective_routing_mode = QUESTION_ROUTING_MODE            except ImportError:                effective_routing_mode = "hybrid"                return AgentState(            # 输入信息            question=question,            session_id=session_id,                        # 上下文信息            context_type=context_type,                        # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)            question_type="UNCERTAIN",            classification_confidence=0.0,            classification_reason="",            classification_method="",                        # 数据库查询流程状态            sql=None,            sql_generation_attempts=0,            query_result=None,            summary=None,                        # SQL验证和修复相关状态            sql_generation_success=False,            sql_validation_success=False,            sql_repair_attempted=False,            sql_repair_success=False,            validation_error_type=None,            user_prompt=None,                        # 聊天响应            chat_response=None,                        # 最终输出            final_response={},                        # 错误处理            error=None,            error_code=None,                        # 流程控制            current_step="initialized",            execution_path=["start"],            retry_count=0,            max_retries=3,                        # 调试信息            debug_info={},                        # 路由模式            routing_mode=effective_routing_mode        )        # ==================== SQL验证和修复相关方法 ====================        def _is_sql_validation_enabled(self) -> bool:        """检查是否启用SQL验证"""        from agent.config import get_nested_config        return (get_nested_config(self.config, "sql_validation.enable_syntax_validation", False) or                 get_nested_config(self.config, "sql_validation.enable_forbidden_check", False))    def _is_auto_repair_enabled(self) -> bool:        """检查是否启用自动修复"""        from agent.config import get_nested_config        return (get_nested_config(self.config, "sql_validation.enable_auto_repair", False) and                 get_nested_config(self.config, "sql_validation.enable_syntax_validation", False))    async def _validate_sql_with_custom_priority(self, sql: str) -> Dict[str, Any]:        """        按照自定义优先级验证SQL:先禁止词,再语法                Args:            sql: 要验证的SQL语句                    Returns:            验证结果字典        """        try:            from agent.config import get_nested_config                        # 1. 优先检查禁止词(您要求的优先级)            if get_nested_config(self.config, "sql_validation.enable_forbidden_check", True):                forbidden_result = self._check_forbidden_keywords(sql)                if not forbidden_result.get("valid"):                    return {                        "valid": False,                        "error_type": "forbidden_keywords",                        "error_message": forbidden_result.get("error"),                        "can_repair": False  # 禁止词错误不能修复                    }                        # 2. 再检查语法(EXPLAIN SQL)            if get_nested_config(self.config, "sql_validation.enable_syntax_validation", True):                syntax_result = await self._validate_sql_syntax(sql)                if not syntax_result.get("valid"):                    return {                        "valid": False,                        "error_type": "syntax_error",                        "error_message": syntax_result.get("error"),                        "can_repair": True  # 语法错误可以尝试修复                    }                        return {"valid": True}                    except Exception as e:            return {                "valid": False,                "error_type": "validation_exception",                "error_message": str(e),                "can_repair": False            }    def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:        """检查禁止的SQL关键词"""        try:            from agent.config import get_nested_config            forbidden_operations = get_nested_config(                self.config,                 "sql_validation.forbidden_operations",                 ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']            )                        sql_upper = sql.upper().strip()                        for operation in forbidden_operations:                if sql_upper.startswith(operation.upper()):                    return {                        "valid": False,                        "error": f"不允许的操作: {operation}。本系统只支持查询操作(SELECT)。"                    }                        return {"valid": True}                    except Exception as e:            return {                "valid": False,                "error": f"禁止词检查异常: {str(e)}"            }    async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:        """语法验证 - 使用EXPLAIN SQL"""        try:            from common.vanna_instance import get_vanna_instance            import asyncio                        vn = get_vanna_instance()                        # 构建EXPLAIN查询            explain_sql = f"EXPLAIN {sql}"                        # 异步执行验证            result = await asyncio.to_thread(vn.run_sql, explain_sql)                        if result is not None:                return {"valid": True}            else:                return {                    "valid": False,                    "error": "SQL语法验证失败"                }                        except Exception as e:            return {                "valid": False,                "error": str(e)            }    async def _attempt_sql_repair_once(self, sql: str, error_message: str) -> Dict[str, Any]:        """        使用LLM尝试修复SQL - 只修复一次                Args:            sql: 原始SQL            error_message: 错误信息                    Returns:            修复结果字典        """        try:            from common.vanna_instance import get_vanna_instance            from agent.config import get_nested_config            import asyncio                        vn = get_vanna_instance()                        # 构建修复提示词            repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。当前数据库类型: PostgreSQL错误信息: {error_message}需要修复的SQL:{sql}修复要求:1. 只修复语法错误和表结构错误2. 保持SQL的原始业务逻辑不变  3. 使用PostgreSQL标准语法4. 确保修复后的SQL语法正确请直接输出修复后的SQL语句,不要添加其他说明文字。"""            # 获取超时配置            timeout = get_nested_config(self.config, "sql_validation.repair_timeout", 60)                        # 异步调用LLM修复            response = await asyncio.wait_for(                asyncio.to_thread(                    vn.chat_with_llm,                    question=repair_prompt,                    system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"                ),                timeout=timeout            )                        if response and response.strip():                repaired_sql = response.strip()                                # 验证修复后的SQL                validation_result = await self._validate_sql_syntax(repaired_sql)                                if validation_result.get("valid"):                    return {                        "success": True,                        "repaired_sql": repaired_sql,                        "error": None                    }                else:                    return {                        "success": False,                        "repaired_sql": None,                        "error": f"修复后的SQL仍然无效: {validation_result.get('error')}"                    }            else:                return {                    "success": False,                    "repaired_sql": None,                    "error": "LLM返回空响应"                }                        except asyncio.TimeoutError:            return {                "success": False,                "repaired_sql": None,                "error": f"修复超时({get_nested_config(self.config, 'sql_validation.repair_timeout', 60)}秒)"            }        except Exception as e:            return {                "success": False,                "repaired_sql": None,                "error": f"修复异常: {str(e)}"            }    # ==================== 原有方法 ====================        def _extract_original_question(self, question: str) -> str:        """        从enhanced_question中提取原始问题                Args:            question: 可能包含上下文的问题                    Returns:            str: 原始问题        """        try:            # 检查是否为enhanced_question格式            if "\n[CONTEXT]\n" in question and "\n[CURRENT]\n" in question:                # 提取[CURRENT]标签后的内容                current_start = question.find("\n[CURRENT]\n")                if current_start != -1:                    original_question = question[current_start + len("\n[CURRENT]\n"):].strip()                    return original_question                        # 如果不是enhanced_question格式,直接返回原问题            return question.strip()                    except Exception as e:            self.logger.warning(f"提取原始问题失败: {str(e)}")            return question.strip()    async def health_check(self) -> Dict[str, Any]:        """健康检查"""        try:            # 从配置获取健康检查参数            from agent.config import get_nested_config            test_question = get_nested_config(self.config, "health_check.test_question", "你好")            enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)                        if enable_full_test:                # 完整流程测试                test_result = await self.process_question(test_question, "health_check")                                return {                    "status": "healthy" if test_result.get("success") else "degraded",                    "test_result": test_result.get("success", False),                    "workflow_compiled": True,  # 动态创建,始终可用                    "tools_count": len(self.tools),                    "agent_reuse_enabled": False,                    "message": "Agent健康检查完成"                }            else:                # 简单检查                return {                    "status": "healthy",                    "test_result": True,                    "workflow_compiled": True,  # 动态创建,始终可用                    "tools_count": len(self.tools),                    "agent_reuse_enabled": False,                    "message": "Agent简单健康检查完成"                }                    except Exception as e:            return {                "status": "unhealthy",                "error": str(e),                "workflow_compiled": True,  # 动态创建,始终可用                "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,                "agent_reuse_enabled": False,                "message": "Agent健康检查失败"            }
 |