citu_agent.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from agent.state import AgentState
  8. from agent.classifier import QuestionClassifier
  9. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  10. from agent.utils import get_compatible_llm
  11. class CituLangGraphAgent:
  12. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  13. def __init__(self):
  14. # 加载配置
  15. try:
  16. from agent.config import get_current_config, get_nested_config
  17. self.config = get_current_config()
  18. print("[CITU_AGENT] 加载Agent配置完成")
  19. except ImportError:
  20. self.config = {}
  21. print("[CITU_AGENT] 配置文件不可用,使用默认配置")
  22. self.classifier = QuestionClassifier()
  23. self.tools = TOOLS
  24. self.llm = get_compatible_llm()
  25. # 预创建Agent实例以提升性能
  26. enable_reuse = self.config.get("performance", {}).get("enable_agent_reuse", True)
  27. if enable_reuse:
  28. print("[CITU_AGENT] 预创建Agent实例中...")
  29. self._database_executor = self._create_database_agent()
  30. self._chat_executor = self._create_chat_agent()
  31. print("[CITU_AGENT] Agent实例预创建完成")
  32. else:
  33. self._database_executor = None
  34. self._chat_executor = None
  35. print("[CITU_AGENT] Agent实例重用已禁用,将在运行时创建")
  36. self.workflow = self._create_workflow()
  37. print("[CITU_AGENT] LangGraph Agent with Tools初始化完成")
  38. def _create_workflow(self) -> StateGraph:
  39. """创建LangGraph工作流"""
  40. workflow = StateGraph(AgentState)
  41. # 添加节点
  42. workflow.add_node("classify_question", self._classify_question_node)
  43. workflow.add_node("agent_chat", self._agent_chat_node)
  44. workflow.add_node("agent_database", self._agent_database_node)
  45. workflow.add_node("format_response", self._format_response_node)
  46. # 设置入口点
  47. workflow.set_entry_point("classify_question")
  48. # 添加条件边:分类后的路由
  49. # 完全信任QuestionClassifier的决策,不再进行二次判断
  50. workflow.add_conditional_edges(
  51. "classify_question",
  52. self._route_after_classification,
  53. {
  54. "DATABASE": "agent_database",
  55. "CHAT": "agent_chat" # CHAT分支处理所有非DATABASE的情况(包括UNCERTAIN)
  56. }
  57. )
  58. # 添加边
  59. workflow.add_edge("agent_chat", "format_response")
  60. workflow.add_edge("agent_database", "format_response")
  61. workflow.add_edge("format_response", END)
  62. return workflow.compile()
  63. def _classify_question_node(self, state: AgentState) -> AgentState:
  64. """问题分类节点"""
  65. try:
  66. print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
  67. classification_result = self.classifier.classify(state["question"])
  68. # 更新状态
  69. state["question_type"] = classification_result.question_type
  70. state["classification_confidence"] = classification_result.confidence
  71. state["classification_reason"] = classification_result.reason
  72. state["classification_method"] = classification_result.method
  73. state["current_step"] = "classified"
  74. state["execution_path"].append("classify")
  75. print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  76. return state
  77. except Exception as e:
  78. print(f"[ERROR] 问题分类异常: {str(e)}")
  79. state["error"] = f"问题分类失败: {str(e)}"
  80. state["error_code"] = 500
  81. state["execution_path"].append("classify_error")
  82. return state
  83. def _create_database_agent(self):
  84. """创建数据库专用Agent(预创建)"""
  85. from agent.config import get_nested_config
  86. # 获取配置
  87. max_iterations = get_nested_config(self.config, "database_agent.max_iterations", 5)
  88. enable_verbose = get_nested_config(self.config, "database_agent.enable_verbose", True)
  89. early_stopping_method = get_nested_config(self.config, "database_agent.early_stopping_method", "generate")
  90. database_prompt = ChatPromptTemplate.from_messages([
  91. SystemMessage(content="""
  92. 你是一个专业的数据库查询助手。你的任务是帮助用户查询数据库并生成报告。
  93. 工具使用流程:
  94. 1. 首先使用 generate_sql 工具将用户问题转换为SQL
  95. 2. 然后使用 execute_sql 工具执行SQL查询
  96. 3. 最后使用 generate_summary 工具为结果生成自然语言摘要
  97. 如果任何步骤失败,请提供清晰的错误信息并建议解决方案。
  98. """),
  99. MessagesPlaceholder(variable_name="chat_history", optional=True),
  100. HumanMessage(content="{input}"),
  101. MessagesPlaceholder(variable_name="agent_scratchpad")
  102. ])
  103. database_tools = [generate_sql, execute_sql, generate_summary]
  104. agent = create_openai_tools_agent(self.llm, database_tools, database_prompt)
  105. return AgentExecutor(
  106. agent=agent,
  107. tools=database_tools,
  108. verbose=enable_verbose,
  109. handle_parsing_errors=True,
  110. max_iterations=max_iterations,
  111. early_stopping_method=early_stopping_method
  112. )
  113. def _agent_database_node(self, state: AgentState) -> AgentState:
  114. """数据库Agent节点 - 使用预创建或动态创建的Agent"""
  115. try:
  116. print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
  117. # 使用预创建的Agent或动态创建
  118. if self._database_executor is not None:
  119. executor = self._database_executor
  120. print(f"[DATABASE_AGENT] 使用预创建的Agent实例")
  121. else:
  122. executor = self._create_database_agent()
  123. print(f"[DATABASE_AGENT] 动态创建Agent实例")
  124. # 执行Agent
  125. result = executor.invoke({
  126. "input": state["question"]
  127. })
  128. # 解析Agent执行结果
  129. self._parse_database_agent_result(state, result)
  130. state["current_step"] = "database_completed"
  131. state["execution_path"].append("agent_database")
  132. print(f"[DATABASE_AGENT] 数据库查询完成")
  133. return state
  134. except Exception as e:
  135. print(f"[ERROR] 数据库Agent异常: {str(e)}")
  136. state["error"] = f"数据库查询失败: {str(e)}"
  137. state["error_code"] = 500
  138. state["current_step"] = "database_error"
  139. state["execution_path"].append("agent_database_error")
  140. return state
  141. def _create_chat_agent(self):
  142. """创建聊天专用Agent(预创建)"""
  143. from agent.config import get_nested_config
  144. # 获取配置
  145. max_iterations = get_nested_config(self.config, "chat_agent.max_iterations", 3)
  146. enable_verbose = get_nested_config(self.config, "chat_agent.enable_verbose", True)
  147. chat_prompt = ChatPromptTemplate.from_messages([
  148. SystemMessage(content="""
  149. 你是Citu智能数据问答平台的友好助手。
  150. 使用 general_chat 工具来处理用户的一般性问题、概念解释、操作指导等。
  151. 特别注意:
  152. - 如果用户的问题可能涉及数据查询,建议他们尝试数据库查询功能
  153. - 如果问题不够明确,主动询问更多细节以便更好地帮助用户
  154. - 对于模糊的问题,可以提供多种可能的解决方案
  155. - 当遇到不确定的问题时,通过友好的对话来澄清用户意图
  156. """),
  157. MessagesPlaceholder(variable_name="chat_history", optional=True),
  158. HumanMessage(content="{input}"),
  159. MessagesPlaceholder(variable_name="agent_scratchpad")
  160. ])
  161. chat_tools = [general_chat]
  162. agent = create_openai_tools_agent(self.llm, chat_tools, chat_prompt)
  163. return AgentExecutor(
  164. agent=agent,
  165. tools=chat_tools,
  166. verbose=enable_verbose,
  167. handle_parsing_errors=True,
  168. max_iterations=max_iterations
  169. )
  170. def _agent_chat_node(self, state: AgentState) -> AgentState:
  171. """聊天Agent节点 - 使用预创建或动态创建的Agent"""
  172. try:
  173. print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
  174. # 使用预创建的Agent或动态创建
  175. if self._chat_executor is not None:
  176. executor = self._chat_executor
  177. print(f"[CHAT_AGENT] 使用预创建的Agent实例")
  178. else:
  179. executor = self._create_chat_agent()
  180. print(f"[CHAT_AGENT] 动态创建Agent实例")
  181. # 构建上下文
  182. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  183. context = None
  184. if enable_context_injection and state.get("classification_reason"):
  185. context = f"分类原因: {state['classification_reason']}"
  186. # 执行Agent
  187. input_text = state["question"]
  188. if context:
  189. input_text = f"{state['question']}\n\n上下文: {context}"
  190. result = executor.invoke({
  191. "input": input_text
  192. })
  193. # 提取聊天响应
  194. state["chat_response"] = result.get("output", "")
  195. state["current_step"] = "chat_completed"
  196. state["execution_path"].append("agent_chat")
  197. print(f"[CHAT_AGENT] 聊天处理完成")
  198. return state
  199. except Exception as e:
  200. print(f"[ERROR] 聊天Agent异常: {str(e)}")
  201. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  202. state["current_step"] = "chat_error"
  203. state["execution_path"].append("agent_chat_error")
  204. return state
  205. def _format_response_node(self, state: AgentState) -> AgentState:
  206. """格式化最终响应节点"""
  207. try:
  208. print(f"[FORMAT_NODE] 开始格式化响应,问题类型: {state['question_type']}")
  209. state["current_step"] = "completed"
  210. state["execution_path"].append("format_response")
  211. # 根据问题类型和执行状态格式化响应
  212. if state.get("error"):
  213. # 有错误的情况
  214. state["final_response"] = {
  215. "success": False,
  216. "error": state["error"],
  217. "error_code": state.get("error_code", 500),
  218. "question_type": state["question_type"],
  219. "execution_path": state["execution_path"],
  220. "classification_info": {
  221. "confidence": state.get("classification_confidence", 0),
  222. "reason": state.get("classification_reason", ""),
  223. "method": state.get("classification_method", "")
  224. }
  225. }
  226. elif state["question_type"] == "DATABASE":
  227. # 数据库查询类型
  228. if state.get("data_result") and state.get("summary"):
  229. # 完整的数据库查询流程
  230. state["final_response"] = {
  231. "success": True,
  232. "response": state["summary"],
  233. "type": "DATABASE",
  234. "sql": state.get("sql"),
  235. "data_result": state["data_result"],
  236. "summary": state["summary"],
  237. "execution_path": state["execution_path"],
  238. "classification_info": {
  239. "confidence": state["classification_confidence"],
  240. "reason": state["classification_reason"],
  241. "method": state["classification_method"]
  242. }
  243. }
  244. else:
  245. # 数据库查询失败,但有部分结果
  246. state["final_response"] = {
  247. "success": False,
  248. "error": state.get("error", "数据库查询未完成"),
  249. "type": "DATABASE",
  250. "sql": state.get("sql"),
  251. "execution_path": state["execution_path"]
  252. }
  253. else:
  254. # 聊天类型
  255. state["final_response"] = {
  256. "success": True,
  257. "response": state.get("chat_response", ""),
  258. "type": "CHAT",
  259. "execution_path": state["execution_path"],
  260. "classification_info": {
  261. "confidence": state["classification_confidence"],
  262. "reason": state["classification_reason"],
  263. "method": state["classification_method"]
  264. }
  265. }
  266. print(f"[FORMAT_NODE] 响应格式化完成")
  267. return state
  268. except Exception as e:
  269. print(f"[ERROR] 响应格式化异常: {str(e)}")
  270. state["final_response"] = {
  271. "success": False,
  272. "error": f"响应格式化异常: {str(e)}",
  273. "error_code": 500,
  274. "execution_path": state["execution_path"]
  275. }
  276. return state
  277. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  278. """
  279. 分类后的路由决策
  280. 完全信任QuestionClassifier的决策:
  281. - DATABASE类型 → 数据库Agent
  282. - CHAT和UNCERTAIN类型 → 聊天Agent
  283. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  284. """
  285. question_type = state["question_type"]
  286. confidence = state["classification_confidence"]
  287. print(f"[ROUTE] 分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  288. if question_type == "DATABASE":
  289. return "DATABASE"
  290. else:
  291. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  292. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  293. return "CHAT"
  294. def _parse_database_agent_result(self, state: AgentState, agent_result: Dict[str, Any]):
  295. """解析数据库Agent的执行结果"""
  296. try:
  297. output = agent_result.get("output", "")
  298. intermediate_steps = agent_result.get("intermediate_steps", [])
  299. # 从intermediate_steps中提取工具调用结果
  300. for step in intermediate_steps:
  301. if len(step) >= 2:
  302. action, observation = step[0], step[1]
  303. if hasattr(action, 'tool') and hasattr(action, 'tool_input'):
  304. tool_name = action.tool
  305. tool_result = observation
  306. # 解析工具结果
  307. if tool_name == "generate_sql" and isinstance(tool_result, dict):
  308. if tool_result.get("success"):
  309. state["sql"] = tool_result.get("sql")
  310. else:
  311. state["error"] = tool_result.get("error")
  312. elif tool_name == "execute_sql" and isinstance(tool_result, dict):
  313. if tool_result.get("success"):
  314. state["data_result"] = tool_result.get("data_result")
  315. else:
  316. state["error"] = tool_result.get("error")
  317. elif tool_name == "generate_summary" and isinstance(tool_result, dict):
  318. if tool_result.get("success"):
  319. state["summary"] = tool_result.get("summary")
  320. # 如果没有从工具结果中获取到摘要,使用Agent的最终输出
  321. if not state.get("summary") and output:
  322. state["summary"] = output
  323. except Exception as e:
  324. print(f"[WARNING] 解析数据库Agent结果失败: {str(e)}")
  325. # 使用Agent的输出作为摘要
  326. state["summary"] = agent_result.get("output", "查询处理完成")
  327. def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
  328. """
  329. 统一的问题处理入口
  330. Args:
  331. question: 用户问题
  332. session_id: 会话ID
  333. Returns:
  334. Dict包含完整的处理结果
  335. """
  336. try:
  337. print(f"[CITU_AGENT] 开始处理问题: {question}")
  338. # 初始化状态
  339. initial_state = self._create_initial_state(question, session_id)
  340. # 执行工作流
  341. final_state = self.workflow.invoke(
  342. initial_state,
  343. config={
  344. "configurable": {"session_id": session_id}
  345. } if session_id else None
  346. )
  347. # 提取最终结果
  348. result = final_state["final_response"]
  349. print(f"[CITU_AGENT] 问题处理完成: {result.get('success', False)}")
  350. return result
  351. except Exception as e:
  352. print(f"[ERROR] Agent执行异常: {str(e)}")
  353. return {
  354. "success": False,
  355. "error": f"Agent系统异常: {str(e)}",
  356. "error_code": 500,
  357. "execution_path": ["error"]
  358. }
  359. def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
  360. """创建初始状态"""
  361. return AgentState(
  362. # 输入信息
  363. question=question,
  364. session_id=session_id,
  365. # 分类结果
  366. question_type="",
  367. classification_confidence=0.0,
  368. classification_reason="",
  369. classification_method="",
  370. # 数据库查询流程状态
  371. sql=None,
  372. sql_generation_attempts=0,
  373. data_result=None,
  374. summary=None,
  375. # 聊天响应
  376. chat_response=None,
  377. # 最终输出
  378. final_response={},
  379. # 错误处理
  380. error=None,
  381. error_code=None,
  382. # 流程控制
  383. current_step="start",
  384. execution_path=[],
  385. retry_count=0,
  386. max_retries=2,
  387. # 调试信息
  388. debug_info={}
  389. )
  390. def health_check(self) -> Dict[str, Any]:
  391. """健康检查"""
  392. try:
  393. # 从配置获取健康检查参数
  394. from agent.config import get_nested_config
  395. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  396. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  397. if enable_full_test:
  398. # 完整流程测试
  399. test_result = self.process_question(test_question, "health_check")
  400. return {
  401. "status": "healthy" if test_result.get("success") else "degraded",
  402. "test_result": test_result.get("success", False),
  403. "workflow_compiled": self.workflow is not None,
  404. "tools_count": len(self.tools),
  405. "agent_reuse_enabled": self._database_executor is not None and self._chat_executor is not None,
  406. "message": "Agent健康检查完成"
  407. }
  408. else:
  409. # 简单检查
  410. return {
  411. "status": "healthy",
  412. "test_result": True,
  413. "workflow_compiled": self.workflow is not None,
  414. "tools_count": len(self.tools),
  415. "agent_reuse_enabled": self._database_executor is not None and self._chat_executor is not None,
  416. "message": "Agent简单健康检查完成"
  417. }
  418. except Exception as e:
  419. return {
  420. "status": "unhealthy",
  421. "error": str(e),
  422. "workflow_compiled": self.workflow is not None,
  423. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  424. "agent_reuse_enabled": False,
  425. "message": "Agent健康检查失败"
  426. }