sql_validator.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import asyncio
  2. import logging
  3. import time
  4. from typing import Dict, Any, List, Optional
  5. from dataclasses import dataclass, field
  6. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  7. @dataclass
  8. class SQLValidationResult:
  9. """SQL验证结果"""
  10. sql: str
  11. valid: bool
  12. error_message: str = ""
  13. execution_time: float = 0.0
  14. retry_count: int = 0
  15. # SQL修复相关字段
  16. repair_attempted: bool = False
  17. repair_successful: bool = False
  18. repaired_sql: str = ""
  19. repair_error: str = ""
  20. @dataclass
  21. class ValidationStats:
  22. """验证统计信息"""
  23. total_sqls: int = 0
  24. valid_sqls: int = 0
  25. invalid_sqls: int = 0
  26. total_time: float = 0.0
  27. avg_time_per_sql: float = 0.0
  28. retry_count: int = 0
  29. # SQL修复统计
  30. repair_attempted: int = 0
  31. repair_successful: int = 0
  32. repair_failed: int = 0
  33. class SQLValidator:
  34. """SQL验证器"""
  35. def __init__(self, db_connection: str = None):
  36. """
  37. 初始化SQL验证器
  38. Args:
  39. db_connection: 数据库连接字符串(可选,用于复用连接池)
  40. """
  41. self.db_connection = db_connection
  42. self.connection_pool = None
  43. self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
  44. self.logger = logging.getLogger("schema_tools.SQLValidator")
  45. async def _get_connection_pool(self):
  46. """获取或复用现有连接池"""
  47. if not self.connection_pool:
  48. if self.config['reuse_connection_pool'] and self.db_connection:
  49. # 复用现有的DatabaseInspector连接池
  50. from schema_tools.tools.base import ToolRegistry
  51. db_tool = ToolRegistry.get_tool("database_inspector",
  52. db_connection=self.db_connection)
  53. # 如果连接池不存在,则创建
  54. if not db_tool.connection_pool:
  55. await db_tool._create_connection_pool()
  56. # 复用连接池
  57. self.connection_pool = db_tool.connection_pool
  58. self.logger.info("复用现有数据库连接池进行SQL验证")
  59. else:
  60. raise ValueError("需要提供数据库连接字符串或启用连接池复用")
  61. return self.connection_pool
  62. async def validate_sql(self, sql: str, retry_count: int = 0) -> SQLValidationResult:
  63. """
  64. 验证单个SQL语句
  65. Args:
  66. sql: 要验证的SQL语句
  67. retry_count: 当前重试次数
  68. Returns:
  69. SQLValidationResult: 验证结果
  70. """
  71. start_time = time.time()
  72. try:
  73. pool = await self._get_connection_pool()
  74. async with pool.acquire() as conn:
  75. # 设置超时
  76. timeout = self.config['validation_timeout']
  77. # 设置只读模式(安全考虑)
  78. if self.config['readonly_mode']:
  79. await asyncio.wait_for(
  80. conn.execute("SET default_transaction_read_only = on"),
  81. timeout=timeout
  82. )
  83. # 执行EXPLAIN验证SQL
  84. await asyncio.wait_for(
  85. conn.execute(f"EXPLAIN {sql}"),
  86. timeout=timeout
  87. )
  88. execution_time = time.time() - start_time
  89. self.logger.debug(f"SQL验证成功: {sql[:50]}... ({execution_time:.3f}s)")
  90. return SQLValidationResult(
  91. sql=sql,
  92. valid=True,
  93. execution_time=execution_time,
  94. retry_count=retry_count
  95. )
  96. except asyncio.TimeoutError:
  97. execution_time = time.time() - start_time
  98. error_msg = f"验证超时({self.config['validation_timeout']}秒)"
  99. self.logger.warning(f"SQL验证超时: {sql[:50]}...")
  100. return SQLValidationResult(
  101. sql=sql,
  102. valid=False,
  103. error_message=error_msg,
  104. execution_time=execution_time,
  105. retry_count=retry_count
  106. )
  107. except Exception as e:
  108. execution_time = time.time() - start_time
  109. error_msg = str(e)
  110. # 检查是否需要重试
  111. max_retries = self.config['max_retry_count']
  112. if retry_count < max_retries and self._should_retry(e):
  113. self.logger.debug(f"SQL验证失败,重试 {retry_count + 1}/{max_retries}: {error_msg}")
  114. await asyncio.sleep(0.5) # 短暂等待后重试
  115. return await self.validate_sql(sql, retry_count + 1)
  116. self.logger.debug(f"SQL验证失败: {sql[:50]}... - {error_msg}")
  117. return SQLValidationResult(
  118. sql=sql,
  119. valid=False,
  120. error_message=error_msg,
  121. execution_time=execution_time,
  122. retry_count=retry_count
  123. )
  124. async def validate_sqls_batch(self, sqls: List[str]) -> List[SQLValidationResult]:
  125. """
  126. 批量验证SQL语句
  127. Args:
  128. sqls: SQL语句列表
  129. Returns:
  130. 验证结果列表
  131. """
  132. if not sqls:
  133. return []
  134. max_concurrent = self.config['max_concurrent_validations']
  135. semaphore = asyncio.Semaphore(max_concurrent)
  136. async def validate_with_semaphore(sql):
  137. async with semaphore:
  138. return await self.validate_sql(sql)
  139. self.logger.info(f"开始批量验证 {len(sqls)} 个SQL语句 (并发度: {max_concurrent})")
  140. # 并发执行验证
  141. tasks = [validate_with_semaphore(sql) for sql in sqls]
  142. results = await asyncio.gather(*tasks, return_exceptions=True)
  143. # 处理异常结果
  144. validated_results = []
  145. for i, result in enumerate(results):
  146. if isinstance(result, Exception):
  147. self.logger.error(f"SQL验证任务异常: {sqls[i][:50]}... - {result}")
  148. validated_results.append(SQLValidationResult(
  149. sql=sqls[i],
  150. valid=False,
  151. error_message=f"验证任务异常: {str(result)}"
  152. ))
  153. else:
  154. validated_results.append(result)
  155. return validated_results
  156. def _should_retry(self, error: Exception) -> bool:
  157. """
  158. 判断是否应该重试
  159. Args:
  160. error: 异常对象
  161. Returns:
  162. 是否应该重试
  163. """
  164. # 一般网络或连接相关的错误可以重试
  165. retry_indicators = [
  166. "connection",
  167. "network",
  168. "timeout",
  169. "server closed",
  170. "pool",
  171. ]
  172. error_str = str(error).lower()
  173. return any(indicator in error_str for indicator in retry_indicators)
  174. def calculate_stats(self, results: List[SQLValidationResult]) -> ValidationStats:
  175. """
  176. 计算验证统计信息
  177. Args:
  178. results: 验证结果列表
  179. Returns:
  180. ValidationStats: 统计信息
  181. """
  182. total_sqls = len(results)
  183. valid_sqls = sum(1 for r in results if r.valid)
  184. invalid_sqls = total_sqls - valid_sqls
  185. total_time = sum(r.execution_time for r in results)
  186. avg_time = total_time / total_sqls if total_sqls > 0 else 0.0
  187. total_retries = sum(r.retry_count for r in results)
  188. # 计算修复统计
  189. repair_attempted = sum(1 for r in results if r.repair_attempted)
  190. repair_successful = sum(1 for r in results if r.repair_successful)
  191. repair_failed = repair_attempted - repair_successful
  192. return ValidationStats(
  193. total_sqls=total_sqls,
  194. valid_sqls=valid_sqls,
  195. invalid_sqls=invalid_sqls,
  196. total_time=total_time,
  197. avg_time_per_sql=avg_time,
  198. retry_count=total_retries,
  199. repair_attempted=repair_attempted,
  200. repair_successful=repair_successful,
  201. repair_failed=repair_failed
  202. )