base.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import asyncio
  2. import time
  3. from abc import ABC, abstractmethod
  4. from core.logging import get_data_pipeline_logger
  5. from typing import Dict, Any, Optional, Type, List
  6. from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext
  7. class ToolRegistry:
  8. """工具注册管理器"""
  9. _tools: Dict[str, Type['BaseTool']] = {}
  10. _instances: Dict[str, 'BaseTool'] = {}
  11. @classmethod
  12. def register(cls, name: str):
  13. """装饰器:注册工具"""
  14. def decorator(tool_class: Type['BaseTool']):
  15. cls._tools[name] = tool_class
  16. logger = get_data_pipeline_logger("ToolRegistry")
  17. logger.debug(f"注册工具: {name} -> {tool_class.__name__}")
  18. return tool_class
  19. return decorator
  20. @classmethod
  21. def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
  22. """获取工具实例,支持单例模式"""
  23. if name not in cls._instances:
  24. if name not in cls._tools:
  25. raise ValueError(f"工具 '{name}' 未注册")
  26. tool_class = cls._tools[name]
  27. # 自动注入vanna实例到需要LLM的工具
  28. if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
  29. from core.vanna_llm_factory import create_vanna_instance
  30. kwargs['vn'] = create_vanna_instance()
  31. logger = get_data_pipeline_logger("ToolRegistry")
  32. logger.debug(f"为工具 {name} 注入LLM实例")
  33. cls._instances[name] = tool_class(**kwargs)
  34. return cls._instances[name]
  35. @classmethod
  36. def list_tools(cls) -> List[str]:
  37. """列出所有已注册的工具"""
  38. return list(cls._tools.keys())
  39. @classmethod
  40. def clear_instances(cls):
  41. """清除所有工具实例(用于测试)"""
  42. cls._instances.clear()
  43. class BaseTool(ABC):
  44. """工具基类"""
  45. needs_llm: bool = False # 是否需要LLM实例
  46. tool_name: str = "" # 工具名称
  47. def __init__(self, **kwargs):
  48. self.logger = get_data_pipeline_logger(f"tools.{self.__class__.__name__}")
  49. # 如果工具需要LLM,检查是否已注入
  50. if self.needs_llm and 'vn' not in kwargs:
  51. raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
  52. # 存储vanna实例
  53. if 'vn' in kwargs:
  54. self.vn = kwargs['vn']
  55. @abstractmethod
  56. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  57. """
  58. 执行工具逻辑
  59. Args:
  60. context: 表处理上下文
  61. Returns:
  62. ProcessingResult: 处理结果
  63. """
  64. pass
  65. async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
  66. """带计时的执行包装器"""
  67. start_time = time.time()
  68. try:
  69. self.logger.info(f"开始执行工具: {self.tool_name}")
  70. result = await self.execute(context)
  71. execution_time = time.time() - start_time
  72. result.execution_time = execution_time
  73. if result.success:
  74. self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
  75. else:
  76. self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
  77. return result
  78. except Exception as e:
  79. execution_time = time.time() - start_time
  80. self.logger.exception(f"工具 {self.tool_name} 执行异常")
  81. return ProcessingResult(
  82. success=False,
  83. error_message=f"工具执行异常: {str(e)}",
  84. execution_time=execution_time
  85. )
  86. def validate_input(self, context: TableProcessingContext) -> bool:
  87. """输入验证(子类可重写)"""
  88. return context.table_metadata is not None
  89. class PipelineExecutor:
  90. """处理链执行器"""
  91. def __init__(self, pipeline_config: Dict[str, List[str]]):
  92. self.pipeline_config = pipeline_config
  93. self.logger = get_data_pipeline_logger("tools.PipelineExecutor")
  94. async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
  95. """执行指定的处理链"""
  96. if pipeline_name not in self.pipeline_config:
  97. raise ValueError(f"未知的处理链: {pipeline_name}")
  98. steps = self.pipeline_config[pipeline_name]
  99. results = {}
  100. self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
  101. for step_name in steps:
  102. try:
  103. tool = ToolRegistry.get_tool(step_name)
  104. # 验证输入
  105. if not tool.validate_input(context):
  106. result = ProcessingResult(
  107. success=False,
  108. error_message=f"工具 {step_name} 输入验证失败"
  109. )
  110. else:
  111. result = await tool._execute_with_timing(context)
  112. results[step_name] = result
  113. context.update_step(step_name, result)
  114. # 如果步骤失败且不允许继续,则停止
  115. if not result.success:
  116. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  117. if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
  118. self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
  119. break
  120. else:
  121. self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
  122. except Exception as e:
  123. self.logger.exception(f"执行步骤 {step_name} 时发生异常")
  124. results[step_name] = ProcessingResult(
  125. success=False,
  126. error_message=f"步骤执行异常: {str(e)}"
  127. )
  128. break
  129. return results