citu_agent.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from agent.state import AgentState
  8. from agent.classifier import QuestionClassifier
  9. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  10. from agent.utils import get_compatible_llm
  11. from app_config import ENABLE_RESULT_SUMMARY
  12. class CituLangGraphAgent:
  13. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  14. def __init__(self):
  15. # 加载配置
  16. try:
  17. from agent.config import get_current_config, get_nested_config
  18. self.config = get_current_config()
  19. print("[CITU_AGENT] 加载Agent配置完成")
  20. except ImportError:
  21. self.config = {}
  22. print("[CITU_AGENT] 配置文件不可用,使用默认配置")
  23. self.classifier = QuestionClassifier()
  24. self.tools = TOOLS
  25. self.llm = get_compatible_llm()
  26. # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
  27. print("[CITU_AGENT] 使用直接工具调用模式")
  28. self.workflow = self._create_workflow()
  29. print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
  30. def _create_workflow(self) -> StateGraph:
  31. """根据路由模式创建不同的工作流"""
  32. try:
  33. from app_config import QUESTION_ROUTING_MODE
  34. print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
  35. except ImportError:
  36. QUESTION_ROUTING_MODE = "hybrid"
  37. print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  38. workflow = StateGraph(AgentState)
  39. # 根据路由模式创建不同的工作流
  40. if QUESTION_ROUTING_MODE == "database_direct":
  41. # 直接数据库模式:跳过分类,直接进入数据库处理
  42. workflow.add_node("init_direct_database", self._init_direct_database_node)
  43. workflow.add_node("agent_database", self._agent_database_node)
  44. workflow.add_node("format_response", self._format_response_node)
  45. workflow.set_entry_point("init_direct_database")
  46. workflow.add_edge("init_direct_database", "agent_database")
  47. workflow.add_edge("agent_database", "format_response")
  48. workflow.add_edge("format_response", END)
  49. elif QUESTION_ROUTING_MODE == "chat_direct":
  50. # 直接聊天模式:跳过分类,直接进入聊天处理
  51. workflow.add_node("init_direct_chat", self._init_direct_chat_node)
  52. workflow.add_node("agent_chat", self._agent_chat_node)
  53. workflow.add_node("format_response", self._format_response_node)
  54. workflow.set_entry_point("init_direct_chat")
  55. workflow.add_edge("init_direct_chat", "agent_chat")
  56. workflow.add_edge("agent_chat", "format_response")
  57. workflow.add_edge("format_response", END)
  58. else:
  59. # 其他模式(hybrid, llm_only):使用原有的分类工作流
  60. workflow.add_node("classify_question", self._classify_question_node)
  61. workflow.add_node("agent_chat", self._agent_chat_node)
  62. workflow.add_node("agent_database", self._agent_database_node)
  63. workflow.add_node("format_response", self._format_response_node)
  64. workflow.set_entry_point("classify_question")
  65. # 添加条件边:分类后的路由
  66. workflow.add_conditional_edges(
  67. "classify_question",
  68. self._route_after_classification,
  69. {
  70. "DATABASE": "agent_database",
  71. "CHAT": "agent_chat"
  72. }
  73. )
  74. workflow.add_edge("agent_chat", "format_response")
  75. workflow.add_edge("agent_database", "format_response")
  76. workflow.add_edge("format_response", END)
  77. return workflow.compile()
  78. def _init_direct_database_node(self, state: AgentState) -> AgentState:
  79. """初始化直接数据库模式的状态"""
  80. try:
  81. from app_config import QUESTION_ROUTING_MODE
  82. # 设置直接数据库模式的分类状态
  83. state["question_type"] = "DATABASE"
  84. state["classification_confidence"] = 1.0
  85. state["classification_reason"] = "配置为直接数据库查询模式"
  86. state["classification_method"] = "direct_database"
  87. state["routing_mode"] = QUESTION_ROUTING_MODE
  88. state["current_step"] = "direct_database_init"
  89. state["execution_path"].append("init_direct_database")
  90. print(f"[DIRECT_DATABASE] 直接数据库模式初始化完成")
  91. return state
  92. except Exception as e:
  93. print(f"[ERROR] 直接数据库模式初始化异常: {str(e)}")
  94. state["error"] = f"直接数据库模式初始化失败: {str(e)}"
  95. state["error_code"] = 500
  96. state["execution_path"].append("init_direct_database_error")
  97. return state
  98. def _init_direct_chat_node(self, state: AgentState) -> AgentState:
  99. """初始化直接聊天模式的状态"""
  100. try:
  101. from app_config import QUESTION_ROUTING_MODE
  102. # 设置直接聊天模式的分类状态
  103. state["question_type"] = "CHAT"
  104. state["classification_confidence"] = 1.0
  105. state["classification_reason"] = "配置为直接聊天模式"
  106. state["classification_method"] = "direct_chat"
  107. state["routing_mode"] = QUESTION_ROUTING_MODE
  108. state["current_step"] = "direct_chat_init"
  109. state["execution_path"].append("init_direct_chat")
  110. print(f"[DIRECT_CHAT] 直接聊天模式初始化完成")
  111. return state
  112. except Exception as e:
  113. print(f"[ERROR] 直接聊天模式初始化异常: {str(e)}")
  114. state["error"] = f"直接聊天模式初始化失败: {str(e)}"
  115. state["error_code"] = 500
  116. state["execution_path"].append("init_direct_chat_error")
  117. return state
  118. def _classify_question_node(self, state: AgentState) -> AgentState:
  119. """问题分类节点 - 支持路由模式"""
  120. try:
  121. from app_config import QUESTION_ROUTING_MODE
  122. print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
  123. classification_result = self.classifier.classify(state["question"])
  124. # 更新状态
  125. state["question_type"] = classification_result.question_type
  126. state["classification_confidence"] = classification_result.confidence
  127. state["classification_reason"] = classification_result.reason
  128. state["classification_method"] = classification_result.method
  129. state["routing_mode"] = QUESTION_ROUTING_MODE
  130. state["current_step"] = "classified"
  131. state["execution_path"].append("classify")
  132. print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  133. print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
  134. return state
  135. except Exception as e:
  136. print(f"[ERROR] 问题分类异常: {str(e)}")
  137. state["error"] = f"问题分类失败: {str(e)}"
  138. state["error_code"] = 500
  139. state["execution_path"].append("classify_error")
  140. return state
  141. def _agent_database_node(self, state: AgentState) -> AgentState:
  142. """数据库Agent节点 - 直接工具调用模式"""
  143. try:
  144. print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
  145. question = state["question"]
  146. # 步骤1:生成SQL
  147. print(f"[DATABASE_AGENT] 步骤1:生成SQL")
  148. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  149. if not sql_result.get("success"):
  150. print(f"[DATABASE_AGENT] SQL生成失败: {sql_result.get('error')}")
  151. state["error"] = sql_result.get("error", "SQL生成失败")
  152. state["error_code"] = 500
  153. state["current_step"] = "database_error"
  154. state["execution_path"].append("agent_database_error")
  155. return state
  156. sql = sql_result.get("sql")
  157. state["sql"] = sql
  158. print(f"[DATABASE_AGENT] SQL生成成功: {sql}")
  159. # 步骤1.5:检查是否为解释性响应而非SQL
  160. error_type = sql_result.get("error_type")
  161. if error_type == "llm_explanation":
  162. # LLM返回了解释性文本,直接作为最终答案
  163. explanation = sql_result.get("error", "")
  164. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  165. state["current_step"] = "database_completed"
  166. state["execution_path"].append("agent_database")
  167. print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
  168. return state
  169. # 额外验证:检查SQL格式(防止工具误判)
  170. from agent.utils import _is_valid_sql_format
  171. if not _is_valid_sql_format(sql):
  172. # 内容看起来不是SQL,当作解释性响应处理
  173. state["chat_response"] = sql + " 请尝试提问其它问题。"
  174. state["current_step"] = "database_completed"
  175. state["execution_path"].append("agent_database")
  176. print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
  177. return state
  178. # 步骤2:执行SQL
  179. print(f"[DATABASE_AGENT] 步骤2:执行SQL")
  180. execute_result = execute_sql.invoke({"sql": sql})
  181. if not execute_result.get("success"):
  182. print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
  183. state["error"] = execute_result.get("error", "SQL执行失败")
  184. state["error_code"] = 500
  185. state["current_step"] = "database_error"
  186. state["execution_path"].append("agent_database_error")
  187. return state
  188. query_result = execute_result.get("data_result")
  189. state["query_result"] = query_result
  190. print(f"[DATABASE_AGENT] SQL执行成功,返回 {query_result.get('row_count', 0)} 行数据")
  191. # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
  192. if ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0:
  193. print(f"[DATABASE_AGENT] 步骤3:生成摘要")
  194. summary_result = generate_summary.invoke({
  195. "question": question,
  196. "query_result": query_result,
  197. "sql": sql
  198. })
  199. if not summary_result.get("success"):
  200. print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
  201. # 摘要生成失败不是致命错误,使用默认摘要
  202. state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
  203. else:
  204. state["summary"] = summary_result.get("summary")
  205. print(f"[DATABASE_AGENT] 摘要生成成功")
  206. else:
  207. print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={query_result.get('row_count', 0)})")
  208. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  209. state["current_step"] = "database_completed"
  210. state["execution_path"].append("agent_database")
  211. print(f"[DATABASE_AGENT] 数据库查询完成")
  212. return state
  213. except Exception as e:
  214. print(f"[ERROR] 数据库Agent异常: {str(e)}")
  215. import traceback
  216. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  217. state["error"] = f"数据库查询失败: {str(e)}"
  218. state["error_code"] = 500
  219. state["current_step"] = "database_error"
  220. state["execution_path"].append("agent_database_error")
  221. return state
  222. def _agent_chat_node(self, state: AgentState) -> AgentState:
  223. """聊天Agent节点 - 直接工具调用模式"""
  224. try:
  225. print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
  226. question = state["question"]
  227. # 构建上下文 - 仅使用真实的对话历史上下文
  228. # 注意:不要将分类原因传递给LLM,那是系统内部的路由信息
  229. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  230. context = None
  231. if enable_context_injection:
  232. # TODO: 在这里可以添加真实的对话历史上下文
  233. # 例如从Redis或其他存储中获取最近的对话记录
  234. # context = get_conversation_history(state.get("session_id"))
  235. pass
  236. # 直接调用general_chat工具
  237. print(f"[CHAT_AGENT] 调用general_chat工具")
  238. chat_result = general_chat.invoke({
  239. "question": question,
  240. "context": context
  241. })
  242. if chat_result.get("success"):
  243. state["chat_response"] = chat_result.get("response", "")
  244. print(f"[CHAT_AGENT] 聊天处理成功")
  245. else:
  246. # 处理失败,使用备用响应
  247. state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
  248. print(f"[CHAT_AGENT] 聊天处理失败,使用备用响应: {chat_result.get('error')}")
  249. state["current_step"] = "chat_completed"
  250. state["execution_path"].append("agent_chat")
  251. print(f"[CHAT_AGENT] 聊天处理完成")
  252. return state
  253. except Exception as e:
  254. print(f"[ERROR] 聊天Agent异常: {str(e)}")
  255. import traceback
  256. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  257. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  258. state["current_step"] = "chat_error"
  259. state["execution_path"].append("agent_chat_error")
  260. return state
  261. def _format_response_node(self, state: AgentState) -> AgentState:
  262. """格式化最终响应节点"""
  263. try:
  264. print(f"[FORMAT_NODE] 开始格式化响应,问题类型: {state['question_type']}")
  265. state["current_step"] = "completed"
  266. state["execution_path"].append("format_response")
  267. # 根据问题类型和执行状态格式化响应
  268. if state.get("error"):
  269. # 有错误的情况
  270. state["final_response"] = {
  271. "success": False,
  272. "error": state["error"],
  273. "error_code": state.get("error_code", 500),
  274. "question_type": state["question_type"],
  275. "execution_path": state["execution_path"],
  276. "classification_info": {
  277. "confidence": state.get("classification_confidence", 0),
  278. "reason": state.get("classification_reason", ""),
  279. "method": state.get("classification_method", "")
  280. }
  281. }
  282. elif state["question_type"] == "DATABASE":
  283. # 数据库查询类型
  284. if state.get("chat_response"):
  285. # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
  286. state["final_response"] = {
  287. "success": True,
  288. "response": state["chat_response"],
  289. "type": "DATABASE",
  290. "sql": state.get("sql"),
  291. "query_result": state.get("query_result"), # 获取query_result字段
  292. "execution_path": state["execution_path"],
  293. "classification_info": {
  294. "confidence": state["classification_confidence"],
  295. "reason": state["classification_reason"],
  296. "method": state["classification_method"]
  297. }
  298. }
  299. elif state.get("summary"):
  300. # 正常的数据库查询结果,有摘要的情况
  301. # 不将summary复制到response,让response保持为空
  302. state["final_response"] = {
  303. "success": True,
  304. "type": "DATABASE",
  305. "sql": state.get("sql"),
  306. "query_result": state.get("query_result"), # 获取query_result字段
  307. "summary": state["summary"],
  308. "execution_path": state["execution_path"],
  309. "classification_info": {
  310. "confidence": state["classification_confidence"],
  311. "reason": state["classification_reason"],
  312. "method": state["classification_method"]
  313. }
  314. }
  315. elif state.get("query_result"):
  316. # 有数据但没有摘要(摘要被配置禁用)
  317. query_result = state.get("query_result")
  318. row_count = query_result.get("row_count", 0)
  319. # 构建基本响应,不包含summary字段和response字段
  320. # 用户应该直接从query_result.columns和query_result.rows获取数据
  321. state["final_response"] = {
  322. "success": True,
  323. "type": "DATABASE",
  324. "sql": state.get("sql"),
  325. "query_result": query_result, # 获取query_result字段
  326. "execution_path": state["execution_path"],
  327. "classification_info": {
  328. "confidence": state["classification_confidence"],
  329. "reason": state["classification_reason"],
  330. "method": state["classification_method"]
  331. }
  332. }
  333. else:
  334. # 数据库查询失败,没有任何结果
  335. state["final_response"] = {
  336. "success": False,
  337. "error": state.get("error", "数据库查询未完成"),
  338. "type": "DATABASE",
  339. "sql": state.get("sql"),
  340. "execution_path": state["execution_path"]
  341. }
  342. else:
  343. # 聊天类型
  344. state["final_response"] = {
  345. "success": True,
  346. "response": state.get("chat_response", ""),
  347. "type": "CHAT",
  348. "execution_path": state["execution_path"],
  349. "classification_info": {
  350. "confidence": state["classification_confidence"],
  351. "reason": state["classification_reason"],
  352. "method": state["classification_method"]
  353. }
  354. }
  355. print(f"[FORMAT_NODE] 响应格式化完成")
  356. return state
  357. except Exception as e:
  358. print(f"[ERROR] 响应格式化异常: {str(e)}")
  359. state["final_response"] = {
  360. "success": False,
  361. "error": f"响应格式化异常: {str(e)}",
  362. "error_code": 500,
  363. "execution_path": state["execution_path"]
  364. }
  365. return state
  366. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  367. """
  368. 分类后的路由决策
  369. 完全信任QuestionClassifier的决策:
  370. - DATABASE类型 → 数据库Agent
  371. - CHAT和UNCERTAIN类型 → 聊天Agent
  372. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  373. """
  374. question_type = state["question_type"]
  375. confidence = state["classification_confidence"]
  376. print(f"[ROUTE] 分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  377. if question_type == "DATABASE":
  378. return "DATABASE"
  379. else:
  380. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  381. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  382. return "CHAT"
  383. def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
  384. """
  385. 统一的问题处理入口
  386. Args:
  387. question: 用户问题
  388. session_id: 会话ID
  389. Returns:
  390. Dict包含完整的处理结果
  391. """
  392. try:
  393. print(f"[CITU_AGENT] 开始处理问题: {question}")
  394. # 初始化状态
  395. initial_state = self._create_initial_state(question, session_id)
  396. # 执行工作流
  397. final_state = self.workflow.invoke(
  398. initial_state,
  399. config={
  400. "configurable": {"session_id": session_id}
  401. } if session_id else None
  402. )
  403. # 提取最终结果
  404. result = final_state["final_response"]
  405. print(f"[CITU_AGENT] 问题处理完成: {result.get('success', False)}")
  406. return result
  407. except Exception as e:
  408. print(f"[ERROR] Agent执行异常: {str(e)}")
  409. return {
  410. "success": False,
  411. "error": f"Agent系统异常: {str(e)}",
  412. "error_code": 500,
  413. "execution_path": ["error"]
  414. }
  415. def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
  416. """创建初始状态 - 支持路由模式"""
  417. try:
  418. from app_config import QUESTION_ROUTING_MODE
  419. except ImportError:
  420. QUESTION_ROUTING_MODE = "hybrid"
  421. return AgentState(
  422. # 输入信息
  423. question=question,
  424. session_id=session_id,
  425. # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
  426. question_type="UNCERTAIN",
  427. classification_confidence=0.0,
  428. classification_reason="",
  429. classification_method="",
  430. # 数据库查询流程状态
  431. sql=None,
  432. sql_generation_attempts=0,
  433. query_result=None,
  434. summary=None,
  435. # 聊天响应
  436. chat_response=None,
  437. # 最终输出
  438. final_response={},
  439. # 错误处理
  440. error=None,
  441. error_code=None,
  442. # 流程控制
  443. current_step="initialized",
  444. execution_path=["start"],
  445. retry_count=0,
  446. max_retries=3,
  447. # 调试信息
  448. debug_info={},
  449. # 路由模式
  450. routing_mode=QUESTION_ROUTING_MODE
  451. )
  452. def health_check(self) -> Dict[str, Any]:
  453. """健康检查"""
  454. try:
  455. # 从配置获取健康检查参数
  456. from agent.config import get_nested_config
  457. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  458. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  459. if enable_full_test:
  460. # 完整流程测试
  461. test_result = self.process_question(test_question, "health_check")
  462. return {
  463. "status": "healthy" if test_result.get("success") else "degraded",
  464. "test_result": test_result.get("success", False),
  465. "workflow_compiled": self.workflow is not None,
  466. "tools_count": len(self.tools),
  467. "agent_reuse_enabled": False,
  468. "message": "Agent健康检查完成"
  469. }
  470. else:
  471. # 简单检查
  472. return {
  473. "status": "healthy",
  474. "test_result": True,
  475. "workflow_compiled": self.workflow is not None,
  476. "tools_count": len(self.tools),
  477. "agent_reuse_enabled": False,
  478. "message": "Agent简单健康检查完成"
  479. }
  480. except Exception as e:
  481. return {
  482. "status": "unhealthy",
  483. "error": str(e),
  484. "workflow_compiled": self.workflow is not None,
  485. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  486. "agent_reuse_enabled": False,
  487. "message": "Agent健康检查失败"
  488. }