citu_agent.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from agent.state import AgentState
  8. from agent.classifier import QuestionClassifier
  9. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  10. from agent.utils import get_compatible_llm
  11. from app_config import ENABLE_RESULT_SUMMARY
  12. class CituLangGraphAgent:
  13. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  14. def __init__(self):
  15. # 加载配置
  16. try:
  17. from agent.config import get_current_config, get_nested_config
  18. self.config = get_current_config()
  19. print("[CITU_AGENT] 加载Agent配置完成")
  20. except ImportError:
  21. self.config = {}
  22. print("[CITU_AGENT] 配置文件不可用,使用默认配置")
  23. self.classifier = QuestionClassifier()
  24. self.tools = TOOLS
  25. self.llm = get_compatible_llm()
  26. # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
  27. print("[CITU_AGENT] 使用直接工具调用模式")
  28. self.workflow = self._create_workflow()
  29. print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
  30. def _create_workflow(self) -> StateGraph:
  31. """创建LangGraph工作流"""
  32. workflow = StateGraph(AgentState)
  33. # 添加节点
  34. workflow.add_node("classify_question", self._classify_question_node)
  35. workflow.add_node("agent_chat", self._agent_chat_node)
  36. workflow.add_node("agent_database", self._agent_database_node)
  37. workflow.add_node("format_response", self._format_response_node)
  38. # 设置入口点
  39. workflow.set_entry_point("classify_question")
  40. # 添加条件边:分类后的路由
  41. # 完全信任QuestionClassifier的决策,不再进行二次判断
  42. workflow.add_conditional_edges(
  43. "classify_question",
  44. self._route_after_classification,
  45. {
  46. "DATABASE": "agent_database",
  47. "CHAT": "agent_chat" # CHAT分支处理所有非DATABASE的情况(包括UNCERTAIN)
  48. }
  49. )
  50. # 添加边
  51. workflow.add_edge("agent_chat", "format_response")
  52. workflow.add_edge("agent_database", "format_response")
  53. workflow.add_edge("format_response", END)
  54. return workflow.compile()
  55. def _classify_question_node(self, state: AgentState) -> AgentState:
  56. """问题分类节点"""
  57. try:
  58. print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
  59. classification_result = self.classifier.classify(state["question"])
  60. # 更新状态
  61. state["question_type"] = classification_result.question_type
  62. state["classification_confidence"] = classification_result.confidence
  63. state["classification_reason"] = classification_result.reason
  64. state["classification_method"] = classification_result.method
  65. state["current_step"] = "classified"
  66. state["execution_path"].append("classify")
  67. print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  68. return state
  69. except Exception as e:
  70. print(f"[ERROR] 问题分类异常: {str(e)}")
  71. state["error"] = f"问题分类失败: {str(e)}"
  72. state["error_code"] = 500
  73. state["execution_path"].append("classify_error")
  74. return state
  75. def _agent_database_node(self, state: AgentState) -> AgentState:
  76. """数据库Agent节点 - 直接工具调用模式"""
  77. try:
  78. print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
  79. question = state["question"]
  80. # 步骤1:生成SQL
  81. print(f"[DATABASE_AGENT] 步骤1:生成SQL")
  82. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  83. if not sql_result.get("success"):
  84. print(f"[DATABASE_AGENT] SQL生成失败: {sql_result.get('error')}")
  85. state["error"] = sql_result.get("error", "SQL生成失败")
  86. state["error_code"] = 500
  87. state["current_step"] = "database_error"
  88. state["execution_path"].append("agent_database_error")
  89. return state
  90. sql = sql_result.get("sql")
  91. state["sql"] = sql
  92. print(f"[DATABASE_AGENT] SQL生成成功: {sql}")
  93. # 步骤1.5:检查是否为解释性响应而非SQL
  94. error_type = sql_result.get("error_type")
  95. if error_type == "llm_explanation":
  96. # LLM返回了解释性文本,直接作为最终答案
  97. explanation = sql_result.get("error", "")
  98. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  99. state["current_step"] = "database_completed"
  100. state["execution_path"].append("agent_database")
  101. print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
  102. return state
  103. # 额外验证:检查SQL格式(防止工具误判)
  104. from agent.utils import _is_valid_sql_format
  105. if not _is_valid_sql_format(sql):
  106. # 内容看起来不是SQL,当作解释性响应处理
  107. state["chat_response"] = sql + " 请尝试提问其它问题。"
  108. state["current_step"] = "database_completed"
  109. state["execution_path"].append("agent_database")
  110. print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
  111. return state
  112. # 步骤2:执行SQL
  113. print(f"[DATABASE_AGENT] 步骤2:执行SQL")
  114. execute_result = execute_sql.invoke({"sql": sql})
  115. if not execute_result.get("success"):
  116. print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
  117. state["error"] = execute_result.get("error", "SQL执行失败")
  118. state["error_code"] = 500
  119. state["current_step"] = "database_error"
  120. state["execution_path"].append("agent_database_error")
  121. return state
  122. data_result = execute_result.get("data_result")
  123. state["data_result"] = data_result
  124. print(f"[DATABASE_AGENT] SQL执行成功,返回 {data_result.get('row_count', 0)} 行数据")
  125. # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
  126. if ENABLE_RESULT_SUMMARY and data_result.get('row_count', 0) > 0:
  127. print(f"[DATABASE_AGENT] 步骤3:生成摘要")
  128. summary_result = generate_summary.invoke({
  129. "question": question,
  130. "data_result": data_result,
  131. "sql": sql
  132. })
  133. if not summary_result.get("success"):
  134. print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
  135. # 摘要生成失败不是致命错误,使用默认摘要
  136. state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
  137. else:
  138. state["summary"] = summary_result.get("summary")
  139. print(f"[DATABASE_AGENT] 摘要生成成功")
  140. else:
  141. print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={data_result.get('row_count', 0)})")
  142. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  143. state["current_step"] = "database_completed"
  144. state["execution_path"].append("agent_database")
  145. print(f"[DATABASE_AGENT] 数据库查询完成")
  146. return state
  147. except Exception as e:
  148. print(f"[ERROR] 数据库Agent异常: {str(e)}")
  149. import traceback
  150. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  151. state["error"] = f"数据库查询失败: {str(e)}"
  152. state["error_code"] = 500
  153. state["current_step"] = "database_error"
  154. state["execution_path"].append("agent_database_error")
  155. return state
  156. def _agent_chat_node(self, state: AgentState) -> AgentState:
  157. """聊天Agent节点 - 直接工具调用模式"""
  158. try:
  159. print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
  160. question = state["question"]
  161. # 构建上下文
  162. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  163. context = None
  164. if enable_context_injection and state.get("classification_reason"):
  165. context = f"分类原因: {state['classification_reason']}"
  166. # 直接调用general_chat工具
  167. print(f"[CHAT_AGENT] 调用general_chat工具")
  168. chat_result = general_chat.invoke({
  169. "question": question,
  170. "context": context
  171. })
  172. if chat_result.get("success"):
  173. state["chat_response"] = chat_result.get("response", "")
  174. print(f"[CHAT_AGENT] 聊天处理成功")
  175. else:
  176. # 处理失败,使用备用响应
  177. state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
  178. print(f"[CHAT_AGENT] 聊天处理失败,使用备用响应: {chat_result.get('error')}")
  179. state["current_step"] = "chat_completed"
  180. state["execution_path"].append("agent_chat")
  181. print(f"[CHAT_AGENT] 聊天处理完成")
  182. return state
  183. except Exception as e:
  184. print(f"[ERROR] 聊天Agent异常: {str(e)}")
  185. import traceback
  186. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  187. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  188. state["current_step"] = "chat_error"
  189. state["execution_path"].append("agent_chat_error")
  190. return state
  191. def _format_response_node(self, state: AgentState) -> AgentState:
  192. """格式化最终响应节点"""
  193. try:
  194. print(f"[FORMAT_NODE] 开始格式化响应,问题类型: {state['question_type']}")
  195. state["current_step"] = "completed"
  196. state["execution_path"].append("format_response")
  197. # 根据问题类型和执行状态格式化响应
  198. if state.get("error"):
  199. # 有错误的情况
  200. state["final_response"] = {
  201. "success": False,
  202. "error": state["error"],
  203. "error_code": state.get("error_code", 500),
  204. "question_type": state["question_type"],
  205. "execution_path": state["execution_path"],
  206. "classification_info": {
  207. "confidence": state.get("classification_confidence", 0),
  208. "reason": state.get("classification_reason", ""),
  209. "method": state.get("classification_method", "")
  210. }
  211. }
  212. elif state["question_type"] == "DATABASE":
  213. # 数据库查询类型
  214. if state.get("chat_response"):
  215. # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
  216. state["final_response"] = {
  217. "success": True,
  218. "response": state["chat_response"],
  219. "type": "DATABASE",
  220. "sql": state.get("sql"),
  221. "data_result": state.get("data_result"),
  222. "execution_path": state["execution_path"],
  223. "classification_info": {
  224. "confidence": state["classification_confidence"],
  225. "reason": state["classification_reason"],
  226. "method": state["classification_method"]
  227. }
  228. }
  229. elif state.get("summary"):
  230. # 正常的数据库查询结果,有摘要的情况
  231. # 不将summary复制到response,让response保持为空
  232. state["final_response"] = {
  233. "success": True,
  234. "type": "DATABASE",
  235. "sql": state.get("sql"),
  236. "data_result": state.get("data_result"),
  237. "summary": state["summary"],
  238. "execution_path": state["execution_path"],
  239. "classification_info": {
  240. "confidence": state["classification_confidence"],
  241. "reason": state["classification_reason"],
  242. "method": state["classification_method"]
  243. }
  244. }
  245. elif state.get("data_result"):
  246. # 有数据但没有摘要(摘要被配置禁用)
  247. data_result = state.get("data_result")
  248. row_count = data_result.get("row_count", 0)
  249. # 构建基本响应,不包含summary字段和response字段
  250. # 用户应该直接从data_result.columns和data_result.rows获取数据
  251. state["final_response"] = {
  252. "success": True,
  253. "type": "DATABASE",
  254. "sql": state.get("sql"),
  255. "data_result": data_result,
  256. "execution_path": state["execution_path"],
  257. "classification_info": {
  258. "confidence": state["classification_confidence"],
  259. "reason": state["classification_reason"],
  260. "method": state["classification_method"]
  261. }
  262. }
  263. else:
  264. # 数据库查询失败,没有任何结果
  265. state["final_response"] = {
  266. "success": False,
  267. "error": state.get("error", "数据库查询未完成"),
  268. "type": "DATABASE",
  269. "sql": state.get("sql"),
  270. "execution_path": state["execution_path"]
  271. }
  272. else:
  273. # 聊天类型
  274. state["final_response"] = {
  275. "success": True,
  276. "response": state.get("chat_response", ""),
  277. "type": "CHAT",
  278. "execution_path": state["execution_path"],
  279. "classification_info": {
  280. "confidence": state["classification_confidence"],
  281. "reason": state["classification_reason"],
  282. "method": state["classification_method"]
  283. }
  284. }
  285. print(f"[FORMAT_NODE] 响应格式化完成")
  286. return state
  287. except Exception as e:
  288. print(f"[ERROR] 响应格式化异常: {str(e)}")
  289. state["final_response"] = {
  290. "success": False,
  291. "error": f"响应格式化异常: {str(e)}",
  292. "error_code": 500,
  293. "execution_path": state["execution_path"]
  294. }
  295. return state
  296. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  297. """
  298. 分类后的路由决策
  299. 完全信任QuestionClassifier的决策:
  300. - DATABASE类型 → 数据库Agent
  301. - CHAT和UNCERTAIN类型 → 聊天Agent
  302. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  303. """
  304. question_type = state["question_type"]
  305. confidence = state["classification_confidence"]
  306. print(f"[ROUTE] 分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  307. if question_type == "DATABASE":
  308. return "DATABASE"
  309. else:
  310. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  311. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  312. return "CHAT"
  313. def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
  314. """
  315. 统一的问题处理入口
  316. Args:
  317. question: 用户问题
  318. session_id: 会话ID
  319. Returns:
  320. Dict包含完整的处理结果
  321. """
  322. try:
  323. print(f"[CITU_AGENT] 开始处理问题: {question}")
  324. # 初始化状态
  325. initial_state = self._create_initial_state(question, session_id)
  326. # 执行工作流
  327. final_state = self.workflow.invoke(
  328. initial_state,
  329. config={
  330. "configurable": {"session_id": session_id}
  331. } if session_id else None
  332. )
  333. # 提取最终结果
  334. result = final_state["final_response"]
  335. print(f"[CITU_AGENT] 问题处理完成: {result.get('success', False)}")
  336. return result
  337. except Exception as e:
  338. print(f"[ERROR] Agent执行异常: {str(e)}")
  339. return {
  340. "success": False,
  341. "error": f"Agent系统异常: {str(e)}",
  342. "error_code": 500,
  343. "execution_path": ["error"]
  344. }
  345. def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
  346. """创建初始状态"""
  347. return AgentState(
  348. # 输入信息
  349. question=question,
  350. session_id=session_id,
  351. # 分类结果
  352. question_type="",
  353. classification_confidence=0.0,
  354. classification_reason="",
  355. classification_method="",
  356. # 数据库查询流程状态
  357. sql=None,
  358. sql_generation_attempts=0,
  359. data_result=None,
  360. summary=None,
  361. # 聊天响应
  362. chat_response=None,
  363. # 最终输出
  364. final_response={},
  365. # 错误处理
  366. error=None,
  367. error_code=None,
  368. # 流程控制
  369. current_step="start",
  370. execution_path=[],
  371. retry_count=0,
  372. max_retries=2,
  373. # 调试信息
  374. debug_info={}
  375. )
  376. def health_check(self) -> Dict[str, Any]:
  377. """健康检查"""
  378. try:
  379. # 从配置获取健康检查参数
  380. from agent.config import get_nested_config
  381. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  382. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  383. if enable_full_test:
  384. # 完整流程测试
  385. test_result = self.process_question(test_question, "health_check")
  386. return {
  387. "status": "healthy" if test_result.get("success") else "degraded",
  388. "test_result": test_result.get("success", False),
  389. "workflow_compiled": self.workflow is not None,
  390. "tools_count": len(self.tools),
  391. "agent_reuse_enabled": False,
  392. "message": "Agent健康检查完成"
  393. }
  394. else:
  395. # 简单检查
  396. return {
  397. "status": "healthy",
  398. "test_result": True,
  399. "workflow_compiled": self.workflow is not None,
  400. "tools_count": len(self.tools),
  401. "agent_reuse_enabled": False,
  402. "message": "Agent简单健康检查完成"
  403. }
  404. except Exception as e:
  405. return {
  406. "status": "unhealthy",
  407. "error": str(e),
  408. "workflow_compiled": self.workflow is not None,
  409. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  410. "agent_reuse_enabled": False,
  411. "message": "Agent健康检查失败"
  412. }