| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 | import asyncioimport timeimport loggingfrom abc import ABC, abstractmethodfrom typing import Dict, Any, Optional, Type, Listfrom data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContextclass ToolRegistry:    """工具注册管理器"""    _tools: Dict[str, Type['BaseTool']] = {}    _instances: Dict[str, 'BaseTool'] = {}        @classmethod    def register(cls, name: str):        """装饰器:注册工具"""        def decorator(tool_class: Type['BaseTool']):            cls._tools[name] = tool_class            logging.debug(f"注册工具: {name} -> {tool_class.__name__}")            return tool_class        return decorator        @classmethod    def get_tool(cls, name: str, **kwargs) -> 'BaseTool':        """获取工具实例,支持单例模式"""        if name not in cls._instances:            if name not in cls._tools:                raise ValueError(f"工具 '{name}' 未注册")                        tool_class = cls._tools[name]                        # 自动注入vanna实例到需要LLM的工具            if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:                from core.vanna_llm_factory import create_vanna_instance                kwargs['vn'] = create_vanna_instance()                logging.debug(f"为工具 {name} 注入LLM实例")                        cls._instances[name] = tool_class(**kwargs)                return cls._instances[name]        @classmethod    def list_tools(cls) -> List[str]:        """列出所有已注册的工具"""        return list(cls._tools.keys())        @classmethod    def clear_instances(cls):        """清除所有工具实例(用于测试)"""        cls._instances.clear()class BaseTool(ABC):    """工具基类"""        needs_llm: bool = False  # 是否需要LLM实例    tool_name: str = ""      # 工具名称        def __init__(self, **kwargs):        self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")                # 如果工具需要LLM,检查是否已注入        if self.needs_llm and 'vn' not in kwargs:            raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")                # 存储vanna实例        if 'vn' in kwargs:            self.vn = kwargs['vn']        @abstractmethod    async def execute(self, context: TableProcessingContext) -> ProcessingResult:        """        执行工具逻辑        Args:            context: 表处理上下文        Returns:            ProcessingResult: 处理结果        """        pass        async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:        """带计时的执行包装器"""        start_time = time.time()                try:            self.logger.info(f"开始执行工具: {self.tool_name}")            result = await self.execute(context)            execution_time = time.time() - start_time            result.execution_time = execution_time                        if result.success:                self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")            else:                self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")                        return result                    except Exception as e:            execution_time = time.time() - start_time            self.logger.exception(f"工具 {self.tool_name} 执行异常")                        return ProcessingResult(                success=False,                error_message=f"工具执行异常: {str(e)}",                execution_time=execution_time            )        def validate_input(self, context: TableProcessingContext) -> bool:        """输入验证(子类可重写)"""        return context.table_metadata is not Noneclass PipelineExecutor:    """处理链执行器"""        def __init__(self, pipeline_config: Dict[str, List[str]]):        self.pipeline_config = pipeline_config        self.logger = logging.getLogger("schema_tools.PipelineExecutor")        async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:        """执行指定的处理链"""        if pipeline_name not in self.pipeline_config:            raise ValueError(f"未知的处理链: {pipeline_name}")                steps = self.pipeline_config[pipeline_name]        results = {}                self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")                for step_name in steps:            try:                tool = ToolRegistry.get_tool(step_name)                                # 验证输入                if not tool.validate_input(context):                    result = ProcessingResult(                        success=False,                        error_message=f"工具 {step_name} 输入验证失败"                    )                else:                    result = await tool._execute_with_timing(context)                                results[step_name] = result                context.update_step(step_name, result)                                # 如果步骤失败且不允许继续,则停止                if not result.success:                    from data_pipeline.config import SCHEMA_TOOLS_CONFIG                    if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:                        self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")                        break                    else:                        self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")                            except Exception as e:                self.logger.exception(f"执行步骤 {step_name} 时发生异常")                results[step_name] = ProcessingResult(                    success=False,                    error_message=f"步骤执行异常: {str(e)}"                )                break                return results
 |