citu_agent.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from core.logging import get_agent_logger
  8. from agent.state import AgentState
  9. from agent.classifier import QuestionClassifier
  10. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  11. from agent.tools.utils import get_compatible_llm
  12. from app_config import ENABLE_RESULT_SUMMARY
  13. class CituLangGraphAgent:
  14. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  15. def __init__(self):
  16. # 初始化日志
  17. self.logger = get_agent_logger("CituAgent")
  18. # 加载配置
  19. try:
  20. from agent.config import get_current_config, get_nested_config
  21. self.config = get_current_config()
  22. self.logger.info("加载Agent配置完成")
  23. except ImportError:
  24. self.config = {}
  25. self.logger.warning("配置文件不可用,使用默认配置")
  26. self.classifier = QuestionClassifier()
  27. self.tools = TOOLS
  28. self.llm = get_compatible_llm()
  29. # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
  30. self.logger.info("使用直接工具调用模式")
  31. # 不在构造时创建workflow,改为动态创建以支持路由模式参数
  32. # self.workflow = self._create_workflow()
  33. self.logger.info("LangGraph Agent with Direct Tools初始化完成")
  34. def _create_workflow(self, routing_mode: str = None) -> StateGraph:
  35. """根据路由模式创建不同的工作流"""
  36. # 确定使用的路由模式
  37. if routing_mode:
  38. QUESTION_ROUTING_MODE = routing_mode
  39. self.logger.info(f"创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")
  40. else:
  41. try:
  42. from app_config import QUESTION_ROUTING_MODE
  43. self.logger.info(f"创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
  44. except ImportError:
  45. QUESTION_ROUTING_MODE = "hybrid"
  46. self.logger.warning(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  47. workflow = StateGraph(AgentState)
  48. # 根据路由模式创建不同的工作流
  49. if QUESTION_ROUTING_MODE == "database_direct":
  50. # 直接数据库模式:跳过分类,直接进入数据库处理(使用新的拆分节点)
  51. workflow.add_node("init_direct_database", self._init_direct_database_node)
  52. workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
  53. workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
  54. workflow.add_node("format_response", self._format_response_node)
  55. workflow.set_entry_point("init_direct_database")
  56. # 添加条件路由
  57. workflow.add_edge("init_direct_database", "agent_sql_generation")
  58. workflow.add_conditional_edges(
  59. "agent_sql_generation",
  60. self._route_after_sql_generation,
  61. {
  62. "continue_execution": "agent_sql_execution",
  63. "return_to_user": "format_response"
  64. }
  65. )
  66. workflow.add_edge("agent_sql_execution", "format_response")
  67. workflow.add_edge("format_response", END)
  68. elif QUESTION_ROUTING_MODE == "chat_direct":
  69. # 直接聊天模式:跳过分类,直接进入聊天处理
  70. workflow.add_node("init_direct_chat", self._init_direct_chat_node)
  71. workflow.add_node("agent_chat", self._agent_chat_node)
  72. workflow.add_node("format_response", self._format_response_node)
  73. workflow.set_entry_point("init_direct_chat")
  74. workflow.add_edge("init_direct_chat", "agent_chat")
  75. workflow.add_edge("agent_chat", "format_response")
  76. workflow.add_edge("format_response", END)
  77. else:
  78. # 其他模式(hybrid, llm_only):使用新的拆分工作流
  79. workflow.add_node("classify_question", self._classify_question_node)
  80. workflow.add_node("agent_chat", self._agent_chat_node)
  81. workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
  82. workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
  83. workflow.add_node("format_response", self._format_response_node)
  84. workflow.set_entry_point("classify_question")
  85. # 添加条件边:分类后的路由
  86. workflow.add_conditional_edges(
  87. "classify_question",
  88. self._route_after_classification,
  89. {
  90. "DATABASE": "agent_sql_generation",
  91. "CHAT": "agent_chat"
  92. }
  93. )
  94. # 添加条件边:SQL生成后的路由
  95. workflow.add_conditional_edges(
  96. "agent_sql_generation",
  97. self._route_after_sql_generation,
  98. {
  99. "continue_execution": "agent_sql_execution",
  100. "return_to_user": "format_response"
  101. }
  102. )
  103. # 普通边
  104. workflow.add_edge("agent_chat", "format_response")
  105. workflow.add_edge("agent_sql_execution", "format_response")
  106. workflow.add_edge("format_response", END)
  107. return workflow.compile()
  108. def _init_direct_database_node(self, state: AgentState) -> AgentState:
  109. """初始化直接数据库模式的状态"""
  110. try:
  111. # 从state中获取路由模式,而不是从配置文件读取
  112. routing_mode = state.get("routing_mode", "database_direct")
  113. # 设置直接数据库模式的分类状态
  114. state["question_type"] = "DATABASE"
  115. state["classification_confidence"] = 1.0
  116. state["classification_reason"] = "配置为直接数据库查询模式"
  117. state["classification_method"] = "direct_database"
  118. state["routing_mode"] = routing_mode
  119. state["current_step"] = "direct_database_init"
  120. state["execution_path"].append("init_direct_database")
  121. self.logger.info("直接数据库模式初始化完成")
  122. return state
  123. except Exception as e:
  124. self.logger.error(f"直接数据库模式初始化异常: {str(e)}")
  125. state["error"] = f"直接数据库模式初始化失败: {str(e)}"
  126. state["error_code"] = 500
  127. state["execution_path"].append("init_direct_database_error")
  128. return state
  129. def _init_direct_chat_node(self, state: AgentState) -> AgentState:
  130. """初始化直接聊天模式的状态"""
  131. try:
  132. # 从state中获取路由模式,而不是从配置文件读取
  133. routing_mode = state.get("routing_mode", "chat_direct")
  134. # 设置直接聊天模式的分类状态
  135. state["question_type"] = "CHAT"
  136. state["classification_confidence"] = 1.0
  137. state["classification_reason"] = "配置为直接聊天模式"
  138. state["classification_method"] = "direct_chat"
  139. state["routing_mode"] = routing_mode
  140. state["current_step"] = "direct_chat_init"
  141. state["execution_path"].append("init_direct_chat")
  142. self.logger.info("直接聊天模式初始化完成")
  143. return state
  144. except Exception as e:
  145. self.logger.error(f"直接聊天模式初始化异常: {str(e)}")
  146. state["error"] = f"直接聊天模式初始化失败: {str(e)}"
  147. state["error_code"] = 500
  148. state["execution_path"].append("init_direct_chat_error")
  149. return state
  150. def _classify_question_node(self, state: AgentState) -> AgentState:
  151. """问题分类节点 - 支持渐进式分类策略"""
  152. try:
  153. # 从state中获取路由模式,而不是从配置文件读取
  154. routing_mode = state.get("routing_mode", "hybrid")
  155. self.logger.info(f"开始分类问题: {state['question']}")
  156. # 获取上下文类型(如果有的话)
  157. context_type = state.get("context_type")
  158. if context_type:
  159. self.logger.info(f"检测到上下文类型: {context_type}")
  160. # 使用渐进式分类策略,传递路由模式
  161. classification_result = self.classifier.classify(state["question"], context_type, routing_mode)
  162. # 更新状态
  163. state["question_type"] = classification_result.question_type
  164. state["classification_confidence"] = classification_result.confidence
  165. state["classification_reason"] = classification_result.reason
  166. state["classification_method"] = classification_result.method
  167. state["routing_mode"] = routing_mode
  168. state["current_step"] = "classified"
  169. state["execution_path"].append("classify")
  170. self.logger.info(f"分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  171. self.logger.info(f"路由模式: {routing_mode}, 分类方法: {classification_result.method}")
  172. return state
  173. except Exception as e:
  174. self.logger.error(f"问题分类异常: {str(e)}")
  175. state["error"] = f"问题分类失败: {str(e)}"
  176. state["error_code"] = 500
  177. state["execution_path"].append("classify_error")
  178. return state
  179. async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
  180. """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""
  181. try:
  182. self.logger.info(f"开始处理SQL生成和验证: {state['question']}")
  183. question = state["question"]
  184. # 步骤1:生成SQL
  185. self.logger.info("步骤1:生成SQL")
  186. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  187. if not sql_result.get("success"):
  188. # SQL生成失败的统一处理
  189. error_message = sql_result.get("error", "")
  190. error_type = sql_result.get("error_type", "")
  191. self.logger.debug(f"error_type = '{error_type}'")
  192. # 根据错误类型生成用户提示
  193. if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
  194. user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
  195. failure_reason = "missing_database_info"
  196. elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
  197. user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
  198. failure_reason = "ambiguous_question"
  199. elif error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
  200. # 对于解释性文本,直接设置为聊天响应
  201. state["chat_response"] = error_message + " 请尝试提问其它问题。"
  202. state["sql_generation_success"] = False
  203. state["validation_error_type"] = "llm_explanation"
  204. state["current_step"] = "sql_generation_completed"
  205. state["execution_path"].append("agent_sql_generation")
  206. self.logger.info(f"返回LLM解释性答案: {error_message}")
  207. return state
  208. else:
  209. user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
  210. failure_reason = "unknown_generation_failure"
  211. # 统一返回失败状态
  212. state["sql_generation_success"] = False
  213. state["user_prompt"] = user_prompt
  214. state["validation_error_type"] = failure_reason
  215. state["current_step"] = "sql_generation_failed"
  216. state["execution_path"].append("agent_sql_generation_failed")
  217. self.logger.warning(f"生成失败: {failure_reason} - {user_prompt}")
  218. return state
  219. sql = sql_result.get("sql")
  220. state["sql"] = sql
  221. # 步骤1.5:检查是否为解释性响应而非SQL
  222. error_type = sql_result.get("error_type")
  223. if error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
  224. # LLM返回了解释性文本,直接作为最终答案
  225. explanation = sql_result.get("error", "")
  226. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  227. state["sql_generation_success"] = False
  228. state["validation_error_type"] = "llm_explanation"
  229. state["current_step"] = "sql_generation_completed"
  230. state["execution_path"].append("agent_sql_generation")
  231. self.logger.info(f"返回LLM解释性答案: {explanation}")
  232. return state
  233. if sql:
  234. self.logger.info(f"SQL生成成功: {sql}")
  235. else:
  236. self.logger.warning("SQL为空,但不是解释性响应")
  237. # 这种情况应该很少见,但为了安全起见保留原有的错误处理
  238. return state
  239. # 额外验证:检查SQL格式(防止工具误判)
  240. from agent.tools.utils import _is_valid_sql_format
  241. if not _is_valid_sql_format(sql):
  242. # 内容看起来不是SQL,当作解释性响应处理
  243. state["chat_response"] = sql + " 请尝试提问其它问题。"
  244. state["sql_generation_success"] = False
  245. state["validation_error_type"] = "invalid_sql_format"
  246. state["current_step"] = "sql_generation_completed"
  247. state["execution_path"].append("agent_sql_generation")
  248. self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
  249. return state
  250. # 步骤2:SQL验证(如果启用)
  251. if self._is_sql_validation_enabled():
  252. self.logger.info("步骤2:验证SQL")
  253. validation_result = await self._validate_sql_with_custom_priority(sql)
  254. if not validation_result.get("valid"):
  255. # 验证失败,检查是否可以修复
  256. error_type = validation_result.get("error_type")
  257. error_message = validation_result.get("error_message")
  258. can_repair = validation_result.get("can_repair", False)
  259. self.logger.warning(f"SQL验证失败: {error_type} - {error_message}")
  260. if error_type == "forbidden_keywords":
  261. # 禁止词错误,直接失败,不尝试修复
  262. state["sql_generation_success"] = False
  263. state["sql_validation_success"] = False
  264. state["user_prompt"] = error_message
  265. state["validation_error_type"] = "forbidden_keywords"
  266. state["current_step"] = "sql_validation_failed"
  267. state["execution_path"].append("forbidden_keywords_failed")
  268. self.logger.warning("禁止词验证失败,直接结束")
  269. return state
  270. elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():
  271. # 语法错误,尝试修复(仅一次)
  272. self.logger.info(f"尝试修复SQL语法错误(仅一次): {error_message}")
  273. state["sql_repair_attempted"] = True
  274. repair_result = await self._attempt_sql_repair_once(sql, error_message)
  275. if repair_result.get("success"):
  276. # 修复成功
  277. repaired_sql = repair_result.get("repaired_sql")
  278. state["sql"] = repaired_sql
  279. state["sql_generation_success"] = True
  280. state["sql_validation_success"] = True
  281. state["sql_repair_success"] = True
  282. state["current_step"] = "sql_generation_completed"
  283. state["execution_path"].append("sql_repair_success")
  284. self.logger.info(f"SQL修复成功: {repaired_sql}")
  285. return state
  286. else:
  287. # 修复失败,直接结束
  288. repair_error = repair_result.get("error", "修复失败")
  289. self.logger.warning(f"SQL修复失败: {repair_error}")
  290. state["sql_generation_success"] = False
  291. state["sql_validation_success"] = False
  292. state["sql_repair_success"] = False
  293. state["user_prompt"] = f"SQL语法修复失败: {repair_error}"
  294. state["validation_error_type"] = "syntax_repair_failed"
  295. state["current_step"] = "sql_repair_failed"
  296. state["execution_path"].append("sql_repair_failed")
  297. return state
  298. else:
  299. # 不启用修复或其他错误类型,直接失败
  300. state["sql_generation_success"] = False
  301. state["sql_validation_success"] = False
  302. state["user_prompt"] = f"SQL验证失败: {error_message}"
  303. state["validation_error_type"] = error_type
  304. state["current_step"] = "sql_validation_failed"
  305. state["execution_path"].append("sql_validation_failed")
  306. self.logger.warning("SQL验证失败,不尝试修复")
  307. return state
  308. else:
  309. self.logger.info("SQL验证通过")
  310. state["sql_validation_success"] = True
  311. else:
  312. self.logger.info("跳过SQL验证(未启用)")
  313. state["sql_validation_success"] = True
  314. # 生成和验证都成功
  315. state["sql_generation_success"] = True
  316. state["current_step"] = "sql_generation_completed"
  317. state["execution_path"].append("agent_sql_generation")
  318. self.logger.info("SQL生成验证完成,准备执行")
  319. return state
  320. except Exception as e:
  321. self.logger.error(f"SQL生成验证节点异常: {str(e)}")
  322. import traceback
  323. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  324. state["sql_generation_success"] = False
  325. state["sql_validation_success"] = False
  326. state["user_prompt"] = f"SQL生成验证异常: {str(e)}"
  327. state["validation_error_type"] = "node_exception"
  328. state["current_step"] = "sql_generation_error"
  329. state["execution_path"].append("agent_sql_generation_error")
  330. return state
  331. def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
  332. """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""
  333. try:
  334. self.logger.info(f"开始执行SQL: {state.get('sql', 'N/A')}")
  335. sql = state.get("sql")
  336. question = state["question"]
  337. if not sql:
  338. self.logger.warning("没有可执行的SQL")
  339. state["error"] = "没有可执行的SQL语句"
  340. state["error_code"] = 500
  341. state["current_step"] = "sql_execution_error"
  342. state["execution_path"].append("agent_sql_execution_error")
  343. return state
  344. # 步骤1:执行SQL
  345. self.logger.info("步骤1:执行SQL")
  346. execute_result = execute_sql.invoke({"sql": sql})
  347. if not execute_result.get("success"):
  348. self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
  349. state["error"] = execute_result.get("error", "SQL执行失败")
  350. state["error_code"] = 500
  351. state["current_step"] = "sql_execution_error"
  352. state["execution_path"].append("agent_sql_execution_error")
  353. return state
  354. query_result = execute_result.get("data_result")
  355. state["query_result"] = query_result
  356. self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
  357. # 步骤2:生成摘要(根据配置和数据情况)
  358. if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
  359. self.logger.info("步骤2:生成摘要")
  360. # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
  361. original_question = self._extract_original_question(question)
  362. self.logger.debug(f"原始问题: {original_question}")
  363. summary_result = generate_summary.invoke({
  364. "question": original_question, # 使用原始问题而不是enhanced_question
  365. "query_result": query_result,
  366. "sql": sql
  367. })
  368. if not summary_result.get("success"):
  369. self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
  370. # 摘要生成失败不是致命错误,使用默认摘要
  371. state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
  372. else:
  373. state["summary"] = summary_result.get("summary")
  374. self.logger.info("摘要生成成功")
  375. else:
  376. self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
  377. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  378. state["current_step"] = "sql_execution_completed"
  379. state["execution_path"].append("agent_sql_execution")
  380. self.logger.info("SQL执行完成")
  381. return state
  382. except Exception as e:
  383. self.logger.error(f"SQL执行节点异常: {str(e)}")
  384. import traceback
  385. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  386. state["error"] = f"SQL执行失败: {str(e)}"
  387. state["error_code"] = 500
  388. state["current_step"] = "sql_execution_error"
  389. state["execution_path"].append("agent_sql_execution_error")
  390. return state
  391. def _agent_database_node(self, state: AgentState) -> AgentState:
  392. """
  393. 数据库Agent节点 - 直接工具调用模式 [已废弃]
  394. 注意:此方法已被拆分为 _agent_sql_generation_node 和 _agent_sql_execution_node
  395. 保留此方法仅为向后兼容,新的工作流使用拆分后的节点
  396. """
  397. try:
  398. self.logger.warning("使用已废弃的database节点,建议使用新的拆分节点")
  399. self.logger.info(f"开始处理数据库查询: {state['question']}")
  400. question = state["question"]
  401. # 步骤1:生成SQL
  402. self.logger.info("步骤1:生成SQL")
  403. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  404. if not sql_result.get("success"):
  405. self.logger.error(f"SQL生成失败: {sql_result.get('error')}")
  406. state["error"] = sql_result.get("error", "SQL生成失败")
  407. state["error_code"] = 500
  408. state["current_step"] = "database_error"
  409. state["execution_path"].append("agent_database_error")
  410. return state
  411. sql = sql_result.get("sql")
  412. state["sql"] = sql
  413. self.logger.info(f"SQL生成成功: {sql}")
  414. # 步骤1.5:检查是否为解释性响应而非SQL
  415. error_type = sql_result.get("error_type")
  416. if error_type == "llm_explanation":
  417. # LLM返回了解释性文本,直接作为最终答案
  418. explanation = sql_result.get("error", "")
  419. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  420. state["current_step"] = "database_completed"
  421. state["execution_path"].append("agent_database")
  422. self.logger.info(f"返回LLM解释性答案: {explanation}")
  423. return state
  424. # 额外验证:检查SQL格式(防止工具误判)
  425. from agent.tools.utils import _is_valid_sql_format
  426. if not _is_valid_sql_format(sql):
  427. # 内容看起来不是SQL,当作解释性响应处理
  428. state["chat_response"] = sql + " 请尝试提问其它问题。"
  429. state["current_step"] = "database_completed"
  430. state["execution_path"].append("agent_database")
  431. self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
  432. return state
  433. # 步骤2:执行SQL
  434. self.logger.info("步骤2:执行SQL")
  435. execute_result = execute_sql.invoke({"sql": sql})
  436. if not execute_result.get("success"):
  437. self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
  438. state["error"] = execute_result.get("error", "SQL执行失败")
  439. state["error_code"] = 500
  440. state["current_step"] = "database_error"
  441. state["execution_path"].append("agent_database_error")
  442. return state
  443. query_result = execute_result.get("data_result")
  444. state["query_result"] = query_result
  445. self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
  446. # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
  447. if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
  448. self.logger.info("步骤3:生成摘要")
  449. # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
  450. original_question = self._extract_original_question(question)
  451. self.logger.debug(f"原始问题: {original_question}")
  452. summary_result = generate_summary.invoke({
  453. "question": original_question, # 使用原始问题而不是enhanced_question
  454. "query_result": query_result,
  455. "sql": sql
  456. })
  457. if not summary_result.get("success"):
  458. self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
  459. # 摘要生成失败不是致命错误,使用默认摘要
  460. state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
  461. else:
  462. state["summary"] = summary_result.get("summary")
  463. self.logger.info("摘要生成成功")
  464. else:
  465. self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
  466. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  467. state["current_step"] = "database_completed"
  468. state["execution_path"].append("agent_database")
  469. self.logger.info("数据库查询完成")
  470. return state
  471. except Exception as e:
  472. self.logger.error(f"数据库Agent异常: {str(e)}")
  473. import traceback
  474. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  475. state["error"] = f"数据库查询失败: {str(e)}"
  476. state["error_code"] = 500
  477. state["current_step"] = "database_error"
  478. state["execution_path"].append("agent_database_error")
  479. return state
  480. def _agent_chat_node(self, state: AgentState) -> AgentState:
  481. """聊天Agent节点 - 直接工具调用模式"""
  482. try:
  483. self.logger.info(f"开始处理聊天: {state['question']}")
  484. question = state["question"]
  485. # 构建上下文 - 仅使用真实的对话历史上下文
  486. # 注意:不要将分类原因传递给LLM,那是系统内部的路由信息
  487. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  488. context = None
  489. if enable_context_injection:
  490. # TODO: 在这里可以添加真实的对话历史上下文
  491. # 例如从Redis或其他存储中获取最近的对话记录
  492. # context = get_conversation_history(state.get("session_id"))
  493. pass
  494. # 直接调用general_chat工具
  495. self.logger.info("调用general_chat工具")
  496. chat_result = general_chat.invoke({
  497. "question": question,
  498. "context": context
  499. })
  500. if chat_result.get("success"):
  501. state["chat_response"] = chat_result.get("response", "")
  502. self.logger.info("聊天处理成功")
  503. else:
  504. # 处理失败,使用备用响应
  505. state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
  506. self.logger.warning(f"聊天处理失败,使用备用响应: {chat_result.get('error')}")
  507. state["current_step"] = "chat_completed"
  508. state["execution_path"].append("agent_chat")
  509. self.logger.info("聊天处理完成")
  510. return state
  511. except Exception as e:
  512. self.logger.error(f"聊天Agent异常: {str(e)}")
  513. import traceback
  514. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  515. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  516. state["current_step"] = "chat_error"
  517. state["execution_path"].append("agent_chat_error")
  518. return state
  519. def _format_response_node(self, state: AgentState) -> AgentState:
  520. """格式化最终响应节点"""
  521. try:
  522. self.logger.info(f"开始格式化响应,问题类型: {state['question_type']}")
  523. state["current_step"] = "completed"
  524. state["execution_path"].append("format_response")
  525. # 根据问题类型和执行状态格式化响应
  526. if state.get("error"):
  527. # 有错误的情况
  528. state["final_response"] = {
  529. "success": False,
  530. "error": state["error"],
  531. "error_code": state.get("error_code", 500),
  532. "question_type": state["question_type"],
  533. "execution_path": state["execution_path"],
  534. "classification_info": {
  535. "confidence": state.get("classification_confidence", 0),
  536. "reason": state.get("classification_reason", ""),
  537. "method": state.get("classification_method", "")
  538. }
  539. }
  540. elif state["question_type"] == "DATABASE":
  541. # 数据库查询类型
  542. # 处理SQL生成失败的情况
  543. if not state.get("sql_generation_success", True) and state.get("user_prompt"):
  544. state["final_response"] = {
  545. "success": False,
  546. "response": state["user_prompt"],
  547. "type": "DATABASE",
  548. "sql_generation_failed": True,
  549. "validation_error_type": state.get("validation_error_type"),
  550. "sql": state.get("sql"),
  551. "execution_path": state["execution_path"],
  552. "classification_info": {
  553. "confidence": state["classification_confidence"],
  554. "reason": state["classification_reason"],
  555. "method": state["classification_method"]
  556. },
  557. "sql_validation_info": {
  558. "sql_generation_success": state.get("sql_generation_success", False),
  559. "sql_validation_success": state.get("sql_validation_success", False),
  560. "sql_repair_attempted": state.get("sql_repair_attempted", False),
  561. "sql_repair_success": state.get("sql_repair_success", False)
  562. }
  563. }
  564. elif state.get("chat_response"):
  565. # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
  566. state["final_response"] = {
  567. "success": True,
  568. "response": state["chat_response"],
  569. "type": "DATABASE",
  570. "sql": state.get("sql"),
  571. "query_result": state.get("query_result"), # 保持内部字段名不变
  572. "execution_path": state["execution_path"],
  573. "classification_info": {
  574. "confidence": state["classification_confidence"],
  575. "reason": state["classification_reason"],
  576. "method": state["classification_method"]
  577. }
  578. }
  579. elif state.get("summary"):
  580. # 正常的数据库查询结果,有摘要的情况
  581. # 将summary的值同时赋给response字段(为将来移除summary字段做准备)
  582. state["final_response"] = {
  583. "success": True,
  584. "type": "DATABASE",
  585. "response": state["summary"], # 新增:将summary的值赋给response
  586. "sql": state.get("sql"),
  587. "query_result": state.get("query_result"), # 保持内部字段名不变
  588. "summary": state["summary"], # 暂时保留summary字段
  589. "execution_path": state["execution_path"],
  590. "classification_info": {
  591. "confidence": state["classification_confidence"],
  592. "reason": state["classification_reason"],
  593. "method": state["classification_method"]
  594. }
  595. }
  596. elif state.get("query_result"):
  597. # 有数据但没有摘要(摘要被配置禁用)
  598. query_result = state.get("query_result")
  599. row_count = query_result.get("row_count", 0)
  600. # 构建基本响应,不包含summary字段和response字段
  601. # 用户应该直接从query_result.columns和query_result.rows获取数据
  602. state["final_response"] = {
  603. "success": True,
  604. "type": "DATABASE",
  605. "sql": state.get("sql"),
  606. "query_result": query_result, # 保持内部字段名不变
  607. "execution_path": state["execution_path"],
  608. "classification_info": {
  609. "confidence": state["classification_confidence"],
  610. "reason": state["classification_reason"],
  611. "method": state["classification_method"]
  612. }
  613. }
  614. else:
  615. # 数据库查询失败,没有任何结果
  616. state["final_response"] = {
  617. "success": False,
  618. "error": state.get("error", "数据库查询未完成"),
  619. "type": "DATABASE",
  620. "sql": state.get("sql"),
  621. "execution_path": state["execution_path"]
  622. }
  623. else:
  624. # 聊天类型
  625. state["final_response"] = {
  626. "success": True,
  627. "response": state.get("chat_response", ""),
  628. "type": "CHAT",
  629. "execution_path": state["execution_path"],
  630. "classification_info": {
  631. "confidence": state["classification_confidence"],
  632. "reason": state["classification_reason"],
  633. "method": state["classification_method"]
  634. }
  635. }
  636. self.logger.info("响应格式化完成")
  637. return state
  638. except Exception as e:
  639. self.logger.error(f"响应格式化异常: {str(e)}")
  640. state["final_response"] = {
  641. "success": False,
  642. "error": f"响应格式化异常: {str(e)}",
  643. "error_code": 500,
  644. "execution_path": state["execution_path"]
  645. }
  646. return state
  647. def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
  648. """
  649. SQL生成后的路由决策
  650. 根据SQL生成和验证的结果决定后续流向:
  651. - SQL生成验证成功 → 继续执行SQL
  652. - SQL生成验证失败 → 直接返回用户提示
  653. """
  654. sql_generation_success = state.get("sql_generation_success", False)
  655. self.logger.debug(f"SQL生成路由: success={sql_generation_success}")
  656. if sql_generation_success:
  657. return "continue_execution" # 路由到SQL执行节点
  658. else:
  659. return "return_to_user" # 路由到format_response,结束流程
  660. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  661. """
  662. 分类后的路由决策
  663. 完全信任QuestionClassifier的决策:
  664. - DATABASE类型 → 数据库Agent
  665. - CHAT和UNCERTAIN类型 → 聊天Agent
  666. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  667. """
  668. question_type = state["question_type"]
  669. confidence = state["classification_confidence"]
  670. self.logger.debug(f"分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  671. if question_type == "DATABASE":
  672. return "DATABASE"
  673. else:
  674. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  675. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  676. return "CHAT"
  677. async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
  678. """
  679. 统一的问题处理入口
  680. Args:
  681. question: 用户问题
  682. session_id: 会话ID
  683. context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
  684. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  685. Returns:
  686. Dict包含完整的处理结果
  687. """
  688. try:
  689. self.logger.info(f"开始处理问题: {question}")
  690. if context_type:
  691. self.logger.info(f"上下文类型: {context_type}")
  692. if routing_mode:
  693. self.logger.info(f"使用指定路由模式: {routing_mode}")
  694. # 动态创建workflow(基于路由模式)
  695. workflow = self._create_workflow(routing_mode)
  696. # 初始化状态
  697. initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
  698. # 执行工作流
  699. final_state = await workflow.ainvoke(
  700. initial_state,
  701. config={
  702. "configurable": {"session_id": session_id}
  703. } if session_id else None
  704. )
  705. # 提取最终结果
  706. result = final_state["final_response"]
  707. self.logger.info(f"问题处理完成: {result.get('success', False)}")
  708. return result
  709. except Exception as e:
  710. self.logger.error(f"Agent执行异常: {str(e)}")
  711. return {
  712. "success": False,
  713. "error": f"Agent系统异常: {str(e)}",
  714. "error_code": 500,
  715. "execution_path": ["error"]
  716. }
  717. def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
  718. """创建初始状态 - 支持渐进式分类"""
  719. # 确定使用的路由模式
  720. if routing_mode:
  721. effective_routing_mode = routing_mode
  722. else:
  723. try:
  724. from app_config import QUESTION_ROUTING_MODE
  725. effective_routing_mode = QUESTION_ROUTING_MODE
  726. except ImportError:
  727. effective_routing_mode = "hybrid"
  728. return AgentState(
  729. # 输入信息
  730. question=question,
  731. session_id=session_id,
  732. # 上下文信息
  733. context_type=context_type,
  734. # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
  735. question_type="UNCERTAIN",
  736. classification_confidence=0.0,
  737. classification_reason="",
  738. classification_method="",
  739. # 数据库查询流程状态
  740. sql=None,
  741. sql_generation_attempts=0,
  742. query_result=None,
  743. summary=None,
  744. # SQL验证和修复相关状态
  745. sql_generation_success=False,
  746. sql_validation_success=False,
  747. sql_repair_attempted=False,
  748. sql_repair_success=False,
  749. validation_error_type=None,
  750. user_prompt=None,
  751. # 聊天响应
  752. chat_response=None,
  753. # 最终输出
  754. final_response={},
  755. # 错误处理
  756. error=None,
  757. error_code=None,
  758. # 流程控制
  759. current_step="initialized",
  760. execution_path=["start"],
  761. retry_count=0,
  762. max_retries=3,
  763. # 调试信息
  764. debug_info={},
  765. # 路由模式
  766. routing_mode=effective_routing_mode
  767. )
  768. # ==================== SQL验证和修复相关方法 ====================
  769. def _is_sql_validation_enabled(self) -> bool:
  770. """检查是否启用SQL验证"""
  771. from agent.config import get_nested_config
  772. return (get_nested_config(self.config, "sql_validation.enable_syntax_validation", False) or
  773. get_nested_config(self.config, "sql_validation.enable_forbidden_check", False))
  774. def _is_auto_repair_enabled(self) -> bool:
  775. """检查是否启用自动修复"""
  776. from agent.config import get_nested_config
  777. return (get_nested_config(self.config, "sql_validation.enable_auto_repair", False) and
  778. get_nested_config(self.config, "sql_validation.enable_syntax_validation", False))
  779. async def _validate_sql_with_custom_priority(self, sql: str) -> Dict[str, Any]:
  780. """
  781. 按照自定义优先级验证SQL:先禁止词,再语法
  782. Args:
  783. sql: 要验证的SQL语句
  784. Returns:
  785. 验证结果字典
  786. """
  787. try:
  788. from agent.config import get_nested_config
  789. # 1. 优先检查禁止词(您要求的优先级)
  790. if get_nested_config(self.config, "sql_validation.enable_forbidden_check", True):
  791. forbidden_result = self._check_forbidden_keywords(sql)
  792. if not forbidden_result.get("valid"):
  793. return {
  794. "valid": False,
  795. "error_type": "forbidden_keywords",
  796. "error_message": forbidden_result.get("error"),
  797. "can_repair": False # 禁止词错误不能修复
  798. }
  799. # 2. 再检查语法(EXPLAIN SQL)
  800. if get_nested_config(self.config, "sql_validation.enable_syntax_validation", True):
  801. syntax_result = await self._validate_sql_syntax(sql)
  802. if not syntax_result.get("valid"):
  803. return {
  804. "valid": False,
  805. "error_type": "syntax_error",
  806. "error_message": syntax_result.get("error"),
  807. "can_repair": True # 语法错误可以尝试修复
  808. }
  809. return {"valid": True}
  810. except Exception as e:
  811. return {
  812. "valid": False,
  813. "error_type": "validation_exception",
  814. "error_message": str(e),
  815. "can_repair": False
  816. }
  817. def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
  818. """检查禁止的SQL关键词"""
  819. try:
  820. from agent.config import get_nested_config
  821. forbidden_operations = get_nested_config(
  822. self.config,
  823. "sql_validation.forbidden_operations",
  824. ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
  825. )
  826. sql_upper = sql.upper().strip()
  827. for operation in forbidden_operations:
  828. if sql_upper.startswith(operation.upper()):
  829. return {
  830. "valid": False,
  831. "error": f"不允许的操作: {operation}。本系统只支持查询操作(SELECT)。"
  832. }
  833. return {"valid": True}
  834. except Exception as e:
  835. return {
  836. "valid": False,
  837. "error": f"禁止词检查异常: {str(e)}"
  838. }
  839. async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
  840. """语法验证 - 使用EXPLAIN SQL"""
  841. try:
  842. from common.vanna_instance import get_vanna_instance
  843. import asyncio
  844. vn = get_vanna_instance()
  845. # 构建EXPLAIN查询
  846. explain_sql = f"EXPLAIN {sql}"
  847. # 异步执行验证
  848. result = await asyncio.to_thread(vn.run_sql, explain_sql)
  849. if result is not None:
  850. return {"valid": True}
  851. else:
  852. return {
  853. "valid": False,
  854. "error": "SQL语法验证失败"
  855. }
  856. except Exception as e:
  857. return {
  858. "valid": False,
  859. "error": str(e)
  860. }
  861. async def _attempt_sql_repair_once(self, sql: str, error_message: str) -> Dict[str, Any]:
  862. """
  863. 使用LLM尝试修复SQL - 只修复一次
  864. Args:
  865. sql: 原始SQL
  866. error_message: 错误信息
  867. Returns:
  868. 修复结果字典
  869. """
  870. try:
  871. from common.vanna_instance import get_vanna_instance
  872. from agent.config import get_nested_config
  873. import asyncio
  874. vn = get_vanna_instance()
  875. # 构建修复提示词
  876. repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
  877. 当前数据库类型: PostgreSQL
  878. 错误信息: {error_message}
  879. 需要修复的SQL:
  880. {sql}
  881. 修复要求:
  882. 1. 只修复语法错误和表结构错误
  883. 2. 保持SQL的原始业务逻辑不变
  884. 3. 使用PostgreSQL标准语法
  885. 4. 确保修复后的SQL语法正确
  886. 请直接输出修复后的SQL语句,不要添加其他说明文字。"""
  887. # 获取超时配置
  888. timeout = get_nested_config(self.config, "sql_validation.repair_timeout", 60)
  889. # 异步调用LLM修复
  890. response = await asyncio.wait_for(
  891. asyncio.to_thread(
  892. vn.chat_with_llm,
  893. question=repair_prompt,
  894. system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
  895. ),
  896. timeout=timeout
  897. )
  898. if response and response.strip():
  899. repaired_sql = response.strip()
  900. # 验证修复后的SQL
  901. validation_result = await self._validate_sql_syntax(repaired_sql)
  902. if validation_result.get("valid"):
  903. return {
  904. "success": True,
  905. "repaired_sql": repaired_sql,
  906. "error": None
  907. }
  908. else:
  909. return {
  910. "success": False,
  911. "repaired_sql": None,
  912. "error": f"修复后的SQL仍然无效: {validation_result.get('error')}"
  913. }
  914. else:
  915. return {
  916. "success": False,
  917. "repaired_sql": None,
  918. "error": "LLM返回空响应"
  919. }
  920. except asyncio.TimeoutError:
  921. return {
  922. "success": False,
  923. "repaired_sql": None,
  924. "error": f"修复超时({get_nested_config(self.config, 'sql_validation.repair_timeout', 60)}秒)"
  925. }
  926. except Exception as e:
  927. return {
  928. "success": False,
  929. "repaired_sql": None,
  930. "error": f"修复异常: {str(e)}"
  931. }
  932. # ==================== 原有方法 ====================
  933. def _extract_original_question(self, question: str) -> str:
  934. """
  935. 从enhanced_question中提取原始问题
  936. Args:
  937. question: 可能包含上下文的问题
  938. Returns:
  939. str: 原始问题
  940. """
  941. try:
  942. # 检查是否为enhanced_question格式
  943. if "\n[CONTEXT]\n" in question and "\n[CURRENT]\n" in question:
  944. # 提取[CURRENT]标签后的内容
  945. current_start = question.find("\n[CURRENT]\n")
  946. if current_start != -1:
  947. original_question = question[current_start + len("\n[CURRENT]\n"):].strip()
  948. return original_question
  949. # 如果不是enhanced_question格式,直接返回原问题
  950. return question.strip()
  951. except Exception as e:
  952. self.logger.warning(f"提取原始问题失败: {str(e)}")
  953. return question.strip()
  954. async def health_check(self) -> Dict[str, Any]:
  955. """健康检查"""
  956. try:
  957. # 从配置获取健康检查参数
  958. from agent.config import get_nested_config
  959. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  960. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  961. if enable_full_test:
  962. # 完整流程测试
  963. test_result = await self.process_question(test_question, "health_check")
  964. return {
  965. "status": "healthy" if test_result.get("success") else "degraded",
  966. "test_result": test_result.get("success", False),
  967. "workflow_compiled": True, # 动态创建,始终可用
  968. "tools_count": len(self.tools),
  969. "agent_reuse_enabled": False,
  970. "message": "Agent健康检查完成"
  971. }
  972. else:
  973. # 简单检查
  974. return {
  975. "status": "healthy",
  976. "test_result": True,
  977. "workflow_compiled": True, # 动态创建,始终可用
  978. "tools_count": len(self.tools),
  979. "agent_reuse_enabled": False,
  980. "message": "Agent简单健康检查完成"
  981. }
  982. except Exception as e:
  983. return {
  984. "status": "unhealthy",
  985. "error": str(e),
  986. "workflow_compiled": True, # 动态创建,始终可用
  987. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  988. "agent_reuse_enabled": False,
  989. "message": "Agent健康检查失败"
  990. }