base.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import asyncio
  2. import time
  3. from abc import ABC, abstractmethod
  4. import logging
  5. from typing import Dict, Any, Optional, Type, List
  6. from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext
  7. class ToolRegistry:
  8. """工具注册管理器"""
  9. _tools: Dict[str, Type['BaseTool']] = {}
  10. @classmethod
  11. def register(cls, name: str):
  12. """装饰器:注册工具"""
  13. def decorator(tool_class: Type['BaseTool']):
  14. cls._tools[name] = tool_class
  15. logger = logging.getLogger("ToolRegistry")
  16. logger.debug(f"注册工具: {name} -> {tool_class.__name__}")
  17. return tool_class
  18. return decorator
  19. @classmethod
  20. def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
  21. """获取工具实例,每次返回新实例确保参数正确传递"""
  22. if name not in cls._tools:
  23. raise ValueError(f"工具 '{name}' 未注册")
  24. tool_class = cls._tools[name]
  25. # 自动注入vanna实例到需要LLM的工具
  26. if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
  27. from core.vanna_llm_factory import create_vanna_instance
  28. kwargs['vn'] = create_vanna_instance()
  29. logger = logging.getLogger("ToolRegistry")
  30. logger.debug(f"为工具 {name} 注入LLM实例")
  31. # 直接返回新实例,不使用单例模式
  32. return tool_class(**kwargs)
  33. @classmethod
  34. def list_tools(cls) -> List[str]:
  35. """列出所有已注册的工具"""
  36. return list(cls._tools.keys())
  37. class BaseTool(ABC):
  38. """工具基类"""
  39. needs_llm: bool = False # 是否需要LLM实例
  40. tool_name: str = "" # 工具名称
  41. def __init__(self, **kwargs):
  42. self.logger = logging.getLogger(f"tools.{self.__class__.__name__}")
  43. # 如果工具需要LLM,检查是否已注入
  44. if self.needs_llm and 'vn' not in kwargs:
  45. raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
  46. # 存储vanna实例
  47. if 'vn' in kwargs:
  48. self.vn = kwargs['vn']
  49. @abstractmethod
  50. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  51. """
  52. 执行工具逻辑
  53. Args:
  54. context: 表处理上下文
  55. Returns:
  56. ProcessingResult: 处理结果
  57. """
  58. pass
  59. async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
  60. """带计时的执行包装器"""
  61. start_time = time.time()
  62. try:
  63. self.logger.info(f"开始执行工具: {self.tool_name}")
  64. result = await self.execute(context)
  65. execution_time = time.time() - start_time
  66. result.execution_time = execution_time
  67. if result.success:
  68. self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
  69. else:
  70. self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
  71. return result
  72. except Exception as e:
  73. execution_time = time.time() - start_time
  74. self.logger.exception(f"工具 {self.tool_name} 执行异常")
  75. return ProcessingResult(
  76. success=False,
  77. error_message=f"工具执行异常: {str(e)}",
  78. execution_time=execution_time
  79. )
  80. def validate_input(self, context: TableProcessingContext) -> bool:
  81. """输入验证(子类可重写)"""
  82. return context.table_metadata is not None
  83. class PipelineExecutor:
  84. """处理链执行器"""
  85. def __init__(self, pipeline_config: Dict[str, List[str]]):
  86. self.pipeline_config = pipeline_config
  87. self.logger = logging.getLogger("tools.PipelineExecutor")
  88. async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
  89. """执行指定的处理链"""
  90. if pipeline_name not in self.pipeline_config:
  91. raise ValueError(f"未知的处理链: {pipeline_name}")
  92. steps = self.pipeline_config[pipeline_name]
  93. results = {}
  94. self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
  95. for step_name in steps:
  96. try:
  97. tool = ToolRegistry.get_tool(step_name)
  98. # 验证输入
  99. if not tool.validate_input(context):
  100. result = ProcessingResult(
  101. success=False,
  102. error_message=f"工具 {step_name} 输入验证失败"
  103. )
  104. else:
  105. result = await tool._execute_with_timing(context)
  106. results[step_name] = result
  107. context.update_step(step_name, result)
  108. # 如果步骤失败且不允许继续,则停止
  109. if not result.success:
  110. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  111. if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
  112. self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
  113. break
  114. else:
  115. self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
  116. except Exception as e:
  117. self.logger.exception(f"执行步骤 {step_name} 时发生异常")
  118. results[step_name] = ProcessingResult(
  119. success=False,
  120. error_message=f"步骤执行异常: {str(e)}"
  121. )
  122. break
  123. return results