Преглед на файлове

已经修复了data_pipeline API的问题,支持truncate_vector_tables和backup_vector_tables两个参数。

wangxq преди 1 месец
родител
ревизия
ef0a85b5d9

+ 26 - 1
data_pipeline/qa_generation/qs_agent.py

@@ -73,7 +73,10 @@ class QuestionSQLGenerationAgent:
         try:
             self.logger.info("🚀 开始生成Question-SQL训练数据")
             
-            # 1. 验证文件数量
+            # 1. 重命名现有文件
+            await self._rename_existing_files()
+            
+            # 2. 验证文件数量
             self.logger.info("📋 验证文件数量...")
             validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
             
@@ -167,6 +170,28 @@ class QuestionSQLGenerationAgent:
             
             raise
     
+    async def _rename_existing_files(self):
+        """重命名现有的输出文件"""
+        try:
+            # 查找现有的 *_pair.json 文件
+            pair_files = list(self.output_dir.glob("*_pair.json"))
+            
+            for pair_file in pair_files:
+                old_name = f"{pair_file}_old"
+                pair_file.rename(old_name)
+                self.logger.info(f"重命名文件: {pair_file.name} → {Path(old_name).name}")
+            
+            # 查找现有的 backup 文件
+            backup_files = list(self.output_dir.glob("*_pair.json.backup"))
+            
+            for backup_file in backup_files:
+                old_name = f"{backup_file}_old"
+                backup_file.rename(old_name)
+                self.logger.info(f"重命名备份文件: {backup_file.name} → {Path(old_name).name}")
+                
+        except Exception as e:
+            self.logger.warning(f"重命名现有文件时出错: {e}")
+
     def _initialize_llm_components(self):
         """初始化LLM相关组件"""
         if not self.vn:

+ 24 - 0
data_pipeline/trainer/run_training.py

@@ -262,6 +262,25 @@ def train_formatted_question_sql_pairs(formatted_file):
     
     print(f"格式化问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(pairs)} 对)")
 
+def _is_valid_training_file(filename: str) -> bool:
+    """判断是否为有效的训练文件"""
+    import re
+    filename_lower = filename.lower()
+    
+    # 排除带数字后缀的文件
+    if re.search(r'\.(ddl|md)_\d+$', filename_lower):
+        return False
+    
+    # 排除 _old 后缀的文件
+    if filename_lower.endswith('_old'):
+        return False
+    
+    # 排除 .backup 相关文件
+    if '.backup' in filename_lower:
+        return False
+    
+    return True
+
 def train_json_question_sql_pairs(json_file):
     """训练JSON格式的问答对
     
@@ -427,6 +446,11 @@ def process_training_files(data_path, task_id=None, backup_vector_tables=False,
             
             # 根据文件类型调用相应的处理函数
             try:
+                # 检查是否为有效的训练文件
+                if not _is_valid_training_file(item):
+                    log_message(f"跳过无效训练文件: {item}")
+                    continue
+                    
                 if file_lower.endswith(".ddl"):
                     log_message(f"处理DDL文件: {item_path}")
                     train_ddl_statements(item_path)

+ 2 - 3
data_pipeline/utils/file_manager.py

@@ -101,12 +101,11 @@ class FileNameManager:
         if filename not in self.used_names:
             return filename
         
-        # 如果重名,添加数字后缀
-        base, ext = os.path.splitext(filename)
+        # 如果重名,在扩展名后添加数字后缀
         counter = 1
         
         while True:
-            unique_name = f"{base}_{counter}{ext}"
+            unique_name = f"{filename}_{counter}"
             if unique_name not in self.used_names:
                 self.logger.warning(f"文件名冲突,'{filename}' 重命名为 '{unique_name}'")
                 return unique_name

+ 5 - 1
data_pipeline/validators/sql_validate_cli.py

@@ -315,7 +315,11 @@ def resolve_input_file_and_output_dir(args):
         
         # 在任务目录中查找Question-SQL文件
         if task_dir.exists():
-            possible_files = list(task_dir.glob("*_pair.json"))
+            # 只搜索标准命名的文件,排除 _old 后缀
+            possible_files = [
+                f for f in task_dir.glob("*_pair.json") 
+                if not f.name.endswith('_old') and '.backup' not in f.name
+            ]
             if possible_files:
                 # 选择最新的文件(按修改时间排序)
                 input_file = str(max(possible_files, key=lambda f: f.stat().st_mtime))