123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import asyncio
- import time
- import logging
- from abc import ABC, abstractmethod
- from typing import Dict, Any, Optional, Type, List
- from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext
- class ToolRegistry:
- """工具注册管理器"""
- _tools: Dict[str, Type['BaseTool']] = {}
- _instances: Dict[str, 'BaseTool'] = {}
-
- @classmethod
- def register(cls, name: str):
- """装饰器:注册工具"""
- def decorator(tool_class: Type['BaseTool']):
- cls._tools[name] = tool_class
- logging.debug(f"注册工具: {name} -> {tool_class.__name__}")
- return tool_class
- return decorator
-
- @classmethod
- def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
- """获取工具实例,支持单例模式"""
- if name not in cls._instances:
- if name not in cls._tools:
- raise ValueError(f"工具 '{name}' 未注册")
-
- tool_class = cls._tools[name]
-
- # 自动注入vanna实例到需要LLM的工具
- if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
- from core.vanna_llm_factory import create_vanna_instance
- kwargs['vn'] = create_vanna_instance()
- logging.debug(f"为工具 {name} 注入LLM实例")
-
- cls._instances[name] = tool_class(**kwargs)
-
- return cls._instances[name]
-
- @classmethod
- def list_tools(cls) -> List[str]:
- """列出所有已注册的工具"""
- return list(cls._tools.keys())
-
- @classmethod
- def clear_instances(cls):
- """清除所有工具实例(用于测试)"""
- cls._instances.clear()
- class BaseTool(ABC):
- """工具基类"""
-
- needs_llm: bool = False # 是否需要LLM实例
- tool_name: str = "" # 工具名称
-
- def __init__(self, **kwargs):
- self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")
-
- # 如果工具需要LLM,检查是否已注入
- if self.needs_llm and 'vn' not in kwargs:
- raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
-
- # 存储vanna实例
- if 'vn' in kwargs:
- self.vn = kwargs['vn']
-
- @abstractmethod
- async def execute(self, context: TableProcessingContext) -> ProcessingResult:
- """
- 执行工具逻辑
- Args:
- context: 表处理上下文
- Returns:
- ProcessingResult: 处理结果
- """
- pass
-
- async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
- """带计时的执行包装器"""
- start_time = time.time()
-
- try:
- self.logger.info(f"开始执行工具: {self.tool_name}")
- result = await self.execute(context)
- execution_time = time.time() - start_time
- result.execution_time = execution_time
-
- if result.success:
- self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
- else:
- self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
-
- return result
-
- except Exception as e:
- execution_time = time.time() - start_time
- self.logger.exception(f"工具 {self.tool_name} 执行异常")
-
- return ProcessingResult(
- success=False,
- error_message=f"工具执行异常: {str(e)}",
- execution_time=execution_time
- )
-
- def validate_input(self, context: TableProcessingContext) -> bool:
- """输入验证(子类可重写)"""
- return context.table_metadata is not None
- class PipelineExecutor:
- """处理链执行器"""
-
- def __init__(self, pipeline_config: Dict[str, List[str]]):
- self.pipeline_config = pipeline_config
- self.logger = logging.getLogger("schema_tools.PipelineExecutor")
-
- async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
- """执行指定的处理链"""
- if pipeline_name not in self.pipeline_config:
- raise ValueError(f"未知的处理链: {pipeline_name}")
-
- steps = self.pipeline_config[pipeline_name]
- results = {}
-
- self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
-
- for step_name in steps:
- try:
- tool = ToolRegistry.get_tool(step_name)
-
- # 验证输入
- if not tool.validate_input(context):
- result = ProcessingResult(
- success=False,
- error_message=f"工具 {step_name} 输入验证失败"
- )
- else:
- result = await tool._execute_with_timing(context)
-
- results[step_name] = result
- context.update_step(step_name, result)
-
- # 如果步骤失败且不允许继续,则停止
- if not result.success:
- from schema_tools.config import SCHEMA_TOOLS_CONFIG
- if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
- self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
- break
- else:
- self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
-
- except Exception as e:
- self.logger.exception(f"执行步骤 {step_name} 时发生异常")
- results[step_name] = ProcessingResult(
- success=False,
- error_message=f"步骤执行异常: {str(e)}"
- )
- break
-
- return results
|