data_sampler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import random
  2. from typing import List, Dict, Any
  3. from data_pipeline.tools.base import BaseTool, ToolRegistry
  4. from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext, TableMetadata
  5. @ToolRegistry.register("data_sampler")
  6. class DataSamplerTool(BaseTool):
  7. """数据采样工具"""
  8. needs_llm = False
  9. tool_name = "数据采样器"
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. self.db_connection = kwargs.get('db_connection')
  13. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  14. """执行数据采样"""
  15. try:
  16. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  17. table_metadata = context.table_metadata
  18. sample_limit = SCHEMA_TOOLS_CONFIG["sample_data_limit"]
  19. large_table_threshold = SCHEMA_TOOLS_CONFIG["large_table_threshold"]
  20. # 判断是否为大表,使用不同的采样策略
  21. if table_metadata.row_count and table_metadata.row_count > large_table_threshold:
  22. sample_data = await self._smart_sample_large_table(table_metadata, sample_limit)
  23. self.logger.info(f"大表 {table_metadata.full_name} 使用智能采样策略")
  24. else:
  25. sample_data = await self._simple_sample(table_metadata, sample_limit)
  26. # 更新上下文中的采样数据
  27. context.table_metadata.sample_data = sample_data
  28. return ProcessingResult(
  29. success=True,
  30. data={
  31. 'sample_count': len(sample_data),
  32. 'sampling_strategy': 'smart' if table_metadata.row_count and table_metadata.row_count > large_table_threshold else 'simple'
  33. },
  34. metadata={'tool': self.tool_name}
  35. )
  36. except Exception as e:
  37. self.logger.exception(f"数据采样失败")
  38. return ProcessingResult(
  39. success=False,
  40. error_message=f"数据采样失败: {str(e)}"
  41. )
  42. async def _simple_sample(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
  43. """简单采样策略"""
  44. import asyncpg
  45. # 直接使用数据库连接字符串创建连接
  46. query = f"SELECT * FROM {table_metadata.full_name} LIMIT {limit}"
  47. conn = await asyncpg.connect(self.db_connection)
  48. try:
  49. rows = await conn.fetch(query)
  50. return [dict(row) for row in rows]
  51. finally:
  52. await conn.close()
  53. async def _smart_sample_large_table(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
  54. """智能采样策略(用于大表)"""
  55. import asyncpg
  56. samples_per_section = max(1, limit // 3)
  57. samples = []
  58. conn = await asyncpg.connect(self.db_connection)
  59. try:
  60. # 1. 前N行采样
  61. front_query = f"SELECT * FROM {table_metadata.full_name} LIMIT {samples_per_section}"
  62. front_rows = await conn.fetch(front_query)
  63. samples.extend([dict(row) for row in front_rows])
  64. # 2. 随机中间采样(使用TABLESAMPLE)
  65. if table_metadata.row_count > samples_per_section * 2:
  66. try:
  67. # 计算采样百分比
  68. sample_percent = min(1.0, (samples_per_section * 100.0) / table_metadata.row_count)
  69. middle_query = f"""
  70. SELECT * FROM {table_metadata.full_name}
  71. TABLESAMPLE SYSTEM({sample_percent})
  72. LIMIT {samples_per_section}
  73. """
  74. middle_rows = await conn.fetch(middle_query)
  75. samples.extend([dict(row) for row in middle_rows])
  76. except Exception as e:
  77. self.logger.warning(f"TABLESAMPLE采样失败,使用OFFSET采样: {e}")
  78. # 回退到OFFSET采样
  79. offset = random.randint(samples_per_section, table_metadata.row_count - samples_per_section)
  80. offset_query = f"SELECT * FROM {table_metadata.full_name} OFFSET {offset} LIMIT {samples_per_section}"
  81. offset_rows = await conn.fetch(offset_query)
  82. samples.extend([dict(row) for row in offset_rows])
  83. # 3. 后N行采样
  84. remaining = limit - len(samples)
  85. if remaining > 0:
  86. # 使用ORDER BY ... DESC来获取最后的行
  87. tail_query = f"""
  88. SELECT * FROM (
  89. SELECT *, ROW_NUMBER() OVER() as rn
  90. FROM {table_metadata.full_name}
  91. ) sub
  92. WHERE sub.rn > (SELECT COUNT(*) FROM {table_metadata.full_name}) - {remaining}
  93. ORDER BY sub.rn
  94. """
  95. try:
  96. tail_rows = await conn.fetch(tail_query)
  97. # 移除ROW_NUMBER列
  98. for row in tail_rows:
  99. row_dict = dict(row)
  100. row_dict.pop('rn', None)
  101. samples.append(row_dict)
  102. except Exception as e:
  103. self.logger.warning(f"尾部采样失败: {e}")
  104. finally:
  105. await conn.close()
  106. return samples[:limit] # 确保不超过限制