base.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import asyncio
  2. import time
  3. import logging
  4. from abc import ABC, abstractmethod
  5. from typing import Dict, Any, Optional, Type, List
  6. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext
  7. class ToolRegistry:
  8. """工具注册管理器"""
  9. _tools: Dict[str, Type['BaseTool']] = {}
  10. _instances: Dict[str, 'BaseTool'] = {}
  11. @classmethod
  12. def register(cls, name: str):
  13. """装饰器:注册工具"""
  14. def decorator(tool_class: Type['BaseTool']):
  15. cls._tools[name] = tool_class
  16. logging.debug(f"注册工具: {name} -> {tool_class.__name__}")
  17. return tool_class
  18. return decorator
  19. @classmethod
  20. def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
  21. """获取工具实例,支持单例模式"""
  22. if name not in cls._instances:
  23. if name not in cls._tools:
  24. raise ValueError(f"工具 '{name}' 未注册")
  25. tool_class = cls._tools[name]
  26. # 自动注入vanna实例到需要LLM的工具
  27. if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
  28. from core.vanna_llm_factory import create_vanna_instance
  29. kwargs['vn'] = create_vanna_instance()
  30. logging.debug(f"为工具 {name} 注入LLM实例")
  31. cls._instances[name] = tool_class(**kwargs)
  32. return cls._instances[name]
  33. @classmethod
  34. def list_tools(cls) -> List[str]:
  35. """列出所有已注册的工具"""
  36. return list(cls._tools.keys())
  37. @classmethod
  38. def clear_instances(cls):
  39. """清除所有工具实例(用于测试)"""
  40. cls._instances.clear()
  41. class BaseTool(ABC):
  42. """工具基类"""
  43. needs_llm: bool = False # 是否需要LLM实例
  44. tool_name: str = "" # 工具名称
  45. def __init__(self, **kwargs):
  46. self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")
  47. # 如果工具需要LLM,检查是否已注入
  48. if self.needs_llm and 'vn' not in kwargs:
  49. raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
  50. # 存储vanna实例
  51. if 'vn' in kwargs:
  52. self.vn = kwargs['vn']
  53. @abstractmethod
  54. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  55. """
  56. 执行工具逻辑
  57. Args:
  58. context: 表处理上下文
  59. Returns:
  60. ProcessingResult: 处理结果
  61. """
  62. pass
  63. async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
  64. """带计时的执行包装器"""
  65. start_time = time.time()
  66. try:
  67. self.logger.info(f"开始执行工具: {self.tool_name}")
  68. result = await self.execute(context)
  69. execution_time = time.time() - start_time
  70. result.execution_time = execution_time
  71. if result.success:
  72. self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
  73. else:
  74. self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
  75. return result
  76. except Exception as e:
  77. execution_time = time.time() - start_time
  78. self.logger.exception(f"工具 {self.tool_name} 执行异常")
  79. return ProcessingResult(
  80. success=False,
  81. error_message=f"工具执行异常: {str(e)}",
  82. execution_time=execution_time
  83. )
  84. def validate_input(self, context: TableProcessingContext) -> bool:
  85. """输入验证(子类可重写)"""
  86. return context.table_metadata is not None
  87. class PipelineExecutor:
  88. """处理链执行器"""
  89. def __init__(self, pipeline_config: Dict[str, List[str]]):
  90. self.pipeline_config = pipeline_config
  91. self.logger = logging.getLogger("schema_tools.PipelineExecutor")
  92. async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
  93. """执行指定的处理链"""
  94. if pipeline_name not in self.pipeline_config:
  95. raise ValueError(f"未知的处理链: {pipeline_name}")
  96. steps = self.pipeline_config[pipeline_name]
  97. results = {}
  98. self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
  99. for step_name in steps:
  100. try:
  101. tool = ToolRegistry.get_tool(step_name)
  102. # 验证输入
  103. if not tool.validate_input(context):
  104. result = ProcessingResult(
  105. success=False,
  106. error_message=f"工具 {step_name} 输入验证失败"
  107. )
  108. else:
  109. result = await tool._execute_with_timing(context)
  110. results[step_name] = result
  111. context.update_step(step_name, result)
  112. # 如果步骤失败且不允许继续,则停止
  113. if not result.success:
  114. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  115. if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
  116. self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
  117. break
  118. else:
  119. self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
  120. except Exception as e:
  121. self.logger.exception(f"执行步骤 {step_name} 时发生异常")
  122. results[step_name] = ProcessingResult(
  123. success=False,
  124. error_message=f"步骤执行异常: {str(e)}"
  125. )
  126. break
  127. return results