data_sampler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import random
  2. from typing import List, Dict, Any
  3. from schema_tools.tools.base import BaseTool, ToolRegistry
  4. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, TableMetadata
  5. @ToolRegistry.register("data_sampler")
  6. class DataSamplerTool(BaseTool):
  7. """数据采样工具"""
  8. needs_llm = False
  9. tool_name = "数据采样器"
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. self.db_connection = kwargs.get('db_connection')
  13. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  14. """执行数据采样"""
  15. try:
  16. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  17. table_metadata = context.table_metadata
  18. sample_limit = SCHEMA_TOOLS_CONFIG["sample_data_limit"]
  19. large_table_threshold = SCHEMA_TOOLS_CONFIG["large_table_threshold"]
  20. # 判断是否为大表,使用不同的采样策略
  21. if table_metadata.row_count and table_metadata.row_count > large_table_threshold:
  22. sample_data = await self._smart_sample_large_table(table_metadata, sample_limit)
  23. self.logger.info(f"大表 {table_metadata.full_name} 使用智能采样策略")
  24. else:
  25. sample_data = await self._simple_sample(table_metadata, sample_limit)
  26. # 更新上下文中的采样数据
  27. context.table_metadata.sample_data = sample_data
  28. return ProcessingResult(
  29. success=True,
  30. data={
  31. 'sample_count': len(sample_data),
  32. 'sampling_strategy': 'smart' if table_metadata.row_count and table_metadata.row_count > large_table_threshold else 'simple'
  33. },
  34. metadata={'tool': self.tool_name}
  35. )
  36. except Exception as e:
  37. self.logger.exception(f"数据采样失败")
  38. return ProcessingResult(
  39. success=False,
  40. error_message=f"数据采样失败: {str(e)}"
  41. )
  42. async def _simple_sample(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
  43. """简单采样策略"""
  44. from schema_tools.tools.database_inspector import DatabaseInspectorTool
  45. # 复用数据库检查工具的连接
  46. inspector = ToolRegistry.get_tool("database_inspector")
  47. query = f"SELECT * FROM {table_metadata.full_name} LIMIT {limit}"
  48. async with inspector.connection_pool.acquire() as conn:
  49. rows = await conn.fetch(query)
  50. return [dict(row) for row in rows]
  51. async def _smart_sample_large_table(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
  52. """智能采样策略(用于大表)"""
  53. from schema_tools.tools.database_inspector import DatabaseInspectorTool
  54. inspector = ToolRegistry.get_tool("database_inspector")
  55. samples_per_section = max(1, limit // 3)
  56. samples = []
  57. async with inspector.connection_pool.acquire() as conn:
  58. # 1. 前N行采样
  59. front_query = f"SELECT * FROM {table_metadata.full_name} LIMIT {samples_per_section}"
  60. front_rows = await conn.fetch(front_query)
  61. samples.extend([dict(row) for row in front_rows])
  62. # 2. 随机中间采样(使用TABLESAMPLE)
  63. if table_metadata.row_count > samples_per_section * 2:
  64. try:
  65. # 计算采样百分比
  66. sample_percent = min(1.0, (samples_per_section * 100.0) / table_metadata.row_count)
  67. middle_query = f"""
  68. SELECT * FROM {table_metadata.full_name}
  69. TABLESAMPLE SYSTEM({sample_percent})
  70. LIMIT {samples_per_section}
  71. """
  72. middle_rows = await conn.fetch(middle_query)
  73. samples.extend([dict(row) for row in middle_rows])
  74. except Exception as e:
  75. self.logger.warning(f"TABLESAMPLE采样失败,使用OFFSET采样: {e}")
  76. # 回退到OFFSET采样
  77. offset = random.randint(samples_per_section, table_metadata.row_count - samples_per_section)
  78. offset_query = f"SELECT * FROM {table_metadata.full_name} OFFSET {offset} LIMIT {samples_per_section}"
  79. offset_rows = await conn.fetch(offset_query)
  80. samples.extend([dict(row) for row in offset_rows])
  81. # 3. 后N行采样
  82. remaining = limit - len(samples)
  83. if remaining > 0:
  84. # 使用ORDER BY ... DESC来获取最后的行
  85. tail_query = f"""
  86. SELECT * FROM (
  87. SELECT *, ROW_NUMBER() OVER() as rn
  88. FROM {table_metadata.full_name}
  89. ) sub
  90. WHERE sub.rn > (SELECT COUNT(*) FROM {table_metadata.full_name}) - {remaining}
  91. ORDER BY sub.rn
  92. """
  93. try:
  94. tail_rows = await conn.fetch(tail_query)
  95. # 移除ROW_NUMBER列
  96. for row in tail_rows:
  97. row_dict = dict(row)
  98. row_dict.pop('rn', None)
  99. samples.append(row_dict)
  100. except Exception as e:
  101. self.logger.warning(f"尾部采样失败: {e}")
  102. return samples[:limit] # 确保不超过限制