citu_agent.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from core.logging import get_agent_logger
  8. from agent.state import AgentState
  9. from agent.classifier import QuestionClassifier
  10. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  11. from agent.tools.utils import get_compatible_llm
  12. from app_config import ENABLE_RESULT_SUMMARY
  13. class CituLangGraphAgent:
  14. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  15. def __init__(self):
  16. # 初始化日志
  17. self.logger = get_agent_logger("CituAgent")
  18. # 加载配置
  19. try:
  20. from agent.config import get_current_config, get_nested_config
  21. self.config = get_current_config()
  22. self.logger.info("加载Agent配置完成")
  23. except ImportError:
  24. self.config = {}
  25. self.logger.warning("配置文件不可用,使用默认配置")
  26. self.classifier = QuestionClassifier()
  27. self.tools = TOOLS
  28. self.llm = get_compatible_llm()
  29. # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
  30. self.logger.info("使用直接工具调用模式")
  31. # 不在构造时创建workflow,改为动态创建以支持路由模式参数
  32. # self.workflow = self._create_workflow()
  33. self.logger.info("LangGraph Agent with Direct Tools初始化完成")
  34. def _create_workflow(self, routing_mode: str = None) -> StateGraph:
  35. """创建统一的工作流,所有路由模式都通过classify_question进行分类"""
  36. self.logger.info(f"🏗️ [WORKFLOW] 创建统一workflow")
  37. workflow = StateGraph(AgentState)
  38. # 统一的工作流结构 - 所有模式都使用相同的节点和路由
  39. workflow.add_node("classify_question", self._classify_question_node)
  40. workflow.add_node("agent_chat", self._agent_chat_node)
  41. workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
  42. workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
  43. workflow.add_node("format_response", self._format_response_node)
  44. # 统一入口点
  45. workflow.set_entry_point("classify_question")
  46. # 添加条件边:分类后的路由
  47. workflow.add_conditional_edges(
  48. "classify_question",
  49. self._route_after_classification,
  50. {
  51. "DATABASE": "agent_sql_generation",
  52. "CHAT": "agent_chat"
  53. }
  54. )
  55. # 添加条件边:SQL生成后的路由
  56. workflow.add_conditional_edges(
  57. "agent_sql_generation",
  58. self._route_after_sql_generation,
  59. {
  60. "continue_execution": "agent_sql_execution",
  61. "return_to_user": "format_response"
  62. }
  63. )
  64. # 普通边
  65. workflow.add_edge("agent_chat", "format_response")
  66. workflow.add_edge("agent_sql_execution", "format_response")
  67. workflow.add_edge("format_response", END)
  68. return workflow.compile()
  69. def _classify_question_node(self, state: AgentState) -> AgentState:
  70. """问题分类节点 - 使用混合分类策略(规则+LLM)"""
  71. try:
  72. # 从state中获取路由模式,而不是从配置文件读取
  73. routing_mode = state.get("routing_mode", "hybrid")
  74. self.logger.info(f"开始分类问题: {state['question']}")
  75. # 获取上下文类型(保留兼容性,但不在分类中使用)
  76. context_type = state.get("context_type")
  77. if context_type:
  78. self.logger.info(f"检测到上下文类型: {context_type}")
  79. # 使用混合分类策略(规则+LLM),传递路由模式
  80. classification_result = self.classifier.classify(state["question"], context_type, routing_mode)
  81. # 更新状态
  82. state["question_type"] = classification_result.question_type
  83. state["classification_confidence"] = classification_result.confidence
  84. state["classification_reason"] = classification_result.reason
  85. state["classification_method"] = classification_result.method
  86. state["routing_mode"] = routing_mode
  87. state["current_step"] = "classified"
  88. state["execution_path"].append("classify")
  89. self.logger.info(f"分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  90. self.logger.info(f"路由模式: {routing_mode}, 分类方法: {classification_result.method}")
  91. return state
  92. except Exception as e:
  93. self.logger.error(f"问题分类异常: {str(e)}")
  94. state["error"] = f"问题分类失败: {str(e)}"
  95. state["error_code"] = 500
  96. state["execution_path"].append("classify_error")
  97. return state
  98. async def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
  99. """SQL生成验证节点 - 负责生成SQL、验证SQL和决定路由"""
  100. try:
  101. self.logger.info(f"开始处理SQL生成和验证: {state['question']}")
  102. question = state["question"]
  103. # 步骤1:生成SQL
  104. self.logger.info("步骤1:生成SQL")
  105. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  106. if not sql_result.get("success"):
  107. # SQL生成失败的统一处理
  108. error_message = sql_result.get("error", "")
  109. error_type = sql_result.get("error_type", "")
  110. self.logger.debug(f"error_type = '{error_type}'")
  111. # 根据错误类型生成用户提示
  112. if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
  113. user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
  114. failure_reason = "missing_database_info"
  115. elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
  116. user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
  117. failure_reason = "ambiguous_question"
  118. elif error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
  119. # 对于解释性文本,直接设置为聊天响应
  120. state["chat_response"] = error_message + " 请尝试提问其它问题。"
  121. state["sql_generation_success"] = False
  122. state["validation_error_type"] = "llm_explanation"
  123. state["current_step"] = "sql_generation_completed"
  124. state["execution_path"].append("agent_sql_generation")
  125. self.logger.info(f"返回LLM解释性答案: {error_message}")
  126. return state
  127. else:
  128. user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
  129. failure_reason = "unknown_generation_failure"
  130. # 统一返回失败状态
  131. state["sql_generation_success"] = False
  132. state["user_prompt"] = user_prompt
  133. state["validation_error_type"] = failure_reason
  134. state["current_step"] = "sql_generation_failed"
  135. state["execution_path"].append("agent_sql_generation_failed")
  136. self.logger.warning(f"生成失败: {failure_reason} - {user_prompt}")
  137. return state
  138. sql = sql_result.get("sql")
  139. state["sql"] = sql
  140. # 步骤1.5:检查是否为解释性响应而非SQL
  141. error_type = sql_result.get("error_type")
  142. if error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
  143. # LLM返回了解释性文本,直接作为最终答案
  144. explanation = sql_result.get("error", "")
  145. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  146. state["sql_generation_success"] = False
  147. state["validation_error_type"] = "llm_explanation"
  148. state["current_step"] = "sql_generation_completed"
  149. state["execution_path"].append("agent_sql_generation")
  150. self.logger.info(f"返回LLM解释性答案: {explanation}")
  151. return state
  152. if sql:
  153. self.logger.info(f"SQL生成成功: {sql}")
  154. else:
  155. self.logger.warning("SQL为空,但不是解释性响应")
  156. # 这种情况应该很少见,但为了安全起见保留原有的错误处理
  157. return state
  158. # 额外验证:检查SQL格式(防止工具误判)
  159. from agent.tools.utils import _is_valid_sql_format
  160. if not _is_valid_sql_format(sql):
  161. # 内容看起来不是SQL,当作解释性响应处理
  162. state["chat_response"] = sql + " 请尝试提问其它问题。"
  163. state["sql_generation_success"] = False
  164. state["validation_error_type"] = "invalid_sql_format"
  165. state["current_step"] = "sql_generation_completed"
  166. state["execution_path"].append("agent_sql_generation")
  167. self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
  168. return state
  169. # 步骤2:SQL验证(如果启用)
  170. if self._is_sql_validation_enabled():
  171. self.logger.info("步骤2:验证SQL")
  172. validation_result = await self._validate_sql_with_custom_priority(sql)
  173. if not validation_result.get("valid"):
  174. # 验证失败,检查是否可以修复
  175. error_type = validation_result.get("error_type")
  176. error_message = validation_result.get("error_message")
  177. can_repair = validation_result.get("can_repair", False)
  178. self.logger.warning(f"SQL验证失败: {error_type} - {error_message}")
  179. if error_type == "forbidden_keywords":
  180. # 禁止词错误,直接失败,不尝试修复
  181. state["sql_generation_success"] = False
  182. state["sql_validation_success"] = False
  183. state["user_prompt"] = error_message
  184. state["validation_error_type"] = "forbidden_keywords"
  185. state["current_step"] = "sql_validation_failed"
  186. state["execution_path"].append("forbidden_keywords_failed")
  187. self.logger.warning("禁止词验证失败,直接结束")
  188. return state
  189. elif error_type == "syntax_error" and can_repair and self._is_auto_repair_enabled():
  190. # 语法错误,尝试修复(仅一次)
  191. self.logger.info(f"尝试修复SQL语法错误(仅一次): {error_message}")
  192. state["sql_repair_attempted"] = True
  193. repair_result = await self._attempt_sql_repair_once(sql, error_message)
  194. if repair_result.get("success"):
  195. # 修复成功
  196. repaired_sql = repair_result.get("repaired_sql")
  197. state["sql"] = repaired_sql
  198. state["sql_generation_success"] = True
  199. state["sql_validation_success"] = True
  200. state["sql_repair_success"] = True
  201. state["current_step"] = "sql_generation_completed"
  202. state["execution_path"].append("sql_repair_success")
  203. self.logger.info(f"SQL修复成功: {repaired_sql}")
  204. return state
  205. else:
  206. # 修复失败,直接结束
  207. repair_error = repair_result.get("error", "修复失败")
  208. self.logger.warning(f"SQL修复失败: {repair_error}")
  209. state["sql_generation_success"] = False
  210. state["sql_validation_success"] = False
  211. state["sql_repair_success"] = False
  212. state["user_prompt"] = f"SQL语法修复失败: {repair_error}"
  213. state["validation_error_type"] = "syntax_repair_failed"
  214. state["current_step"] = "sql_repair_failed"
  215. state["execution_path"].append("sql_repair_failed")
  216. return state
  217. else:
  218. # 不启用修复或其他错误类型,直接失败
  219. state["sql_generation_success"] = False
  220. state["sql_validation_success"] = False
  221. state["user_prompt"] = f"SQL验证失败: {error_message}"
  222. state["validation_error_type"] = error_type
  223. state["current_step"] = "sql_validation_failed"
  224. state["execution_path"].append("sql_validation_failed")
  225. self.logger.warning("SQL验证失败,不尝试修复")
  226. return state
  227. else:
  228. self.logger.info("SQL验证通过")
  229. state["sql_validation_success"] = True
  230. else:
  231. self.logger.info("跳过SQL验证(未启用)")
  232. state["sql_validation_success"] = True
  233. # 生成和验证都成功
  234. state["sql_generation_success"] = True
  235. state["current_step"] = "sql_generation_completed"
  236. state["execution_path"].append("agent_sql_generation")
  237. self.logger.info("SQL生成验证完成,准备执行")
  238. return state
  239. except Exception as e:
  240. self.logger.error(f"SQL生成验证节点异常: {str(e)}")
  241. import traceback
  242. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  243. state["sql_generation_success"] = False
  244. state["sql_validation_success"] = False
  245. state["user_prompt"] = f"SQL生成验证异常: {str(e)}"
  246. state["validation_error_type"] = "node_exception"
  247. state["current_step"] = "sql_generation_error"
  248. state["execution_path"].append("agent_sql_generation_error")
  249. return state
  250. def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
  251. """SQL执行节点 - 负责执行已验证的SQL和生成摘要"""
  252. try:
  253. self.logger.info(f"开始执行SQL: {state.get('sql', 'N/A')}")
  254. sql = state.get("sql")
  255. question = state["question"]
  256. if not sql:
  257. self.logger.warning("没有可执行的SQL")
  258. state["error"] = "没有可执行的SQL语句"
  259. state["error_code"] = 500
  260. state["current_step"] = "sql_execution_error"
  261. state["execution_path"].append("agent_sql_execution_error")
  262. return state
  263. # 步骤1:执行SQL
  264. self.logger.info("步骤1:执行SQL")
  265. execute_result = execute_sql.invoke({"sql": sql})
  266. if not execute_result.get("success"):
  267. self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
  268. state["error"] = execute_result.get("error", "SQL执行失败")
  269. state["error_code"] = 500
  270. state["current_step"] = "sql_execution_error"
  271. state["execution_path"].append("agent_sql_execution_error")
  272. return state
  273. query_result = execute_result.get("data_result")
  274. state["query_result"] = query_result
  275. self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
  276. # 步骤2:生成摘要(根据配置和数据情况)
  277. if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
  278. self.logger.info("步骤2:生成摘要")
  279. # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
  280. original_question = self._extract_original_question(question)
  281. self.logger.debug(f"原始问题: {original_question}")
  282. summary_result = generate_summary.invoke({
  283. "question": original_question, # 使用原始问题而不是enhanced_question
  284. "query_result": query_result,
  285. "sql": sql
  286. })
  287. if not summary_result.get("success"):
  288. self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
  289. # 摘要生成失败不是致命错误,使用默认摘要
  290. state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
  291. else:
  292. state["summary"] = summary_result.get("summary")
  293. self.logger.info("摘要生成成功")
  294. else:
  295. self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
  296. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  297. state["current_step"] = "sql_execution_completed"
  298. state["execution_path"].append("agent_sql_execution")
  299. self.logger.info("SQL执行完成")
  300. return state
  301. except Exception as e:
  302. self.logger.error(f"SQL执行节点异常: {str(e)}")
  303. import traceback
  304. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  305. state["error"] = f"SQL执行失败: {str(e)}"
  306. state["error_code"] = 500
  307. state["current_step"] = "sql_execution_error"
  308. state["execution_path"].append("agent_sql_execution_error")
  309. return state
  310. def _agent_database_node(self, state: AgentState) -> AgentState:
  311. """
  312. 数据库Agent节点 - 直接工具调用模式 [已废弃]
  313. 注意:此方法已被拆分为 _agent_sql_generation_node 和 _agent_sql_execution_node
  314. 保留此方法仅为向后兼容,新的工作流使用拆分后的节点
  315. """
  316. try:
  317. self.logger.warning("使用已废弃的database节点,建议使用新的拆分节点")
  318. self.logger.info(f"开始处理数据库查询: {state['question']}")
  319. question = state["question"]
  320. # 步骤1:生成SQL
  321. self.logger.info("步骤1:生成SQL")
  322. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  323. if not sql_result.get("success"):
  324. self.logger.error(f"SQL生成失败: {sql_result.get('error')}")
  325. state["error"] = sql_result.get("error", "SQL生成失败")
  326. state["error_code"] = 500
  327. state["current_step"] = "database_error"
  328. state["execution_path"].append("agent_database_error")
  329. return state
  330. sql = sql_result.get("sql")
  331. state["sql"] = sql
  332. self.logger.info(f"SQL生成成功: {sql}")
  333. # 步骤1.5:检查是否为解释性响应而非SQL
  334. error_type = sql_result.get("error_type")
  335. if error_type == "llm_explanation":
  336. # LLM返回了解释性文本,直接作为最终答案
  337. explanation = sql_result.get("error", "")
  338. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  339. state["current_step"] = "database_completed"
  340. state["execution_path"].append("agent_database")
  341. self.logger.info(f"返回LLM解释性答案: {explanation}")
  342. return state
  343. # 额外验证:检查SQL格式(防止工具误判)
  344. from agent.tools.utils import _is_valid_sql_format
  345. if not _is_valid_sql_format(sql):
  346. # 内容看起来不是SQL,当作解释性响应处理
  347. state["chat_response"] = sql + " 请尝试提问其它问题。"
  348. state["current_step"] = "database_completed"
  349. state["execution_path"].append("agent_database")
  350. self.logger.info(f"内容不是有效SQL,当作解释返回: {sql}")
  351. return state
  352. # 步骤2:执行SQL
  353. self.logger.info("步骤2:执行SQL")
  354. execute_result = execute_sql.invoke({"sql": sql})
  355. if not execute_result.get("success"):
  356. self.logger.error(f"SQL执行失败: {execute_result.get('error')}")
  357. state["error"] = execute_result.get("error", "SQL执行失败")
  358. state["error_code"] = 500
  359. state["current_step"] = "database_error"
  360. state["execution_path"].append("agent_database_error")
  361. return state
  362. query_result = execute_result.get("data_result")
  363. state["query_result"] = query_result
  364. self.logger.info(f"SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
  365. # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
  366. if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
  367. self.logger.info("步骤3:生成摘要")
  368. # 重要:提取原始问题用于摘要生成,避免历史记录循环嵌套
  369. original_question = self._extract_original_question(question)
  370. self.logger.debug(f"原始问题: {original_question}")
  371. summary_result = generate_summary.invoke({
  372. "question": original_question, # 使用原始问题而不是enhanced_question
  373. "query_result": query_result,
  374. "sql": sql
  375. })
  376. if not summary_result.get("success"):
  377. self.logger.warning(f"摘要生成失败: {summary_result.get('message')}")
  378. # 摘要生成失败不是致命错误,使用默认摘要
  379. state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
  380. else:
  381. state["summary"] = summary_result.get("summary")
  382. self.logger.info("摘要生成成功")
  383. else:
  384. self.logger.info(f"跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
  385. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  386. state["current_step"] = "database_completed"
  387. state["execution_path"].append("agent_database")
  388. self.logger.info("数据库查询完成")
  389. return state
  390. except Exception as e:
  391. self.logger.error(f"数据库Agent异常: {str(e)}")
  392. import traceback
  393. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  394. state["error"] = f"数据库查询失败: {str(e)}"
  395. state["error_code"] = 500
  396. state["current_step"] = "database_error"
  397. state["execution_path"].append("agent_database_error")
  398. return state
  399. def _agent_chat_node(self, state: AgentState) -> AgentState:
  400. """聊天Agent节点 - 直接工具调用模式"""
  401. try:
  402. # 🔹 添加State调试日志 - 打印agent_chat接收到的完整State内容
  403. import json
  404. try:
  405. state_debug = dict(state)
  406. self.logger.debug(f"agent_chat接收到的State内容: {json.dumps(state_debug, ensure_ascii=False, indent=2)}")
  407. except Exception as debug_e:
  408. self.logger.debug(f"State序列化失败: {debug_e}")
  409. self.logger.debug(f"agent_chat接收到的State内容: {state}")
  410. self.logger.info(f"开始处理聊天: {state['question']}")
  411. question = state["question"]
  412. # 构建上下文 - 仅使用真实的对话历史上下文
  413. # 注意:不要将分类原因传递给LLM,那是系统内部的路由信息
  414. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  415. context = None
  416. if enable_context_injection:
  417. # 实际上上下文已经在API层面处理,并合并到question中了
  418. # 这里不需要再次获取Redis上下文
  419. pass
  420. # 直接调用general_chat工具
  421. self.logger.info("调用general_chat工具")
  422. chat_result = general_chat.invoke({
  423. "question": question,
  424. "context": context
  425. })
  426. if chat_result.get("success"):
  427. state["chat_response"] = chat_result.get("response", "")
  428. self.logger.info("聊天处理成功")
  429. else:
  430. # 处理失败,使用备用响应
  431. state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
  432. self.logger.warning(f"聊天处理失败,使用备用响应: {chat_result.get('error')}")
  433. state["current_step"] = "chat_completed"
  434. state["execution_path"].append("agent_chat")
  435. self.logger.info("聊天处理完成")
  436. return state
  437. except Exception as e:
  438. self.logger.error(f"聊天Agent异常: {str(e)}")
  439. import traceback
  440. self.logger.error(f"详细错误信息: {traceback.format_exc()}")
  441. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  442. state["current_step"] = "chat_error"
  443. state["execution_path"].append("agent_chat_error")
  444. return state
  445. def _format_response_node(self, state: AgentState) -> AgentState:
  446. """格式化最终响应节点"""
  447. try:
  448. self.logger.info(f"开始格式化响应,问题类型: {state['question_type']}")
  449. state["current_step"] = "completed"
  450. state["execution_path"].append("format_response")
  451. # 根据问题类型和执行状态格式化响应
  452. if state.get("error"):
  453. # 有错误的情况
  454. state["final_response"] = {
  455. "success": False,
  456. "error": state["error"],
  457. "error_code": state.get("error_code", 500),
  458. "question_type": state["question_type"],
  459. "execution_path": state["execution_path"],
  460. "classification_info": {
  461. "confidence": state.get("classification_confidence", 0),
  462. "reason": state.get("classification_reason", ""),
  463. "method": state.get("classification_method", "")
  464. }
  465. }
  466. elif state["question_type"] == "DATABASE":
  467. # 数据库查询类型
  468. # 处理SQL生成失败的情况
  469. if not state.get("sql_generation_success", True) and state.get("user_prompt"):
  470. state["final_response"] = {
  471. "success": False,
  472. "response": state["user_prompt"],
  473. "type": "DATABASE",
  474. "sql_generation_failed": True,
  475. "validation_error_type": state.get("validation_error_type"),
  476. "sql": state.get("sql"),
  477. "execution_path": state["execution_path"],
  478. "classification_info": {
  479. "confidence": state["classification_confidence"],
  480. "reason": state["classification_reason"],
  481. "method": state["classification_method"]
  482. },
  483. "sql_validation_info": {
  484. "sql_generation_success": state.get("sql_generation_success", False),
  485. "sql_validation_success": state.get("sql_validation_success", False),
  486. "sql_repair_attempted": state.get("sql_repair_attempted", False),
  487. "sql_repair_success": state.get("sql_repair_success", False)
  488. }
  489. }
  490. elif state.get("chat_response"):
  491. # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
  492. state["final_response"] = {
  493. "success": True,
  494. "response": state["chat_response"],
  495. "type": "DATABASE",
  496. "sql": state.get("sql"),
  497. "query_result": state.get("query_result"), # 保持内部字段名不变
  498. "execution_path": state["execution_path"],
  499. "classification_info": {
  500. "confidence": state["classification_confidence"],
  501. "reason": state["classification_reason"],
  502. "method": state["classification_method"]
  503. }
  504. }
  505. elif state.get("summary"):
  506. # 正常的数据库查询结果,有摘要的情况
  507. # 将summary的值同时赋给response字段(为将来移除summary字段做准备)
  508. state["final_response"] = {
  509. "success": True,
  510. "type": "DATABASE",
  511. "response": state["summary"], # 新增:将summary的值赋给response
  512. "sql": state.get("sql"),
  513. "query_result": state.get("query_result"), # 保持内部字段名不变
  514. "summary": state["summary"], # 暂时保留summary字段
  515. "execution_path": state["execution_path"],
  516. "classification_info": {
  517. "confidence": state["classification_confidence"],
  518. "reason": state["classification_reason"],
  519. "method": state["classification_method"]
  520. }
  521. }
  522. elif state.get("query_result"):
  523. # 有数据但没有摘要(摘要被配置禁用)
  524. query_result = state.get("query_result")
  525. row_count = query_result.get("row_count", 0)
  526. # 构建基本响应,不包含summary字段和response字段
  527. # 用户应该直接从query_result.columns和query_result.rows获取数据
  528. state["final_response"] = {
  529. "success": True,
  530. "type": "DATABASE",
  531. "sql": state.get("sql"),
  532. "query_result": query_result, # 保持内部字段名不变
  533. "execution_path": state["execution_path"],
  534. "classification_info": {
  535. "confidence": state["classification_confidence"],
  536. "reason": state["classification_reason"],
  537. "method": state["classification_method"]
  538. }
  539. }
  540. else:
  541. # 数据库查询失败,没有任何结果
  542. state["final_response"] = {
  543. "success": False,
  544. "error": state.get("error", "数据库查询未完成"),
  545. "type": "DATABASE",
  546. "sql": state.get("sql"),
  547. "execution_path": state["execution_path"]
  548. }
  549. else:
  550. # 聊天类型
  551. state["final_response"] = {
  552. "success": True,
  553. "response": state.get("chat_response", ""),
  554. "type": "CHAT",
  555. "execution_path": state["execution_path"],
  556. "classification_info": {
  557. "confidence": state["classification_confidence"],
  558. "reason": state["classification_reason"],
  559. "method": state["classification_method"]
  560. }
  561. }
  562. self.logger.info("响应格式化完成")
  563. # 输出完整的 STATE 内容用于调试
  564. import json
  565. try:
  566. # 创建一个可序列化的 state 副本
  567. debug_state = dict(state)
  568. self.logger.debug(f"format_response_node 完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
  569. except Exception as debug_e:
  570. self.logger.debug(f"STATE 序列化失败,使用简单输出: {debug_e}")
  571. self.logger.debug(f"format_response_node STATE 内容: {state}")
  572. return state
  573. except Exception as e:
  574. self.logger.error(f"响应格式化异常: {str(e)}")
  575. state["final_response"] = {
  576. "success": False,
  577. "error": f"响应格式化异常: {str(e)}",
  578. "error_code": 500,
  579. "execution_path": state["execution_path"]
  580. }
  581. # 即使在异常情况下也输出 STATE 内容用于调试
  582. import json
  583. try:
  584. debug_state = dict(state)
  585. self.logger.debug(f"format_response_node 异常情况下的完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
  586. except Exception as debug_e:
  587. self.logger.debug(f"异常情况下 STATE 序列化失败: {debug_e}")
  588. self.logger.debug(f"format_response_node 异常情况下的 STATE 内容: {state}")
  589. return state
  590. def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
  591. """
  592. SQL生成后的路由决策
  593. 根据SQL生成和验证的结果决定后续流向:
  594. - SQL生成验证成功 → 继续执行SQL
  595. - SQL生成验证失败 → 直接返回用户提示
  596. """
  597. sql_generation_success = state.get("sql_generation_success", False)
  598. self.logger.debug(f"SQL生成路由: success={sql_generation_success}")
  599. if sql_generation_success:
  600. return "continue_execution" # 路由到SQL执行节点
  601. else:
  602. return "return_to_user" # 路由到format_response,结束流程
  603. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  604. """
  605. 分类后的路由决策
  606. 完全信任QuestionClassifier的决策:
  607. - DATABASE类型 → 数据库Agent
  608. - CHAT和UNCERTAIN类型 → 聊天Agent
  609. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  610. """
  611. question_type = state["question_type"]
  612. confidence = state["classification_confidence"]
  613. self.logger.debug(f"分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  614. if question_type == "DATABASE":
  615. return "DATABASE"
  616. else:
  617. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  618. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  619. return "CHAT"
  620. async def process_question(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
  621. """
  622. 统一的问题处理入口
  623. Args:
  624. question: 用户问题
  625. conversation_id: 对话ID
  626. context_type: 上下文类型(保留兼容性参数,当前未使用)
  627. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  628. Returns:
  629. Dict包含完整的处理结果
  630. """
  631. try:
  632. self.logger.info(f"开始处理问题: {question}")
  633. if context_type:
  634. self.logger.info(f"上下文类型: {context_type}")
  635. if routing_mode:
  636. self.logger.info(f"使用指定路由模式: {routing_mode}")
  637. # 动态创建workflow(基于路由模式)
  638. self.logger.info(f"🔄 [PROCESS] 调用动态创建workflow")
  639. workflow = self._create_workflow(routing_mode)
  640. # 初始化状态
  641. initial_state = self._create_initial_state(question, conversation_id, context_type, routing_mode)
  642. # 执行工作流
  643. final_state = await workflow.ainvoke(
  644. initial_state,
  645. config={
  646. "configurable": {"conversation_id": conversation_id}
  647. } if conversation_id else None
  648. )
  649. # 提取最终结果
  650. result = final_state["final_response"]
  651. self.logger.info(f"问题处理完成: {result.get('success', False)}")
  652. return result
  653. except Exception as e:
  654. self.logger.error(f"Agent执行异常: {str(e)}")
  655. return {
  656. "success": False,
  657. "error": f"Agent系统异常: {str(e)}",
  658. "error_code": 500,
  659. "execution_path": ["error"]
  660. }
  661. async def process_question_stream(self, question: str, user_id: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None):
  662. """
  663. 流式处理用户问题 - 复用process_question()的所有逻辑
  664. Args:
  665. question: 用户问题
  666. user_id: 用户ID,用于生成conversation_id
  667. conversation_id: 对话ID,可选,不提供则自动生成
  668. context_type: 上下文类型(保留兼容性参数,当前未使用)
  669. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  670. Yields:
  671. Dict: 流式状态更新,包含进度信息或最终结果
  672. """
  673. try:
  674. self.logger.info(f"🌊 [STREAM] 开始流式处理问题: {question}")
  675. if context_type:
  676. self.logger.info(f"🌊 [STREAM] 上下文类型: {context_type}")
  677. if routing_mode:
  678. self.logger.info(f"🌊 [STREAM] 使用指定路由模式: {routing_mode}")
  679. # 生成conversation_id(如果未提供)
  680. if not conversation_id:
  681. conversation_id = self._generate_conversation_id(user_id)
  682. # 1. 复用现有的初始化逻辑
  683. self.logger.info(f"🌊 [STREAM] 动态创建workflow")
  684. workflow = self._create_workflow(routing_mode)
  685. # 2. 创建初始状态(复用现有逻辑)
  686. initial_state = self._create_initial_state(question, conversation_id, context_type, routing_mode)
  687. # 3. 使用astream流式执行
  688. self.logger.info(f"🌊 [STREAM] 开始流式执行workflow")
  689. async for chunk in workflow.astream(
  690. initial_state,
  691. config={
  692. "configurable": {"conversation_id": conversation_id}
  693. } if conversation_id else None
  694. ):
  695. # 处理每个节点的输出
  696. for node_name, node_data in chunk.items():
  697. self.logger.debug(f"🌊 [STREAM] 收到节点输出: {node_name}")
  698. # 映射节点状态为用户友好的进度信息
  699. progress_info = self._map_node_to_progress(node_name, node_data)
  700. if progress_info:
  701. yield {
  702. "type": "progress",
  703. "node": node_name,
  704. "progress": progress_info,
  705. "state_data": self._extract_relevant_state(node_data),
  706. "conversation_id": conversation_id
  707. }
  708. # 4. 最终结果处理(复用现有的结果提取逻辑)
  709. # 注意:由于astream的特性,最后一个chunk包含最终状态
  710. final_result = node_data.get("final_response", {})
  711. self.logger.info(f"🌊 [STREAM] 流式处理完成: {final_result.get('success', False)}")
  712. yield {
  713. "type": "completed",
  714. "result": final_result,
  715. "conversation_id": conversation_id
  716. }
  717. except Exception as e:
  718. self.logger.error(f"🌊 [STREAM] Agent流式执行异常: {str(e)}")
  719. yield {
  720. "type": "error",
  721. "error": str(e),
  722. "conversation_id": conversation_id
  723. }
  724. def _create_initial_state(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
  725. """创建初始状态 - 支持兼容性参数"""
  726. # 确定使用的路由模式
  727. if routing_mode:
  728. effective_routing_mode = routing_mode
  729. else:
  730. try:
  731. from app_config import QUESTION_ROUTING_MODE
  732. effective_routing_mode = QUESTION_ROUTING_MODE
  733. except ImportError:
  734. effective_routing_mode = "hybrid"
  735. return AgentState(
  736. # 输入信息
  737. question=question,
  738. conversation_id=conversation_id,
  739. # 上下文信息
  740. context_type=context_type,
  741. # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
  742. question_type="UNCERTAIN",
  743. classification_confidence=0.0,
  744. classification_reason="",
  745. classification_method="",
  746. # 数据库查询流程状态
  747. sql=None,
  748. query_result=None,
  749. summary=None,
  750. # SQL验证和修复相关状态
  751. sql_generation_success=False,
  752. sql_validation_success=False,
  753. sql_repair_attempted=False,
  754. sql_repair_success=False,
  755. validation_error_type=None,
  756. user_prompt=None,
  757. # 聊天响应
  758. chat_response=None,
  759. # 最终输出
  760. final_response={},
  761. # 错误处理
  762. error=None,
  763. error_code=None,
  764. # 流程控制
  765. current_step="initialized",
  766. execution_path=["start"],
  767. # 路由模式
  768. routing_mode=effective_routing_mode
  769. )
  770. # ==================== SQL验证和修复相关方法 ====================
  771. def _is_sql_validation_enabled(self) -> bool:
  772. """检查是否启用SQL验证"""
  773. from agent.config import get_nested_config
  774. return (get_nested_config(self.config, "sql_validation.enable_syntax_validation", False) or
  775. get_nested_config(self.config, "sql_validation.enable_forbidden_check", False))
  776. def _is_auto_repair_enabled(self) -> bool:
  777. """检查是否启用自动修复"""
  778. from agent.config import get_nested_config
  779. return (get_nested_config(self.config, "sql_validation.enable_auto_repair", False) and
  780. get_nested_config(self.config, "sql_validation.enable_syntax_validation", False))
  781. async def _validate_sql_with_custom_priority(self, sql: str) -> Dict[str, Any]:
  782. """
  783. 按照自定义优先级验证SQL:先禁止词,再语法
  784. Args:
  785. sql: 要验证的SQL语句
  786. Returns:
  787. 验证结果字典
  788. """
  789. try:
  790. from agent.config import get_nested_config
  791. # 1. 优先检查禁止词(您要求的优先级)
  792. if get_nested_config(self.config, "sql_validation.enable_forbidden_check", True):
  793. forbidden_result = self._check_forbidden_keywords(sql)
  794. if not forbidden_result.get("valid"):
  795. return {
  796. "valid": False,
  797. "error_type": "forbidden_keywords",
  798. "error_message": forbidden_result.get("error"),
  799. "can_repair": False # 禁止词错误不能修复
  800. }
  801. # 2. 再检查语法(EXPLAIN SQL)
  802. if get_nested_config(self.config, "sql_validation.enable_syntax_validation", True):
  803. syntax_result = await self._validate_sql_syntax(sql)
  804. if not syntax_result.get("valid"):
  805. return {
  806. "valid": False,
  807. "error_type": "syntax_error",
  808. "error_message": syntax_result.get("error"),
  809. "can_repair": True # 语法错误可以尝试修复
  810. }
  811. return {"valid": True}
  812. except Exception as e:
  813. return {
  814. "valid": False,
  815. "error_type": "validation_exception",
  816. "error_message": str(e),
  817. "can_repair": False
  818. }
  819. def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
  820. """检查禁止的SQL关键词"""
  821. try:
  822. from agent.config import get_nested_config
  823. forbidden_operations = get_nested_config(
  824. self.config,
  825. "sql_validation.forbidden_operations",
  826. ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
  827. )
  828. sql_upper = sql.upper().strip()
  829. for operation in forbidden_operations:
  830. if sql_upper.startswith(operation.upper()):
  831. return {
  832. "valid": False,
  833. "error": f"不允许的操作: {operation}。本系统只支持查询操作(SELECT)。"
  834. }
  835. return {"valid": True}
  836. except Exception as e:
  837. return {
  838. "valid": False,
  839. "error": f"禁止词检查异常: {str(e)}"
  840. }
  841. async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
  842. """语法验证 - 使用EXPLAIN SQL"""
  843. try:
  844. from common.vanna_instance import get_vanna_instance
  845. import asyncio
  846. vn = get_vanna_instance()
  847. # 构建EXPLAIN查询
  848. explain_sql = f"EXPLAIN {sql}"
  849. # 异步执行验证
  850. result = await asyncio.to_thread(vn.run_sql, explain_sql)
  851. if result is not None:
  852. return {"valid": True}
  853. else:
  854. return {
  855. "valid": False,
  856. "error": "SQL语法验证失败"
  857. }
  858. except Exception as e:
  859. return {
  860. "valid": False,
  861. "error": str(e)
  862. }
  863. async def _attempt_sql_repair_once(self, sql: str, error_message: str) -> Dict[str, Any]:
  864. """
  865. 使用LLM尝试修复SQL - 只修复一次
  866. Args:
  867. sql: 原始SQL
  868. error_message: 错误信息
  869. Returns:
  870. 修复结果字典
  871. """
  872. try:
  873. from common.vanna_instance import get_vanna_instance
  874. from agent.config import get_nested_config
  875. import asyncio
  876. vn = get_vanna_instance()
  877. # 构建修复提示词
  878. repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
  879. 当前数据库类型: PostgreSQL
  880. 错误信息: {error_message}
  881. 需要修复的SQL:
  882. {sql}
  883. 修复要求:
  884. 1. 只修复语法错误和表结构错误
  885. 2. 保持SQL的原始业务逻辑不变
  886. 3. 使用PostgreSQL标准语法
  887. 4. 确保修复后的SQL语法正确
  888. 请直接输出修复后的SQL语句,不要添加其他说明文字。"""
  889. # 获取超时配置
  890. timeout = get_nested_config(self.config, "sql_validation.repair_timeout", 60)
  891. # 异步调用LLM修复
  892. response = await asyncio.wait_for(
  893. asyncio.to_thread(
  894. vn.chat_with_llm,
  895. question=repair_prompt,
  896. system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
  897. ),
  898. timeout=timeout
  899. )
  900. if response and response.strip():
  901. repaired_sql = response.strip()
  902. # 验证修复后的SQL
  903. validation_result = await self._validate_sql_syntax(repaired_sql)
  904. if validation_result.get("valid"):
  905. return {
  906. "success": True,
  907. "repaired_sql": repaired_sql,
  908. "error": None
  909. }
  910. else:
  911. return {
  912. "success": False,
  913. "repaired_sql": None,
  914. "error": f"修复后的SQL仍然无效: {validation_result.get('error')}"
  915. }
  916. else:
  917. return {
  918. "success": False,
  919. "repaired_sql": None,
  920. "error": "LLM返回空响应"
  921. }
  922. except asyncio.TimeoutError:
  923. return {
  924. "success": False,
  925. "repaired_sql": None,
  926. "error": f"修复超时({get_nested_config(self.config, 'sql_validation.repair_timeout', 60)}秒)"
  927. }
  928. except Exception as e:
  929. return {
  930. "success": False,
  931. "repaired_sql": None,
  932. "error": f"修复异常: {str(e)}"
  933. }
  934. def _generate_conversation_id(self, user_id: str) -> str:
  935. """生成对话ID - 使用与React Agent一致的格式"""
  936. import pandas as pd
  937. timestamp = pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')[:-3] # 去掉最后3位微秒
  938. return f"{user_id}:{timestamp}"
  939. def _map_node_to_progress(self, node_name: str, node_data: dict) -> dict:
  940. """将节点执行状态映射为用户友好的进度信息"""
  941. if node_name == "classify_question":
  942. question_type = node_data.get("question_type", "UNCERTAIN")
  943. confidence = node_data.get("classification_confidence", 0)
  944. return {
  945. "display_name": "分析问题类型",
  946. "icon": "🤔",
  947. "details": f"问题类型: {question_type} (置信度: {confidence:.2f})",
  948. "sub_status": f"使用{node_data.get('classification_method', '未知')}方法分类"
  949. }
  950. elif node_name == "agent_sql_generation":
  951. if node_data.get("sql_generation_success"):
  952. sql = node_data.get("sql", "")
  953. sql_preview = sql[:50] + "..." if len(sql) > 50 else sql
  954. return {
  955. "display_name": "SQL生成成功",
  956. "icon": "✅",
  957. "details": f"生成SQL: {sql_preview}",
  958. "sub_status": "验证通过,准备执行"
  959. }
  960. else:
  961. error_type = node_data.get("validation_error_type", "unknown")
  962. return {
  963. "display_name": "SQL生成处理中",
  964. "icon": "🔧",
  965. "details": f"验证状态: {error_type}",
  966. "sub_status": node_data.get("user_prompt", "正在处理")
  967. }
  968. elif node_name == "agent_sql_execution":
  969. query_result = node_data.get("query_result", {})
  970. row_count = query_result.get("row_count", 0)
  971. return {
  972. "display_name": "执行数据查询",
  973. "icon": "⚙️",
  974. "details": f"查询完成,返回 {row_count} 行数据",
  975. "sub_status": "正在生成摘要" if row_count > 0 else "查询执行完成"
  976. }
  977. elif node_name == "agent_chat":
  978. return {
  979. "display_name": "思考回答",
  980. "icon": "💭",
  981. "details": "正在处理您的问题",
  982. "sub_status": "使用智能对话模式"
  983. }
  984. elif node_name == "format_response":
  985. return {
  986. "display_name": "整理结果",
  987. "icon": "📝",
  988. "details": "正在格式化响应结果",
  989. "sub_status": "即将完成"
  990. }
  991. return None
  992. def _extract_relevant_state(self, node_data: dict) -> dict:
  993. """从节点数据中提取相关的状态信息,过滤敏感信息"""
  994. try:
  995. relevant_keys = [
  996. "current_step", "execution_path", "question_type",
  997. "classification_confidence", "classification_method",
  998. "sql_generation_success", "sql_validation_success",
  999. "routing_mode"
  1000. ]
  1001. extracted = {}
  1002. for key in relevant_keys:
  1003. if key in node_data:
  1004. extracted[key] = node_data[key]
  1005. # 特殊处理SQL:只返回前100个字符避免过长
  1006. if "sql" in node_data and node_data["sql"]:
  1007. sql = str(node_data["sql"])
  1008. extracted["sql_preview"] = sql[:100] + "..." if len(sql) > 100 else sql
  1009. # 特殊处理查询结果:只返回行数统计
  1010. if "query_result" in node_data and node_data["query_result"]:
  1011. query_result = node_data["query_result"]
  1012. if isinstance(query_result, dict):
  1013. extracted["query_summary"] = {
  1014. "row_count": query_result.get("row_count", 0),
  1015. "column_count": len(query_result.get("columns", []))
  1016. }
  1017. return extracted
  1018. except Exception as e:
  1019. self.logger.warning(f"提取状态信息失败: {str(e)}")
  1020. return {"error": "state_extraction_failed"}
  1021. # ==================== 原有方法 ====================
  1022. def _extract_original_question(self, question: str) -> str:
  1023. """
  1024. 从enhanced_question中提取原始问题
  1025. Args:
  1026. question: 可能包含上下文的问题
  1027. Returns:
  1028. str: 原始问题
  1029. """
  1030. try:
  1031. # 检查是否为enhanced_question格式
  1032. if "\n[CONTEXT]\n" in question and "\n[CURRENT]\n" in question:
  1033. # 提取[CURRENT]标签后的内容
  1034. current_start = question.find("\n[CURRENT]\n")
  1035. if current_start != -1:
  1036. original_question = question[current_start + len("\n[CURRENT]\n"):].strip()
  1037. return original_question
  1038. # 如果不是enhanced_question格式,直接返回原问题
  1039. return question.strip()
  1040. except Exception as e:
  1041. self.logger.warning(f"提取原始问题失败: {str(e)}")
  1042. return question.strip()
  1043. async def health_check(self) -> Dict[str, Any]:
  1044. """健康检查"""
  1045. try:
  1046. # 从配置获取健康检查参数
  1047. from agent.config import get_nested_config
  1048. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  1049. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  1050. if enable_full_test:
  1051. # 完整流程测试
  1052. test_result = await self.process_question(test_question, conversation_id="health_check")
  1053. return {
  1054. "status": "healthy" if test_result.get("success") else "degraded",
  1055. "test_result": test_result.get("success", False),
  1056. "workflow_compiled": True, # 动态创建,始终可用
  1057. "tools_count": len(self.tools),
  1058. "agent_reuse_enabled": False,
  1059. "message": "Agent健康检查完成"
  1060. }
  1061. else:
  1062. # 简单检查
  1063. return {
  1064. "status": "healthy",
  1065. "test_result": True,
  1066. "workflow_compiled": True, # 动态创建,始终可用
  1067. "tools_count": len(self.tools),
  1068. "agent_reuse_enabled": False,
  1069. "message": "Agent简单健康检查完成"
  1070. }
  1071. except Exception as e:
  1072. return {
  1073. "status": "unhealthy",
  1074. "error": str(e),
  1075. "workflow_compiled": True, # 动态创建,始终可用
  1076. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  1077. "agent_reuse_enabled": False,
  1078. "message": "Agent健康检查失败"
  1079. }