|
@@ -35,39 +35,119 @@ class CituLangGraphAgent:
|
|
print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
|
|
print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
|
|
|
|
|
|
def _create_workflow(self) -> StateGraph:
|
|
def _create_workflow(self) -> StateGraph:
|
|
- """创建LangGraph工作流"""
|
|
|
|
- workflow = StateGraph(AgentState)
|
|
|
|
-
|
|
|
|
- # 添加节点
|
|
|
|
- workflow.add_node("classify_question", self._classify_question_node)
|
|
|
|
- workflow.add_node("agent_chat", self._agent_chat_node)
|
|
|
|
- workflow.add_node("agent_database", self._agent_database_node)
|
|
|
|
- workflow.add_node("format_response", self._format_response_node)
|
|
|
|
-
|
|
|
|
- # 设置入口点
|
|
|
|
- workflow.set_entry_point("classify_question")
|
|
|
|
|
|
+ """根据路由模式创建不同的工作流"""
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+ print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
+ except ImportError:
|
|
|
|
+ QUESTION_ROUTING_MODE = "hybrid"
|
|
|
|
+ print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
|
|
- # 添加条件边:分类后的路由
|
|
|
|
- # 完全信任QuestionClassifier的决策,不再进行二次判断
|
|
|
|
- workflow.add_conditional_edges(
|
|
|
|
- "classify_question",
|
|
|
|
- self._route_after_classification,
|
|
|
|
- {
|
|
|
|
- "DATABASE": "agent_database",
|
|
|
|
- "CHAT": "agent_chat" # CHAT分支处理所有非DATABASE的情况(包括UNCERTAIN)
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
|
|
+ workflow = StateGraph(AgentState)
|
|
|
|
|
|
- # 添加边
|
|
|
|
- workflow.add_edge("agent_chat", "format_response")
|
|
|
|
- workflow.add_edge("agent_database", "format_response")
|
|
|
|
- workflow.add_edge("format_response", END)
|
|
|
|
|
|
+ # 根据路由模式创建不同的工作流
|
|
|
|
+ if QUESTION_ROUTING_MODE == "database_direct":
|
|
|
|
+ # 直接数据库模式:跳过分类,直接进入数据库处理
|
|
|
|
+ workflow.add_node("init_direct_database", self._init_direct_database_node)
|
|
|
|
+ workflow.add_node("agent_database", self._agent_database_node)
|
|
|
|
+ workflow.add_node("format_response", self._format_response_node)
|
|
|
|
+
|
|
|
|
+ workflow.set_entry_point("init_direct_database")
|
|
|
|
+ workflow.add_edge("init_direct_database", "agent_database")
|
|
|
|
+ workflow.add_edge("agent_database", "format_response")
|
|
|
|
+ workflow.add_edge("format_response", END)
|
|
|
|
+
|
|
|
|
+ elif QUESTION_ROUTING_MODE == "chat_direct":
|
|
|
|
+ # 直接聊天模式:跳过分类,直接进入聊天处理
|
|
|
|
+ workflow.add_node("init_direct_chat", self._init_direct_chat_node)
|
|
|
|
+ workflow.add_node("agent_chat", self._agent_chat_node)
|
|
|
|
+ workflow.add_node("format_response", self._format_response_node)
|
|
|
|
+
|
|
|
|
+ workflow.set_entry_point("init_direct_chat")
|
|
|
|
+ workflow.add_edge("init_direct_chat", "agent_chat")
|
|
|
|
+ workflow.add_edge("agent_chat", "format_response")
|
|
|
|
+ workflow.add_edge("format_response", END)
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ # 其他模式(hybrid, llm_only):使用原有的分类工作流
|
|
|
|
+ workflow.add_node("classify_question", self._classify_question_node)
|
|
|
|
+ workflow.add_node("agent_chat", self._agent_chat_node)
|
|
|
|
+ workflow.add_node("agent_database", self._agent_database_node)
|
|
|
|
+ workflow.add_node("format_response", self._format_response_node)
|
|
|
|
+
|
|
|
|
+ workflow.set_entry_point("classify_question")
|
|
|
|
+
|
|
|
|
+ # 添加条件边:分类后的路由
|
|
|
|
+ workflow.add_conditional_edges(
|
|
|
|
+ "classify_question",
|
|
|
|
+ self._route_after_classification,
|
|
|
|
+ {
|
|
|
|
+ "DATABASE": "agent_database",
|
|
|
|
+ "CHAT": "agent_chat"
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ workflow.add_edge("agent_chat", "format_response")
|
|
|
|
+ workflow.add_edge("agent_database", "format_response")
|
|
|
|
+ workflow.add_edge("format_response", END)
|
|
|
|
|
|
return workflow.compile()
|
|
return workflow.compile()
|
|
|
|
|
|
|
|
+ def _init_direct_database_node(self, state: AgentState) -> AgentState:
|
|
|
|
+ """初始化直接数据库模式的状态"""
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+
|
|
|
|
+ # 设置直接数据库模式的分类状态
|
|
|
|
+ state["question_type"] = "DATABASE"
|
|
|
|
+ state["classification_confidence"] = 1.0
|
|
|
|
+ state["classification_reason"] = "配置为直接数据库查询模式"
|
|
|
|
+ state["classification_method"] = "direct_database"
|
|
|
|
+ state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
|
|
+ state["current_step"] = "direct_database_init"
|
|
|
|
+ state["execution_path"].append("init_direct_database")
|
|
|
|
+
|
|
|
|
+ print(f"[DIRECT_DATABASE] 直接数据库模式初始化完成")
|
|
|
|
+
|
|
|
|
+ return state
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"[ERROR] 直接数据库模式初始化异常: {str(e)}")
|
|
|
|
+ state["error"] = f"直接数据库模式初始化失败: {str(e)}"
|
|
|
|
+ state["error_code"] = 500
|
|
|
|
+ state["execution_path"].append("init_direct_database_error")
|
|
|
|
+ return state
|
|
|
|
+
|
|
|
|
+ def _init_direct_chat_node(self, state: AgentState) -> AgentState:
|
|
|
|
+ """初始化直接聊天模式的状态"""
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+
|
|
|
|
+ # 设置直接聊天模式的分类状态
|
|
|
|
+ state["question_type"] = "CHAT"
|
|
|
|
+ state["classification_confidence"] = 1.0
|
|
|
|
+ state["classification_reason"] = "配置为直接聊天模式"
|
|
|
|
+ state["classification_method"] = "direct_chat"
|
|
|
|
+ state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
|
|
+ state["current_step"] = "direct_chat_init"
|
|
|
|
+ state["execution_path"].append("init_direct_chat")
|
|
|
|
+
|
|
|
|
+ print(f"[DIRECT_CHAT] 直接聊天模式初始化完成")
|
|
|
|
+
|
|
|
|
+ return state
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"[ERROR] 直接聊天模式初始化异常: {str(e)}")
|
|
|
|
+ state["error"] = f"直接聊天模式初始化失败: {str(e)}"
|
|
|
|
+ state["error_code"] = 500
|
|
|
|
+ state["execution_path"].append("init_direct_chat_error")
|
|
|
|
+ return state
|
|
|
|
+
|
|
def _classify_question_node(self, state: AgentState) -> AgentState:
|
|
def _classify_question_node(self, state: AgentState) -> AgentState:
|
|
- """问题分类节点"""
|
|
|
|
|
|
+ """问题分类节点 - 支持路由模式"""
|
|
try:
|
|
try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+
|
|
print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
|
|
print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
|
|
|
|
|
|
classification_result = self.classifier.classify(state["question"])
|
|
classification_result = self.classifier.classify(state["question"])
|
|
@@ -77,10 +157,12 @@ class CituLangGraphAgent:
|
|
state["classification_confidence"] = classification_result.confidence
|
|
state["classification_confidence"] = classification_result.confidence
|
|
state["classification_reason"] = classification_result.reason
|
|
state["classification_reason"] = classification_result.reason
|
|
state["classification_method"] = classification_result.method
|
|
state["classification_method"] = classification_result.method
|
|
|
|
+ state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
state["current_step"] = "classified"
|
|
state["current_step"] = "classified"
|
|
state["execution_path"].append("classify")
|
|
state["execution_path"].append("classify")
|
|
|
|
|
|
print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
|
|
print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
|
|
|
|
+ print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
|
|
|
|
|
|
return state
|
|
return state
|
|
|
|
|
|
@@ -90,7 +172,7 @@ class CituLangGraphAgent:
|
|
state["error_code"] = 500
|
|
state["error_code"] = 500
|
|
state["execution_path"].append("classify_error")
|
|
state["execution_path"].append("classify_error")
|
|
return state
|
|
return state
|
|
-
|
|
|
|
|
|
+
|
|
def _agent_database_node(self, state: AgentState) -> AgentState:
|
|
def _agent_database_node(self, state: AgentState) -> AgentState:
|
|
"""数据库Agent节点 - 直接工具调用模式"""
|
|
"""数据库Agent节点 - 直接工具调用模式"""
|
|
try:
|
|
try:
|
|
@@ -407,14 +489,19 @@ class CituLangGraphAgent:
|
|
}
|
|
}
|
|
|
|
|
|
def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
|
|
def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
|
|
- """创建初始状态"""
|
|
|
|
|
|
+ """创建初始状态 - 支持路由模式"""
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+ except ImportError:
|
|
|
|
+ QUESTION_ROUTING_MODE = "hybrid"
|
|
|
|
+
|
|
return AgentState(
|
|
return AgentState(
|
|
# 输入信息
|
|
# 输入信息
|
|
question=question,
|
|
question=question,
|
|
session_id=session_id,
|
|
session_id=session_id,
|
|
|
|
|
|
- # 分类结果
|
|
|
|
- question_type="",
|
|
|
|
|
|
+ # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
|
|
|
|
+ question_type="UNCERTAIN",
|
|
classification_confidence=0.0,
|
|
classification_confidence=0.0,
|
|
classification_reason="",
|
|
classification_reason="",
|
|
classification_method="",
|
|
classification_method="",
|
|
@@ -436,13 +523,16 @@ class CituLangGraphAgent:
|
|
error_code=None,
|
|
error_code=None,
|
|
|
|
|
|
# 流程控制
|
|
# 流程控制
|
|
- current_step="start",
|
|
|
|
- execution_path=[],
|
|
|
|
|
|
+ current_step="initialized",
|
|
|
|
+ execution_path=["start"],
|
|
retry_count=0,
|
|
retry_count=0,
|
|
- max_retries=2,
|
|
|
|
|
|
+ max_retries=3,
|
|
|
|
|
|
# 调试信息
|
|
# 调试信息
|
|
- debug_info={}
|
|
|
|
|
|
+ debug_info={},
|
|
|
|
+
|
|
|
|
+ # 路由模式
|
|
|
|
+ routing_mode=QUESTION_ROUTING_MODE
|
|
)
|
|
)
|
|
|
|
|
|
def health_check(self) -> Dict[str, Any]:
|
|
def health_check(self) -> Dict[str, Any]:
|