sql_validator.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import asyncio
  2. import time
  3. from typing import Dict, Any, List, Optional
  4. from dataclasses import dataclass, field
  5. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  6. import logging
  7. @dataclass
  8. class SQLValidationResult:
  9. """SQL验证结果"""
  10. sql: str
  11. valid: bool
  12. error_message: str = ""
  13. execution_time: float = 0.0
  14. retry_count: int = 0
  15. # SQL修复相关字段
  16. repair_attempted: bool = False
  17. repair_successful: bool = False
  18. repaired_sql: str = ""
  19. repair_error: str = ""
  20. @dataclass
  21. class ValidationStats:
  22. """验证统计信息"""
  23. total_sqls: int = 0
  24. valid_sqls: int = 0
  25. invalid_sqls: int = 0
  26. total_time: float = 0.0
  27. avg_time_per_sql: float = 0.0
  28. retry_count: int = 0
  29. # SQL修复统计
  30. repair_attempted: int = 0
  31. repair_successful: int = 0
  32. repair_failed: int = 0
  33. class SQLValidator:
  34. """SQL验证器"""
  35. def __init__(self, db_connection: str = None):
  36. """
  37. 初始化SQL验证器
  38. Args:
  39. db_connection: 数据库连接字符串(可选,用于复用连接池)
  40. """
  41. self.db_connection = db_connection
  42. self.connection_pool = None
  43. self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
  44. self.logger = logging.getLogger("SQLValidator")
  45. async def _get_connection_pool(self):
  46. """获取或创建连接池"""
  47. if not self.connection_pool:
  48. if self.db_connection:
  49. # 直接创建自己的连接池,避免复用问题
  50. import asyncpg
  51. try:
  52. self.connection_pool = await asyncpg.create_pool(
  53. self.db_connection,
  54. min_size=1,
  55. max_size=5,
  56. command_timeout=30
  57. )
  58. self.logger.info("SQL验证器连接池创建成功")
  59. except Exception as e:
  60. self.logger.error(f"创建SQL验证器连接池失败: {e}")
  61. raise
  62. else:
  63. raise ValueError("需要提供数据库连接字符串")
  64. return self.connection_pool
  65. async def validate_sql(self, sql: str, retry_count: int = 0) -> SQLValidationResult:
  66. """
  67. 验证单个SQL语句
  68. Args:
  69. sql: 要验证的SQL语句
  70. retry_count: 当前重试次数
  71. Returns:
  72. SQLValidationResult: 验证结果
  73. """
  74. start_time = time.time()
  75. try:
  76. pool = await self._get_connection_pool()
  77. async with pool.acquire() as conn:
  78. # 设置超时
  79. timeout = self.config['validation_timeout']
  80. # 设置只读模式(安全考虑)
  81. if self.config['readonly_mode']:
  82. await asyncio.wait_for(
  83. conn.execute("SET default_transaction_read_only = on"),
  84. timeout=timeout
  85. )
  86. # 执行EXPLAIN验证SQL
  87. await asyncio.wait_for(
  88. conn.execute(f"EXPLAIN {sql}"),
  89. timeout=timeout
  90. )
  91. execution_time = time.time() - start_time
  92. self.logger.debug(f"SQL验证成功: {sql[:50]}... ({execution_time:.3f}s)")
  93. return SQLValidationResult(
  94. sql=sql,
  95. valid=True,
  96. execution_time=execution_time,
  97. retry_count=retry_count
  98. )
  99. except asyncio.TimeoutError:
  100. execution_time = time.time() - start_time
  101. error_msg = f"验证超时({self.config['validation_timeout']}秒)"
  102. self.logger.warning(f"SQL验证超时: {sql[:50]}...")
  103. return SQLValidationResult(
  104. sql=sql,
  105. valid=False,
  106. error_message=error_msg,
  107. execution_time=execution_time,
  108. retry_count=retry_count
  109. )
  110. except Exception as e:
  111. execution_time = time.time() - start_time
  112. error_msg = str(e)
  113. # 检查是否需要重试
  114. max_retries = self.config['max_retry_count']
  115. if retry_count < max_retries and self._should_retry(e):
  116. self.logger.debug(f"SQL验证失败,重试 {retry_count + 1}/{max_retries}: {error_msg}")
  117. await asyncio.sleep(0.5) # 短暂等待后重试
  118. return await self.validate_sql(sql, retry_count + 1)
  119. self.logger.debug(f"SQL验证失败: {sql[:50]}... - {error_msg}")
  120. return SQLValidationResult(
  121. sql=sql,
  122. valid=False,
  123. error_message=error_msg,
  124. execution_time=execution_time,
  125. retry_count=retry_count
  126. )
  127. async def validate_sqls_batch(self, sqls: List[str]) -> List[SQLValidationResult]:
  128. """
  129. 批量验证SQL语句
  130. Args:
  131. sqls: SQL语句列表
  132. Returns:
  133. 验证结果列表
  134. """
  135. if not sqls:
  136. return []
  137. max_concurrent = self.config['max_concurrent_validations']
  138. semaphore = asyncio.Semaphore(max_concurrent)
  139. async def validate_with_semaphore(sql):
  140. async with semaphore:
  141. return await self.validate_sql(sql)
  142. self.logger.info(f"开始批量验证 {len(sqls)} 个SQL语句 (并发度: {max_concurrent})")
  143. # 并发执行验证
  144. tasks = [validate_with_semaphore(sql) for sql in sqls]
  145. results = await asyncio.gather(*tasks, return_exceptions=True)
  146. # 处理异常结果
  147. validated_results = []
  148. for i, result in enumerate(results):
  149. if isinstance(result, Exception):
  150. self.logger.error(f"SQL验证任务异常: {sqls[i][:50]}... - {result}")
  151. validated_results.append(SQLValidationResult(
  152. sql=sqls[i],
  153. valid=False,
  154. error_message=f"验证任务异常: {str(result)}"
  155. ))
  156. else:
  157. validated_results.append(result)
  158. return validated_results
  159. def _should_retry(self, error: Exception) -> bool:
  160. """
  161. 判断是否应该重试
  162. Args:
  163. error: 异常对象
  164. Returns:
  165. 是否应该重试
  166. """
  167. # 一般网络或连接相关的错误可以重试
  168. retry_indicators = [
  169. "connection",
  170. "network",
  171. "timeout",
  172. "server closed",
  173. "pool",
  174. ]
  175. error_str = str(error).lower()
  176. return any(indicator in error_str for indicator in retry_indicators)
  177. def calculate_stats(self, results: List[SQLValidationResult]) -> ValidationStats:
  178. """
  179. 计算验证统计信息
  180. Args:
  181. results: 验证结果列表
  182. Returns:
  183. ValidationStats: 统计信息
  184. """
  185. total_sqls = len(results)
  186. valid_sqls = sum(1 for r in results if r.valid)
  187. invalid_sqls = total_sqls - valid_sqls
  188. total_time = sum(r.execution_time for r in results)
  189. avg_time = total_time / total_sqls if total_sqls > 0 else 0.0
  190. total_retries = sum(r.retry_count for r in results)
  191. # 计算修复统计
  192. repair_attempted = sum(1 for r in results if r.repair_attempted)
  193. repair_successful = sum(1 for r in results if r.repair_successful)
  194. repair_failed = repair_attempted - repair_successful
  195. return ValidationStats(
  196. total_sqls=total_sqls,
  197. valid_sqls=valid_sqls,
  198. invalid_sqls=invalid_sqls,
  199. total_time=total_time,
  200. avg_time_per_sql=avg_time,
  201. retry_count=total_retries,
  202. repair_attempted=repair_attempted,
  203. repair_successful=repair_successful,
  204. repair_failed=repair_failed
  205. )