|
@@ -31,17 +31,23 @@ class CituLangGraphAgent:
|
|
# 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
|
|
# 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
|
|
print("[CITU_AGENT] 使用直接工具调用模式")
|
|
print("[CITU_AGENT] 使用直接工具调用模式")
|
|
|
|
|
|
- self.workflow = self._create_workflow()
|
|
|
|
|
|
+ # 不在构造时创建workflow,改为动态创建以支持路由模式参数
|
|
|
|
+ # self.workflow = self._create_workflow()
|
|
print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
|
|
print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
|
|
|
|
|
|
- def _create_workflow(self) -> StateGraph:
|
|
|
|
|
|
+ def _create_workflow(self, routing_mode: str = None) -> StateGraph:
|
|
"""根据路由模式创建不同的工作流"""
|
|
"""根据路由模式创建不同的工作流"""
|
|
- try:
|
|
|
|
- from app_config import QUESTION_ROUTING_MODE
|
|
|
|
- print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
- except ImportError:
|
|
|
|
- QUESTION_ROUTING_MODE = "hybrid"
|
|
|
|
- print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
|
|
+ # 确定使用的路由模式
|
|
|
|
+ if routing_mode:
|
|
|
|
+ QUESTION_ROUTING_MODE = routing_mode
|
|
|
|
+ print(f"[CITU_AGENT] 创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
+ else:
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+ print(f"[CITU_AGENT] 创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
+ except ImportError:
|
|
|
|
+ QUESTION_ROUTING_MODE = "hybrid"
|
|
|
|
+ print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
|
|
workflow = StateGraph(AgentState)
|
|
workflow = StateGraph(AgentState)
|
|
|
|
|
|
@@ -96,14 +102,15 @@ class CituLangGraphAgent:
|
|
def _init_direct_database_node(self, state: AgentState) -> AgentState:
|
|
def _init_direct_database_node(self, state: AgentState) -> AgentState:
|
|
"""初始化直接数据库模式的状态"""
|
|
"""初始化直接数据库模式的状态"""
|
|
try:
|
|
try:
|
|
- from app_config import QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ # 从state中获取路由模式,而不是从配置文件读取
|
|
|
|
+ routing_mode = state.get("routing_mode", "database_direct")
|
|
|
|
|
|
# 设置直接数据库模式的分类状态
|
|
# 设置直接数据库模式的分类状态
|
|
state["question_type"] = "DATABASE"
|
|
state["question_type"] = "DATABASE"
|
|
state["classification_confidence"] = 1.0
|
|
state["classification_confidence"] = 1.0
|
|
state["classification_reason"] = "配置为直接数据库查询模式"
|
|
state["classification_reason"] = "配置为直接数据库查询模式"
|
|
state["classification_method"] = "direct_database"
|
|
state["classification_method"] = "direct_database"
|
|
- state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ state["routing_mode"] = routing_mode
|
|
state["current_step"] = "direct_database_init"
|
|
state["current_step"] = "direct_database_init"
|
|
state["execution_path"].append("init_direct_database")
|
|
state["execution_path"].append("init_direct_database")
|
|
|
|
|
|
@@ -121,14 +128,15 @@ class CituLangGraphAgent:
|
|
def _init_direct_chat_node(self, state: AgentState) -> AgentState:
|
|
def _init_direct_chat_node(self, state: AgentState) -> AgentState:
|
|
"""初始化直接聊天模式的状态"""
|
|
"""初始化直接聊天模式的状态"""
|
|
try:
|
|
try:
|
|
- from app_config import QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ # 从state中获取路由模式,而不是从配置文件读取
|
|
|
|
+ routing_mode = state.get("routing_mode", "chat_direct")
|
|
|
|
|
|
# 设置直接聊天模式的分类状态
|
|
# 设置直接聊天模式的分类状态
|
|
state["question_type"] = "CHAT"
|
|
state["question_type"] = "CHAT"
|
|
state["classification_confidence"] = 1.0
|
|
state["classification_confidence"] = 1.0
|
|
state["classification_reason"] = "配置为直接聊天模式"
|
|
state["classification_reason"] = "配置为直接聊天模式"
|
|
state["classification_method"] = "direct_chat"
|
|
state["classification_method"] = "direct_chat"
|
|
- state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ state["routing_mode"] = routing_mode
|
|
state["current_step"] = "direct_chat_init"
|
|
state["current_step"] = "direct_chat_init"
|
|
state["execution_path"].append("init_direct_chat")
|
|
state["execution_path"].append("init_direct_chat")
|
|
|
|
|
|
@@ -146,7 +154,8 @@ class CituLangGraphAgent:
|
|
def _classify_question_node(self, state: AgentState) -> AgentState:
|
|
def _classify_question_node(self, state: AgentState) -> AgentState:
|
|
"""问题分类节点 - 支持渐进式分类策略"""
|
|
"""问题分类节点 - 支持渐进式分类策略"""
|
|
try:
|
|
try:
|
|
- from app_config import QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ # 从state中获取路由模式,而不是从配置文件读取
|
|
|
|
+ routing_mode = state.get("routing_mode", "hybrid")
|
|
|
|
|
|
print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
|
|
print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
|
|
|
|
|
|
@@ -155,20 +164,20 @@ class CituLangGraphAgent:
|
|
if context_type:
|
|
if context_type:
|
|
print(f"[CLASSIFY_NODE] 检测到上下文类型: {context_type}")
|
|
print(f"[CLASSIFY_NODE] 检测到上下文类型: {context_type}")
|
|
|
|
|
|
- # 使用渐进式分类策略
|
|
|
|
- classification_result = self.classifier.classify(state["question"], context_type)
|
|
|
|
|
|
+ # 使用渐进式分类策略,传递路由模式
|
|
|
|
+ classification_result = self.classifier.classify(state["question"], context_type, routing_mode)
|
|
|
|
|
|
# 更新状态
|
|
# 更新状态
|
|
state["question_type"] = classification_result.question_type
|
|
state["question_type"] = classification_result.question_type
|
|
state["classification_confidence"] = classification_result.confidence
|
|
state["classification_confidence"] = classification_result.confidence
|
|
state["classification_reason"] = classification_result.reason
|
|
state["classification_reason"] = classification_result.reason
|
|
state["classification_method"] = classification_result.method
|
|
state["classification_method"] = classification_result.method
|
|
- state["routing_mode"] = QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ state["routing_mode"] = routing_mode
|
|
state["current_step"] = "classified"
|
|
state["current_step"] = "classified"
|
|
state["execution_path"].append("classify")
|
|
state["execution_path"].append("classify")
|
|
|
|
|
|
print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
|
|
print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
|
|
- print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
|
|
|
|
|
|
+ print(f"[CLASSIFY_NODE] 路由模式: {routing_mode}, 分类方法: {classification_result.method}")
|
|
|
|
|
|
return state
|
|
return state
|
|
|
|
|
|
@@ -370,13 +379,14 @@ class CituLangGraphAgent:
|
|
}
|
|
}
|
|
elif state.get("summary"):
|
|
elif state.get("summary"):
|
|
# 正常的数据库查询结果,有摘要的情况
|
|
# 正常的数据库查询结果,有摘要的情况
|
|
- # 不将summary复制到response,让response保持为空
|
|
|
|
|
|
+ # 将summary的值同时赋给response字段(为将来移除summary字段做准备)
|
|
state["final_response"] = {
|
|
state["final_response"] = {
|
|
"success": True,
|
|
"success": True,
|
|
"type": "DATABASE",
|
|
"type": "DATABASE",
|
|
|
|
+ "response": state["summary"], # 新增:将summary的值赋给response
|
|
"sql": state.get("sql"),
|
|
"sql": state.get("sql"),
|
|
"query_result": state.get("query_result"), # 获取query_result字段
|
|
"query_result": state.get("query_result"), # 获取query_result字段
|
|
- "summary": state["summary"],
|
|
|
|
|
|
+ "summary": state["summary"], # 暂时保留summary字段
|
|
"execution_path": state["execution_path"],
|
|
"execution_path": state["execution_path"],
|
|
"classification_info": {
|
|
"classification_info": {
|
|
"confidence": state["classification_confidence"],
|
|
"confidence": state["classification_confidence"],
|
|
@@ -462,7 +472,7 @@ class CituLangGraphAgent:
|
|
# 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
|
|
# 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
|
|
return "CHAT"
|
|
return "CHAT"
|
|
|
|
|
|
- def process_question(self, question: str, session_id: str = None, context_type: str = None) -> Dict[str, Any]:
|
|
|
|
|
|
+ def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
|
|
"""
|
|
"""
|
|
统一的问题处理入口
|
|
统一的问题处理入口
|
|
|
|
|
|
@@ -470,6 +480,7 @@ class CituLangGraphAgent:
|
|
question: 用户问题
|
|
question: 用户问题
|
|
session_id: 会话ID
|
|
session_id: 会话ID
|
|
context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
|
|
context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
|
|
|
|
+ routing_mode: 路由模式,可选,用于覆盖配置文件设置
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
Dict包含完整的处理结果
|
|
Dict包含完整的处理结果
|
|
@@ -478,12 +489,17 @@ class CituLangGraphAgent:
|
|
print(f"[CITU_AGENT] 开始处理问题: {question}")
|
|
print(f"[CITU_AGENT] 开始处理问题: {question}")
|
|
if context_type:
|
|
if context_type:
|
|
print(f"[CITU_AGENT] 上下文类型: {context_type}")
|
|
print(f"[CITU_AGENT] 上下文类型: {context_type}")
|
|
|
|
+ if routing_mode:
|
|
|
|
+ print(f"[CITU_AGENT] 使用指定路由模式: {routing_mode}")
|
|
|
|
+
|
|
|
|
+ # 动态创建workflow(基于路由模式)
|
|
|
|
+ workflow = self._create_workflow(routing_mode)
|
|
|
|
|
|
# 初始化状态
|
|
# 初始化状态
|
|
- initial_state = self._create_initial_state(question, session_id, context_type)
|
|
|
|
|
|
+ initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
|
|
|
|
|
|
# 执行工作流
|
|
# 执行工作流
|
|
- final_state = self.workflow.invoke(
|
|
|
|
|
|
+ final_state = workflow.invoke(
|
|
initial_state,
|
|
initial_state,
|
|
config={
|
|
config={
|
|
"configurable": {"session_id": session_id}
|
|
"configurable": {"session_id": session_id}
|
|
@@ -506,12 +522,17 @@ class CituLangGraphAgent:
|
|
"execution_path": ["error"]
|
|
"execution_path": ["error"]
|
|
}
|
|
}
|
|
|
|
|
|
- def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None) -> AgentState:
|
|
|
|
|
|
+ def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
|
|
"""创建初始状态 - 支持渐进式分类"""
|
|
"""创建初始状态 - 支持渐进式分类"""
|
|
- try:
|
|
|
|
- from app_config import QUESTION_ROUTING_MODE
|
|
|
|
- except ImportError:
|
|
|
|
- QUESTION_ROUTING_MODE = "hybrid"
|
|
|
|
|
|
+ # 确定使用的路由模式
|
|
|
|
+ if routing_mode:
|
|
|
|
+ effective_routing_mode = routing_mode
|
|
|
|
+ else:
|
|
|
|
+ try:
|
|
|
|
+ from app_config import QUESTION_ROUTING_MODE
|
|
|
|
+ effective_routing_mode = QUESTION_ROUTING_MODE
|
|
|
|
+ except ImportError:
|
|
|
|
+ effective_routing_mode = "hybrid"
|
|
|
|
|
|
return AgentState(
|
|
return AgentState(
|
|
# 输入信息
|
|
# 输入信息
|
|
@@ -553,7 +574,7 @@ class CituLangGraphAgent:
|
|
debug_info={},
|
|
debug_info={},
|
|
|
|
|
|
# 路由模式
|
|
# 路由模式
|
|
- routing_mode=QUESTION_ROUTING_MODE
|
|
|
|
|
|
+ routing_mode=effective_routing_mode
|
|
)
|
|
)
|
|
|
|
|
|
def _extract_original_question(self, question: str) -> str:
|
|
def _extract_original_question(self, question: str) -> str:
|
|
@@ -597,7 +618,7 @@ class CituLangGraphAgent:
|
|
return {
|
|
return {
|
|
"status": "healthy" if test_result.get("success") else "degraded",
|
|
"status": "healthy" if test_result.get("success") else "degraded",
|
|
"test_result": test_result.get("success", False),
|
|
"test_result": test_result.get("success", False),
|
|
- "workflow_compiled": self.workflow is not None,
|
|
|
|
|
|
+ "workflow_compiled": True, # 动态创建,始终可用
|
|
"tools_count": len(self.tools),
|
|
"tools_count": len(self.tools),
|
|
"agent_reuse_enabled": False,
|
|
"agent_reuse_enabled": False,
|
|
"message": "Agent健康检查完成"
|
|
"message": "Agent健康检查完成"
|
|
@@ -607,7 +628,7 @@ class CituLangGraphAgent:
|
|
return {
|
|
return {
|
|
"status": "healthy",
|
|
"status": "healthy",
|
|
"test_result": True,
|
|
"test_result": True,
|
|
- "workflow_compiled": self.workflow is not None,
|
|
|
|
|
|
+ "workflow_compiled": True, # 动态创建,始终可用
|
|
"tools_count": len(self.tools),
|
|
"tools_count": len(self.tools),
|
|
"agent_reuse_enabled": False,
|
|
"agent_reuse_enabled": False,
|
|
"message": "Agent简单健康检查完成"
|
|
"message": "Agent简单健康检查完成"
|
|
@@ -617,7 +638,7 @@ class CituLangGraphAgent:
|
|
return {
|
|
return {
|
|
"status": "unhealthy",
|
|
"status": "unhealthy",
|
|
"error": str(e),
|
|
"error": str(e),
|
|
- "workflow_compiled": self.workflow is not None,
|
|
|
|
|
|
+ "workflow_compiled": True, # 动态创建,始终可用
|
|
"tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
|
|
"tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
|
|
"agent_reuse_enabled": False,
|
|
"agent_reuse_enabled": False,
|
|
"message": "Agent健康检查失败"
|
|
"message": "Agent健康检查失败"
|