citu_agent.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # agent/citu_agent.py
  2. from typing import Dict, Any, Literal
  3. from langgraph.graph import StateGraph, END
  4. from langchain.agents import AgentExecutor, create_openai_tools_agent
  5. from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
  6. from langchain_core.messages import SystemMessage, HumanMessage
  7. from agent.state import AgentState
  8. from agent.classifier import QuestionClassifier
  9. from agent.tools import TOOLS, generate_sql, execute_sql, generate_summary, general_chat
  10. from agent.utils import get_compatible_llm
  11. from app_config import ENABLE_RESULT_SUMMARY
  12. class CituLangGraphAgent:
  13. """Citu LangGraph智能助手主类 - 使用@tool装饰器 + Agent工具调用"""
  14. def __init__(self):
  15. # 加载配置
  16. try:
  17. from agent.config import get_current_config, get_nested_config
  18. self.config = get_current_config()
  19. print("[CITU_AGENT] 加载Agent配置完成")
  20. except ImportError:
  21. self.config = {}
  22. print("[CITU_AGENT] 配置文件不可用,使用默认配置")
  23. self.classifier = QuestionClassifier()
  24. self.tools = TOOLS
  25. self.llm = get_compatible_llm()
  26. # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
  27. print("[CITU_AGENT] 使用直接工具调用模式")
  28. self.workflow = self._create_workflow()
  29. print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
  30. def _create_workflow(self) -> StateGraph:
  31. """根据路由模式创建不同的工作流"""
  32. try:
  33. from app_config import QUESTION_ROUTING_MODE
  34. print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
  35. except ImportError:
  36. QUESTION_ROUTING_MODE = "hybrid"
  37. print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  38. workflow = StateGraph(AgentState)
  39. # 根据路由模式创建不同的工作流
  40. if QUESTION_ROUTING_MODE == "database_direct":
  41. # 直接数据库模式:跳过分类,直接进入数据库处理
  42. workflow.add_node("init_direct_database", self._init_direct_database_node)
  43. workflow.add_node("agent_database", self._agent_database_node)
  44. workflow.add_node("format_response", self._format_response_node)
  45. workflow.set_entry_point("init_direct_database")
  46. workflow.add_edge("init_direct_database", "agent_database")
  47. workflow.add_edge("agent_database", "format_response")
  48. workflow.add_edge("format_response", END)
  49. elif QUESTION_ROUTING_MODE == "chat_direct":
  50. # 直接聊天模式:跳过分类,直接进入聊天处理
  51. workflow.add_node("init_direct_chat", self._init_direct_chat_node)
  52. workflow.add_node("agent_chat", self._agent_chat_node)
  53. workflow.add_node("format_response", self._format_response_node)
  54. workflow.set_entry_point("init_direct_chat")
  55. workflow.add_edge("init_direct_chat", "agent_chat")
  56. workflow.add_edge("agent_chat", "format_response")
  57. workflow.add_edge("format_response", END)
  58. else:
  59. # 其他模式(hybrid, llm_only):使用原有的分类工作流
  60. workflow.add_node("classify_question", self._classify_question_node)
  61. workflow.add_node("agent_chat", self._agent_chat_node)
  62. workflow.add_node("agent_database", self._agent_database_node)
  63. workflow.add_node("format_response", self._format_response_node)
  64. workflow.set_entry_point("classify_question")
  65. # 添加条件边:分类后的路由
  66. workflow.add_conditional_edges(
  67. "classify_question",
  68. self._route_after_classification,
  69. {
  70. "DATABASE": "agent_database",
  71. "CHAT": "agent_chat"
  72. }
  73. )
  74. workflow.add_edge("agent_chat", "format_response")
  75. workflow.add_edge("agent_database", "format_response")
  76. workflow.add_edge("format_response", END)
  77. return workflow.compile()
  78. def _init_direct_database_node(self, state: AgentState) -> AgentState:
  79. """初始化直接数据库模式的状态"""
  80. try:
  81. from app_config import QUESTION_ROUTING_MODE
  82. # 设置直接数据库模式的分类状态
  83. state["question_type"] = "DATABASE"
  84. state["classification_confidence"] = 1.0
  85. state["classification_reason"] = "配置为直接数据库查询模式"
  86. state["classification_method"] = "direct_database"
  87. state["routing_mode"] = QUESTION_ROUTING_MODE
  88. state["current_step"] = "direct_database_init"
  89. state["execution_path"].append("init_direct_database")
  90. print(f"[DIRECT_DATABASE] 直接数据库模式初始化完成")
  91. return state
  92. except Exception as e:
  93. print(f"[ERROR] 直接数据库模式初始化异常: {str(e)}")
  94. state["error"] = f"直接数据库模式初始化失败: {str(e)}"
  95. state["error_code"] = 500
  96. state["execution_path"].append("init_direct_database_error")
  97. return state
  98. def _init_direct_chat_node(self, state: AgentState) -> AgentState:
  99. """初始化直接聊天模式的状态"""
  100. try:
  101. from app_config import QUESTION_ROUTING_MODE
  102. # 设置直接聊天模式的分类状态
  103. state["question_type"] = "CHAT"
  104. state["classification_confidence"] = 1.0
  105. state["classification_reason"] = "配置为直接聊天模式"
  106. state["classification_method"] = "direct_chat"
  107. state["routing_mode"] = QUESTION_ROUTING_MODE
  108. state["current_step"] = "direct_chat_init"
  109. state["execution_path"].append("init_direct_chat")
  110. print(f"[DIRECT_CHAT] 直接聊天模式初始化完成")
  111. return state
  112. except Exception as e:
  113. print(f"[ERROR] 直接聊天模式初始化异常: {str(e)}")
  114. state["error"] = f"直接聊天模式初始化失败: {str(e)}"
  115. state["error_code"] = 500
  116. state["execution_path"].append("init_direct_chat_error")
  117. return state
  118. def _classify_question_node(self, state: AgentState) -> AgentState:
  119. """问题分类节点 - 支持路由模式"""
  120. try:
  121. from app_config import QUESTION_ROUTING_MODE
  122. print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
  123. classification_result = self.classifier.classify(state["question"])
  124. # 更新状态
  125. state["question_type"] = classification_result.question_type
  126. state["classification_confidence"] = classification_result.confidence
  127. state["classification_reason"] = classification_result.reason
  128. state["classification_method"] = classification_result.method
  129. state["routing_mode"] = QUESTION_ROUTING_MODE
  130. state["current_step"] = "classified"
  131. state["execution_path"].append("classify")
  132. print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
  133. print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
  134. return state
  135. except Exception as e:
  136. print(f"[ERROR] 问题分类异常: {str(e)}")
  137. state["error"] = f"问题分类失败: {str(e)}"
  138. state["error_code"] = 500
  139. state["execution_path"].append("classify_error")
  140. return state
  141. def _agent_database_node(self, state: AgentState) -> AgentState:
  142. """数据库Agent节点 - 直接工具调用模式"""
  143. try:
  144. print(f"[DATABASE_AGENT] 开始处理数据库查询: {state['question']}")
  145. question = state["question"]
  146. # 步骤1:生成SQL
  147. print(f"[DATABASE_AGENT] 步骤1:生成SQL")
  148. sql_result = generate_sql.invoke({"question": question, "allow_llm_to_see_data": True})
  149. if not sql_result.get("success"):
  150. print(f"[DATABASE_AGENT] SQL生成失败: {sql_result.get('error')}")
  151. state["error"] = sql_result.get("error", "SQL生成失败")
  152. state["error_code"] = 500
  153. state["current_step"] = "database_error"
  154. state["execution_path"].append("agent_database_error")
  155. return state
  156. sql = sql_result.get("sql")
  157. state["sql"] = sql
  158. print(f"[DATABASE_AGENT] SQL生成成功: {sql}")
  159. # 步骤1.5:检查是否为解释性响应而非SQL
  160. error_type = sql_result.get("error_type")
  161. if error_type == "llm_explanation":
  162. # LLM返回了解释性文本,直接作为最终答案
  163. explanation = sql_result.get("error", "")
  164. state["chat_response"] = explanation + " 请尝试提问其它问题。"
  165. state["current_step"] = "database_completed"
  166. state["execution_path"].append("agent_database")
  167. print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
  168. return state
  169. # 额外验证:检查SQL格式(防止工具误判)
  170. from agent.utils import _is_valid_sql_format
  171. if not _is_valid_sql_format(sql):
  172. # 内容看起来不是SQL,当作解释性响应处理
  173. state["chat_response"] = sql + " 请尝试提问其它问题。"
  174. state["current_step"] = "database_completed"
  175. state["execution_path"].append("agent_database")
  176. print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
  177. return state
  178. # 步骤2:执行SQL
  179. print(f"[DATABASE_AGENT] 步骤2:执行SQL")
  180. execute_result = execute_sql.invoke({"sql": sql})
  181. if not execute_result.get("success"):
  182. print(f"[DATABASE_AGENT] SQL执行失败: {execute_result.get('error')}")
  183. state["error"] = execute_result.get("error", "SQL执行失败")
  184. state["error_code"] = 500
  185. state["current_step"] = "database_error"
  186. state["execution_path"].append("agent_database_error")
  187. return state
  188. data_result = execute_result.get("data_result")
  189. state["data_result"] = data_result
  190. print(f"[DATABASE_AGENT] SQL执行成功,返回 {data_result.get('row_count', 0)} 行数据")
  191. # 步骤3:生成摘要(可通过配置控制,仅在有数据时生成)
  192. if ENABLE_RESULT_SUMMARY and data_result.get('row_count', 0) > 0:
  193. print(f"[DATABASE_AGENT] 步骤3:生成摘要")
  194. summary_result = generate_summary.invoke({
  195. "question": question,
  196. "data_result": data_result,
  197. "sql": sql
  198. })
  199. if not summary_result.get("success"):
  200. print(f"[DATABASE_AGENT] 摘要生成失败: {summary_result.get('message')}")
  201. # 摘要生成失败不是致命错误,使用默认摘要
  202. state["summary"] = f"查询执行完成,共返回 {data_result.get('row_count', 0)} 条记录。"
  203. else:
  204. state["summary"] = summary_result.get("summary")
  205. print(f"[DATABASE_AGENT] 摘要生成成功")
  206. else:
  207. print(f"[DATABASE_AGENT] 跳过摘要生成(ENABLE_RESULT_SUMMARY={ENABLE_RESULT_SUMMARY},数据行数={data_result.get('row_count', 0)})")
  208. # 不生成摘要时,不设置summary字段,让格式化响应节点决定如何处理
  209. state["current_step"] = "database_completed"
  210. state["execution_path"].append("agent_database")
  211. print(f"[DATABASE_AGENT] 数据库查询完成")
  212. return state
  213. except Exception as e:
  214. print(f"[ERROR] 数据库Agent异常: {str(e)}")
  215. import traceback
  216. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  217. state["error"] = f"数据库查询失败: {str(e)}"
  218. state["error_code"] = 500
  219. state["current_step"] = "database_error"
  220. state["execution_path"].append("agent_database_error")
  221. return state
  222. def _agent_chat_node(self, state: AgentState) -> AgentState:
  223. """聊天Agent节点 - 直接工具调用模式"""
  224. try:
  225. print(f"[CHAT_AGENT] 开始处理聊天: {state['question']}")
  226. question = state["question"]
  227. # 构建上下文
  228. enable_context_injection = self.config.get("chat_agent", {}).get("enable_context_injection", True)
  229. context = None
  230. if enable_context_injection and state.get("classification_reason"):
  231. context = f"分类原因: {state['classification_reason']}"
  232. # 直接调用general_chat工具
  233. print(f"[CHAT_AGENT] 调用general_chat工具")
  234. chat_result = general_chat.invoke({
  235. "question": question,
  236. "context": context
  237. })
  238. if chat_result.get("success"):
  239. state["chat_response"] = chat_result.get("response", "")
  240. print(f"[CHAT_AGENT] 聊天处理成功")
  241. else:
  242. # 处理失败,使用备用响应
  243. state["chat_response"] = chat_result.get("response", "抱歉,我暂时无法处理您的问题。请稍后再试。")
  244. print(f"[CHAT_AGENT] 聊天处理失败,使用备用响应: {chat_result.get('error')}")
  245. state["current_step"] = "chat_completed"
  246. state["execution_path"].append("agent_chat")
  247. print(f"[CHAT_AGENT] 聊天处理完成")
  248. return state
  249. except Exception as e:
  250. print(f"[ERROR] 聊天Agent异常: {str(e)}")
  251. import traceback
  252. print(f"[ERROR] 详细错误信息: {traceback.format_exc()}")
  253. state["chat_response"] = "抱歉,我暂时无法处理您的问题。请稍后再试,或者尝试询问数据相关的问题。"
  254. state["current_step"] = "chat_error"
  255. state["execution_path"].append("agent_chat_error")
  256. return state
  257. def _format_response_node(self, state: AgentState) -> AgentState:
  258. """格式化最终响应节点"""
  259. try:
  260. print(f"[FORMAT_NODE] 开始格式化响应,问题类型: {state['question_type']}")
  261. state["current_step"] = "completed"
  262. state["execution_path"].append("format_response")
  263. # 根据问题类型和执行状态格式化响应
  264. if state.get("error"):
  265. # 有错误的情况
  266. state["final_response"] = {
  267. "success": False,
  268. "error": state["error"],
  269. "error_code": state.get("error_code", 500),
  270. "question_type": state["question_type"],
  271. "execution_path": state["execution_path"],
  272. "classification_info": {
  273. "confidence": state.get("classification_confidence", 0),
  274. "reason": state.get("classification_reason", ""),
  275. "method": state.get("classification_method", "")
  276. }
  277. }
  278. elif state["question_type"] == "DATABASE":
  279. # 数据库查询类型
  280. if state.get("chat_response"):
  281. # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响)
  282. state["final_response"] = {
  283. "success": True,
  284. "response": state["chat_response"],
  285. "type": "DATABASE",
  286. "sql": state.get("sql"),
  287. "query_result": state.get("data_result"), # 字段重命名:data_result → query_result
  288. "execution_path": state["execution_path"],
  289. "classification_info": {
  290. "confidence": state["classification_confidence"],
  291. "reason": state["classification_reason"],
  292. "method": state["classification_method"]
  293. }
  294. }
  295. elif state.get("summary"):
  296. # 正常的数据库查询结果,有摘要的情况
  297. # 不将summary复制到response,让response保持为空
  298. state["final_response"] = {
  299. "success": True,
  300. "type": "DATABASE",
  301. "sql": state.get("sql"),
  302. "query_result": state.get("data_result"), # 字段重命名:data_result → query_result
  303. "summary": state["summary"],
  304. "execution_path": state["execution_path"],
  305. "classification_info": {
  306. "confidence": state["classification_confidence"],
  307. "reason": state["classification_reason"],
  308. "method": state["classification_method"]
  309. }
  310. }
  311. elif state.get("data_result"):
  312. # 有数据但没有摘要(摘要被配置禁用)
  313. data_result = state.get("data_result")
  314. row_count = data_result.get("row_count", 0)
  315. # 构建基本响应,不包含summary字段和response字段
  316. # 用户应该直接从query_result.columns和query_result.rows获取数据
  317. state["final_response"] = {
  318. "success": True,
  319. "type": "DATABASE",
  320. "sql": state.get("sql"),
  321. "query_result": data_result, # 字段重命名:data_result → query_result
  322. "execution_path": state["execution_path"],
  323. "classification_info": {
  324. "confidence": state["classification_confidence"],
  325. "reason": state["classification_reason"],
  326. "method": state["classification_method"]
  327. }
  328. }
  329. else:
  330. # 数据库查询失败,没有任何结果
  331. state["final_response"] = {
  332. "success": False,
  333. "error": state.get("error", "数据库查询未完成"),
  334. "type": "DATABASE",
  335. "sql": state.get("sql"),
  336. "execution_path": state["execution_path"]
  337. }
  338. else:
  339. # 聊天类型
  340. state["final_response"] = {
  341. "success": True,
  342. "response": state.get("chat_response", ""),
  343. "type": "CHAT",
  344. "execution_path": state["execution_path"],
  345. "classification_info": {
  346. "confidence": state["classification_confidence"],
  347. "reason": state["classification_reason"],
  348. "method": state["classification_method"]
  349. }
  350. }
  351. print(f"[FORMAT_NODE] 响应格式化完成")
  352. return state
  353. except Exception as e:
  354. print(f"[ERROR] 响应格式化异常: {str(e)}")
  355. state["final_response"] = {
  356. "success": False,
  357. "error": f"响应格式化异常: {str(e)}",
  358. "error_code": 500,
  359. "execution_path": state["execution_path"]
  360. }
  361. return state
  362. def _route_after_classification(self, state: AgentState) -> Literal["DATABASE", "CHAT"]:
  363. """
  364. 分类后的路由决策
  365. 完全信任QuestionClassifier的决策:
  366. - DATABASE类型 → 数据库Agent
  367. - CHAT和UNCERTAIN类型 → 聊天Agent
  368. 这样避免了双重决策的冲突,所有分类逻辑都集中在QuestionClassifier中
  369. """
  370. question_type = state["question_type"]
  371. confidence = state["classification_confidence"]
  372. print(f"[ROUTE] 分类路由: {question_type}, 置信度: {confidence} (完全信任分类器决策)")
  373. if question_type == "DATABASE":
  374. return "DATABASE"
  375. else:
  376. # 将 "CHAT" 和 "UNCERTAIN" 类型都路由到聊天流程
  377. # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
  378. return "CHAT"
  379. def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
  380. """
  381. 统一的问题处理入口
  382. Args:
  383. question: 用户问题
  384. session_id: 会话ID
  385. Returns:
  386. Dict包含完整的处理结果
  387. """
  388. try:
  389. print(f"[CITU_AGENT] 开始处理问题: {question}")
  390. # 初始化状态
  391. initial_state = self._create_initial_state(question, session_id)
  392. # 执行工作流
  393. final_state = self.workflow.invoke(
  394. initial_state,
  395. config={
  396. "configurable": {"session_id": session_id}
  397. } if session_id else None
  398. )
  399. # 提取最终结果
  400. result = final_state["final_response"]
  401. print(f"[CITU_AGENT] 问题处理完成: {result.get('success', False)}")
  402. return result
  403. except Exception as e:
  404. print(f"[ERROR] Agent执行异常: {str(e)}")
  405. return {
  406. "success": False,
  407. "error": f"Agent系统异常: {str(e)}",
  408. "error_code": 500,
  409. "execution_path": ["error"]
  410. }
  411. def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
  412. """创建初始状态 - 支持路由模式"""
  413. try:
  414. from app_config import QUESTION_ROUTING_MODE
  415. except ImportError:
  416. QUESTION_ROUTING_MODE = "hybrid"
  417. return AgentState(
  418. # 输入信息
  419. question=question,
  420. session_id=session_id,
  421. # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
  422. question_type="UNCERTAIN",
  423. classification_confidence=0.0,
  424. classification_reason="",
  425. classification_method="",
  426. # 数据库查询流程状态
  427. sql=None,
  428. sql_generation_attempts=0,
  429. data_result=None,
  430. summary=None,
  431. # 聊天响应
  432. chat_response=None,
  433. # 最终输出
  434. final_response={},
  435. # 错误处理
  436. error=None,
  437. error_code=None,
  438. # 流程控制
  439. current_step="initialized",
  440. execution_path=["start"],
  441. retry_count=0,
  442. max_retries=3,
  443. # 调试信息
  444. debug_info={},
  445. # 路由模式
  446. routing_mode=QUESTION_ROUTING_MODE
  447. )
  448. def health_check(self) -> Dict[str, Any]:
  449. """健康检查"""
  450. try:
  451. # 从配置获取健康检查参数
  452. from agent.config import get_nested_config
  453. test_question = get_nested_config(self.config, "health_check.test_question", "你好")
  454. enable_full_test = get_nested_config(self.config, "health_check.enable_full_test", True)
  455. if enable_full_test:
  456. # 完整流程测试
  457. test_result = self.process_question(test_question, "health_check")
  458. return {
  459. "status": "healthy" if test_result.get("success") else "degraded",
  460. "test_result": test_result.get("success", False),
  461. "workflow_compiled": self.workflow is not None,
  462. "tools_count": len(self.tools),
  463. "agent_reuse_enabled": False,
  464. "message": "Agent健康检查完成"
  465. }
  466. else:
  467. # 简单检查
  468. return {
  469. "status": "healthy",
  470. "test_result": True,
  471. "workflow_compiled": self.workflow is not None,
  472. "tools_count": len(self.tools),
  473. "agent_reuse_enabled": False,
  474. "message": "Agent简单健康检查完成"
  475. }
  476. except Exception as e:
  477. return {
  478. "status": "unhealthy",
  479. "error": str(e),
  480. "workflow_compiled": self.workflow is not None,
  481. "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
  482. "agent_reuse_enabled": False,
  483. "message": "Agent健康检查失败"
  484. }