123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import logging
- from pathlib import Path
- from typing import Dict, List, Tuple, Set
- from dataclasses import dataclass, field
- from schema_tools.utils.table_parser import TableListParser
- from schema_tools.config import SCHEMA_TOOLS_CONFIG
- @dataclass
- class ValidationResult:
- """验证结果"""
- is_valid: bool
- table_count: int
- ddl_count: int
- md_count: int
- error: str = ""
- missing_ddl: List[str] = field(default_factory=list)
- missing_md: List[str] = field(default_factory=list)
- duplicate_tables: List[str] = field(default_factory=list)
- class FileCountValidator:
- """文件数量验证器"""
-
- def __init__(self):
- self.logger = logging.getLogger("schema_tools.FileCountValidator")
- self.config = SCHEMA_TOOLS_CONFIG
-
- def validate(self, table_list_file: str, output_dir: str) -> ValidationResult:
- """
- 验证生成的文件数量是否与表数量一致
-
- Args:
- table_list_file: 表清单文件路径
- output_dir: 输出目录路径
-
- Returns:
- ValidationResult: 验证结果
- """
- try:
- # 1. 解析表清单获取表数量(自动去重)
- table_parser = TableListParser()
- tables = table_parser.parse_file(table_list_file)
- table_count = len(tables)
-
- # 获取重复信息
- unique_tables, duplicate_tables = table_parser.get_duplicate_info(table_list_file)
-
- # 2. 检查表数量限制
- max_tables = self.config['qs_generation']['max_tables']
- if table_count > max_tables:
- return ValidationResult(
- is_valid=False,
- table_count=table_count,
- ddl_count=0,
- md_count=0,
- error=f"表数量({table_count})超过限制({max_tables})。请分批处理或调整配置中的max_tables参数。",
- duplicate_tables=duplicate_tables
- )
-
- # 3. 扫描输出目录
- output_path = Path(output_dir)
- if not output_path.exists():
- return ValidationResult(
- is_valid=False,
- table_count=table_count,
- ddl_count=0,
- md_count=0,
- error=f"输出目录不存在: {output_dir}",
- duplicate_tables=duplicate_tables
- )
-
- # 4. 统计DDL和MD文件
- ddl_files = list(output_path.glob("*.ddl"))
- md_files = list(output_path.glob("*_detail.md")) # 注意文件后缀格式
-
- ddl_count = len(ddl_files)
- md_count = len(md_files)
-
- self.logger.info(f"文件统计 - 表: {table_count}, DDL: {ddl_count}, MD: {md_count}")
- if duplicate_tables:
- self.logger.info(f"表清单中存在 {len(duplicate_tables)} 个重复项")
-
- # 5. 验证数量一致性
- if ddl_count != table_count or md_count != table_count:
- # 查找缺失的文件
- missing_ddl, missing_md = self._find_missing_files(tables, ddl_files, md_files)
-
- error_parts = []
- if ddl_count != table_count:
- error_parts.append(f"DDL文件数量({ddl_count})与表数量({table_count})不一致")
- if missing_ddl:
- self.logger.error(f"缺失的DDL文件对应的表: {', '.join(missing_ddl)}")
-
- if md_count != table_count:
- error_parts.append(f"MD文件数量({md_count})与表数量({table_count})不一致")
- if missing_md:
- self.logger.error(f"缺失的MD文件对应的表: {', '.join(missing_md)}")
-
- return ValidationResult(
- is_valid=False,
- table_count=table_count,
- ddl_count=ddl_count,
- md_count=md_count,
- error="; ".join(error_parts),
- missing_ddl=missing_ddl,
- missing_md=missing_md,
- duplicate_tables=duplicate_tables
- )
-
- # 6. 验证通过
- self.logger.info(f"文件验证通过:{table_count}个表,{ddl_count}个DDL,{md_count}个MD")
-
- return ValidationResult(
- is_valid=True,
- table_count=table_count,
- ddl_count=ddl_count,
- md_count=md_count,
- duplicate_tables=duplicate_tables
- )
-
- except Exception as e:
- self.logger.exception("文件验证失败")
- return ValidationResult(
- is_valid=False,
- table_count=0,
- ddl_count=0,
- md_count=0,
- error=f"验证过程发生异常: {str(e)}"
- )
-
- def _find_missing_files(self, tables: List[str], ddl_files: List[Path], md_files: List[Path]) -> Tuple[List[str], List[str]]:
- """查找缺失的文件"""
- # 获取已生成的文件名(不含扩展名)
- ddl_names = {f.stem for f in ddl_files}
- md_names = {f.stem.replace('_detail', '') for f in md_files} # 移除_detail后缀
-
- missing_ddl = []
- missing_md = []
-
- # 为每个表建立可能的文件名映射
- table_to_filenames = self._get_table_filename_mapping(tables)
-
- # 检查每个表的文件
- for table_spec in tables:
- # 获取该表可能的文件名
- possible_filenames = table_to_filenames[table_spec]
-
- # 检查DDL文件
- ddl_exists = any(fname in ddl_names for fname in possible_filenames)
- if not ddl_exists:
- missing_ddl.append(table_spec)
-
- # 检查MD文件
- md_exists = any(fname in md_names for fname in possible_filenames)
- if not md_exists:
- missing_md.append(table_spec)
-
- return missing_ddl, missing_md
-
- def _get_table_filename_mapping(self, tables: List[str]) -> Dict[str, Set[str]]:
- """获取表名到可能的文件名的映射"""
- mapping = {}
-
- for table_spec in tables:
- # 解析表名
- if '.' in table_spec:
- schema, table = table_spec.split('.', 1)
- else:
- schema, table = 'public', table_spec
-
- # 生成可能的文件名
- possible_names = set()
-
- # 基本格式
- if schema.lower() == 'public':
- possible_names.add(table)
- else:
- possible_names.add(f"{schema}__{table}")
- possible_names.add(f"{schema}_{table}") # 兼容不同格式
-
- # 考虑特殊字符替换
- safe_name = table.replace('-', '_').replace(' ', '_')
- if safe_name != table:
- if schema.lower() == 'public':
- possible_names.add(safe_name)
- else:
- possible_names.add(f"{schema}__{safe_name}")
- possible_names.add(f"{schema}_{safe_name}")
-
- mapping[table_spec] = possible_names
-
- return mapping
|