Преглед изворни кода

load_file.py增加了支持xls/xlsx格式的文件。

wangxq пре 2 месеци
родитељ
комит
ab4896e331
2 измењених фајлова са 195 додато и 121 уклоњено
  1. 192 121
      dataops/scripts/load_file.py
  2. 3 0
      requirements.txt

+ 192 - 121
dataops/scripts/load_file.py

@@ -9,6 +9,7 @@ from datetime import datetime
 import csv
 import glob
 import shutil
+import re
 
 # 配置日志记录器
 logging.basicConfig(
@@ -102,136 +103,164 @@ def get_table_columns(table_name):
         cursor.close()
         conn.close()
 
-def match_csv_columns(csv_headers, table_columns):
+def match_file_columns(file_headers, table_columns):
     """
-    匹配CSV列名与表列名
+    匹配文件列名与表列名
     
     策略:
-    1. 尝试通过表字段注释匹配CSV列名 (忽略大小写和空格)
+    1. 尝试通过表字段注释匹配文件列名 (忽略大小写和空格)
     2. 尝试通过名称直接匹配 (忽略大小写和空格)
     
     参数:
-        csv_headers (list): CSV文件的列名列表
+        file_headers (list): 文件的列名列表
         table_columns (dict): {数据库列名: 列注释} 的字典
     
     返回:
-        dict: {CSV列名: 数据库列名} 的映射字典
+        dict: {文件列名: 数据库列名} 的映射字典
     """
     mapping = {}
     matched_table_cols = set()
 
     # 数据库列名通常不区分大小写(除非加引号),注释可能区分
-    # 为了匹配更健壮,我们将CSV和数据库列名/注释都转为小写处理
+    # 为了匹配更健壮,我们将文件和数据库列名/注释都转为小写处理
     processed_table_columns_lower = {col.lower(): col for col in table_columns.keys()}
     processed_comment_to_column_lower = {
         str(comment).lower(): col
         for col, comment in table_columns.items() if comment
     }
 
-    # 预处理 CSV headers
-    processed_csv_headers_lower = {str(header).lower(): header for header in csv_headers}
+    # 预处理文件headers
+    processed_file_headers_lower = {str(header).lower(): header for header in file_headers}
 
     # 1. 通过注释匹配 (忽略大小写)
-    for processed_header, original_header in processed_csv_headers_lower.items():
+    for processed_header, original_header in processed_file_headers_lower.items():
         if processed_header in processed_comment_to_column_lower:
             table_col_original_case = processed_comment_to_column_lower[processed_header]
             if table_col_original_case not in matched_table_cols:
                 mapping[original_header] = table_col_original_case
                 matched_table_cols.add(table_col_original_case)
-                logger.info(f"通过注释匹配: CSV 列 '{original_header}' -> 表列 '{table_col_original_case}'")
+                logger.info(f"通过注释匹配: 文件列 '{original_header}' -> 表列 '{table_col_original_case}'")
 
     # 2. 通过名称直接匹配 (忽略大小写),仅匹配尚未映射的列
-    for processed_header, original_header in processed_csv_headers_lower.items():
-         if original_header not in mapping: # 仅当此 CSV 列尚未映射时才进行名称匹配
+    for processed_header, original_header in processed_file_headers_lower.items():
+         if original_header not in mapping: # 仅当此文件列尚未映射时才进行名称匹配
             if processed_header in processed_table_columns_lower:
                 table_col_original_case = processed_table_columns_lower[processed_header]
                 if table_col_original_case not in matched_table_cols:
                     mapping[original_header] = table_col_original_case
                     matched_table_cols.add(table_col_original_case)
-                    logger.info(f"通过名称匹配: CSV 列 '{original_header}' -> 表列 '{table_col_original_case}'")
+                    logger.info(f"通过名称匹配: 文件列 '{original_header}' -> 表列 '{table_col_original_case}'")
 
-    unmapped_csv = [h for h in csv_headers if h not in mapping]
-    if unmapped_csv:
-         logger.warning(f"以下 CSV 列未能匹配到表列: {unmapped_csv}")
+    unmapped_file = [h for h in file_headers if h not in mapping]
+    if unmapped_file:
+         logger.warning(f"以下文件列未能匹配到表列: {unmapped_file}")
 
     unmapped_table = [col for col in table_columns if col not in matched_table_cols]
     if unmapped_table:
-        logger.warning(f"以下表列未能匹配到 CSV 列: {unmapped_table}")
+        logger.warning(f"以下表列未能匹配到文件列: {unmapped_table}")
 
     return mapping
 
-def load_csv_to_table(csv_file, table_name, execution_mode='append'):
+def read_excel_file(excel_file, sheet_name=0):
     """
-    将单个CSV文件数据加载到目标表
+    读取Excel文件内容
+    
+    参数:
+        excel_file (str): Excel文件路径
+        sheet_name: 工作表名称或索引,默认为第一个工作表
+    
+    返回:
+        pandas.DataFrame: 读取的数据
+    """
+    try:
+        # 尝试读取Excel文件,使用dtype=str确保所有列按字符串读取
+        df = pd.read_excel(excel_file, sheet_name=sheet_name, keep_default_na=False, 
+                          na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
+        logger.info(f"成功读取Excel文件: {os.path.basename(excel_file)}, 共 {len(df)} 行")
+        return df
+    except Exception as e:
+        logger.error(f"读取Excel文件 {excel_file} 时发生错误: {str(e)}")
+        raise
+
+def read_csv_file(csv_file):
+    """
+    读取CSV文件内容,尝试自动检测编码
     
     参数:
         csv_file (str): CSV文件路径
-        table_name (str): 目标表名 (大小写可能敏感,取决于数据库)
-        execution_mode (str): 执行模式,'append'或'full_refresh'
     
     返回:
-        bool: 成功返回True,失败返回False
+        pandas.DataFrame: 读取的数据
     """
-    conn = None
-    cursor = None # 初始化 cursor
-    logger.info(f"开始处理文件: {csv_file}")
     try:
-        # 读取CSV文件,尝试自动检测编码
+        # 使用 dtype=str 确保所有列按字符串读取,避免类型推断问题
+        df = pd.read_csv(csv_file, encoding='utf-8', keep_default_na=False, 
+                        na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
+        return df
+    except UnicodeDecodeError:
         try:
-            # 使用 dtype=str 确保所有列按字符串读取,避免类型推断问题,特别是对于ID类字段
-            df = pd.read_csv(csv_file, encoding='utf-8', keep_default_na=False, na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
+            logger.warning(f"UTF-8 读取失败,尝试 GBK: {csv_file}")
+            df = pd.read_csv(csv_file, encoding='gbk', keep_default_na=False, 
+                            na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
+            return df
         except UnicodeDecodeError:
-            try:
-                logger.warning(f"UTF-8 读取失败,尝试 GBK: {csv_file}")
-                df = pd.read_csv(csv_file, encoding='gbk', keep_default_na=False, na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
-            except UnicodeDecodeError:
-                logger.warning(f"GBK 读取也失败,尝试 latin1: {csv_file}")
-                df = pd.read_csv(csv_file, encoding='latin1', keep_default_na=False, na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
-        except Exception as read_err:
-             logger.error(f"读取 CSV 文件 {csv_file} 时发生未知错误: {str(read_err)}")
-             return False
-
-        logger.info(f"成功读取CSV文件: {os.path.basename(csv_file)}, 共 {len(df)} 行")
-        
+            logger.warning(f"GBK 读取也失败,尝试 latin1: {csv_file}")
+            df = pd.read_csv(csv_file, encoding='latin1', keep_default_na=False, 
+                            na_values=[r'\N', '', 'NULL', 'null'], dtype=str)
+            return df
+    except Exception as e:
+        logger.error(f"读取CSV文件 {csv_file} 时发生错误: {str(e)}")
+        raise
+
+def load_dataframe_to_table(df, file_path, table_name):
+    """
+    将DataFrame数据加载到目标表
+    
+    参数:
+        df (pandas.DataFrame): 需要加载的数据
+        file_path (str): 源文件路径(仅用于日志记录)
+        table_name (str): 目标表名
+    
+    返回:
+        bool: 成功返回True,失败返回False
+    """
+    conn = None
+    cursor = None
+    try:
         # 清理列名中的潜在空白符
         df.columns = df.columns.str.strip()
 
-        # 如果CSV为空,则直接认为成功并返回
+        # 如果DataFrame为空,则直接认为成功并返回
         if df.empty:
-             logger.info(f"CSV 文件 {csv_file} 为空,无需加载数据。")
-             return True
+            logger.info(f"文件 {file_path} 为空,无需加载数据。")
+            return True
 
-        # 获取CSV列名 (清理后)
-        csv_headers = df.columns.tolist()
-        logger.info(f"清理后的 CSV 列名: {csv_headers}")
+        # 获取文件列名(清理后)
+        file_headers = df.columns.tolist()
+        logger.info(f"清理后的文件列名: {file_headers}")
         
         # 获取表结构
         table_columns = get_table_columns(table_name)
         if not table_columns:
-            logger.error(f"无法获取表 '{table_name}' 的列信息,跳过文件 {csv_file}")
+            logger.error(f"无法获取表 '{table_name}' 的列信息,跳过文件 {file_path}")
             return False
         
-        logger.info(f"表 '{table_name}' 的列信息 (列名: 注释): {table_columns}")
-        
-        # 匹配CSV列与表列
-        column_mapping = match_csv_columns(csv_headers, table_columns)
-        logger.info(f"列映射关系 (CSV列名: 表列名): {column_mapping}")
+        # 匹配文件列与表列
+        column_mapping = match_file_columns(file_headers, table_columns)
         
         # 检查是否有任何列成功映射
         if not column_mapping:
-            logger.error(f"文件 {csv_file} 的列无法与表 '{table_name}' 的列建立任何映射关系,跳过此文件。")
-            return False # 如果一个都没匹配上,则认为失败
+            logger.error(f"文件 {file_path} 的列无法与表 '{table_name}' 的列建立任何映射关系,跳过此文件。")
+            return False
         
         # 仅选择成功映射的列进行加载
-        mapped_csv_headers = list(column_mapping.keys())
-        # 使用 .copy() 避免 SettingWithCopyWarning
-        df_mapped = df[mapped_csv_headers].copy()
+        mapped_file_headers = list(column_mapping.keys())
+        # 避免SettingWithCopyWarning
+        df_mapped = df[mapped_file_headers].copy()
         df_mapped.rename(columns=column_mapping, inplace=True)
         logger.info(f"将加载以下映射后的列: {df_mapped.columns.tolist()}")
 
-        # 将空字符串 '' 替换为 None,以便插入数据库时为 NULL
-        # 使用 map 替代已废弃的 applymap 方法
-        # 对每一列单独应用 map 函数
+        # 将空字符串替换为None
         for col in df_mapped.columns:
             df_mapped[col] = df_mapped[col].map(lambda x: None if isinstance(x, str) and x == '' else x)
 
@@ -239,55 +268,75 @@ def load_csv_to_table(csv_file, table_name, execution_mode='append'):
         conn = get_pg_conn()
         cursor = conn.cursor()
         
-        # 根据执行模式确定操作 - 注意:full_refresh 在 run 函数层面控制,这里仅处理单个文件追加
-        # if execution_mode == 'full_refresh':
-        #     logger.warning(f"在 load_csv_to_table 中收到 full_refresh,但清空操作应在 run 函数完成。此处按 append 处理文件:{csv_file}")
-            # # 如果是全量刷新,先清空表 - 这个逻辑移到 run 函数
-            # logger.info(f"执行全量刷新,清空表 {table_name}")
-            # cursor.execute(f"TRUNCATE TABLE {table_name}")
-        
         # 构建INSERT语句
-        # 使用原始大小写的数据库列名(从 column_mapping 的 value 获取)并加引号
         columns = ', '.join([f'"{col}"' for col in df_mapped.columns])
         placeholders = ', '.join(['%s'] * len(df_mapped.columns))
-        # 假设表在 public schema,并为表名加引号以处理大小写或特殊字符
         insert_sql = f'INSERT INTO public."{table_name}" ({columns}) VALUES ({placeholders})'
         
         # 批量插入数据
-        # df_mapped.values 会产生 numpy array,需要转换为 list of tuples
-        # 确保 None 值正确传递
         rows = [tuple(row) for row in df_mapped.values]
         
         try:
             cursor.executemany(insert_sql, rows)
             conn.commit()
-            logger.info(f"成功将文件 {os.path.basename(csv_file)} 的 {len(rows)} 行数据插入到表 '{table_name}'")
+            logger.info(f"成功将文件 {os.path.basename(file_path)} 的 {len(rows)} 行数据插入到表 '{table_name}'")
+            return True
         except Exception as insert_err:
-             logger.error(f"向表 '{table_name}' 插入数据时出错: {str(insert_err)}")
-             logger.error(f"出错的 SQL 语句大致为: {insert_sql}")
-             # 可以考虑记录前几行出错的数据 (注意隐私和日志大小)
-             try:
-                 logger.error(f"出错的前3行数据 (部分): {rows[:3]}")
-             except: pass # 防御性编程
-             conn.rollback() # 回滚事务
-             return False # 插入失败则返回 False
-
-        
-        return True
-    except pd.errors.EmptyDataError:
-         logger.info(f"CSV 文件 {csv_file} 为空或只有表头,无需加载数据。")
-         return True # 空文件视为成功处理
+            logger.error(f"向表 '{table_name}' 插入数据时出错: {str(insert_err)}")
+            logger.error(f"出错的 SQL 语句大致为: {insert_sql}")
+            try:
+                logger.error(f"出错的前3行数据 (部分): {rows[:3]}")
+            except:
+                pass
+            conn.rollback()
+            return False
+            
     except Exception as e:
-        # 使用 exc_info=True 获取更详细的堆栈跟踪信息
-        logger.error(f"处理文件 {csv_file} 加载到表 '{table_name}' 时发生意外错误", exc_info=True)
+        logger.error(f"处理文件 {file_path} 加载到表 '{table_name}' 时发生意外错误", exc_info=True)
         if conn:
             conn.rollback()
         return False
     finally:
         if cursor:
-             cursor.close()
+            cursor.close()
         if conn:
-             conn.close()
+            conn.close()
+
+def load_file_to_table(file_path, table_name, execution_mode='append'):
+    """
+    根据文件类型,加载文件数据到目标表
+    
+    参数:
+        file_path (str): 文件路径
+        table_name (str): 目标表名
+        execution_mode (str): 执行模式,'append'或'full_refresh'
+    
+    返回:
+        bool: 成功返回True,失败返回False
+    """
+    logger.info(f"开始处理文件: {file_path}")
+    try:
+        file_extension = os.path.splitext(file_path)[1].lower()
+        
+        # 根据文件扩展名选择合适的加载方法
+        if file_extension == '.csv':
+            # CSV文件处理
+            df = read_csv_file(file_path)
+            return load_dataframe_to_table(df, file_path, table_name)
+        elif file_extension in ['.xlsx', '.xls']:
+            # Excel文件处理
+            df = read_excel_file(file_path)
+            return load_dataframe_to_table(df, file_path, table_name)
+        else:
+            logger.error(f"不支持的文件类型: {file_extension},文件: {file_path}")
+            return False
+            
+    except pd.errors.EmptyDataError:
+        logger.info(f"文件 {file_path} 为空或只有表头,无需加载数据。")
+        return True
+    except Exception as e:
+        logger.error(f"处理文件 {file_path} 时发生意外错误", exc_info=True)
+        return False
 
 def run(table_name, execution_mode='append', exec_date=None, target_type=None, 
         storage_location=None, frequency=None, script_name=None, **kwargs):
@@ -334,30 +383,56 @@ def run(table_name, execution_mode='append', exec_date=None, target_type=None,
         storage_location = storage_location.lstrip('/')
         logger.info(f"检测到storage_location以斜杠开头,已移除: {storage_location}")
     
-    full_search_pattern = os.path.normpath(os.path.join(STRUCTURE_UPLOAD_BASE_PATH, storage_location))
-    logger.info(f"完整文件搜索模式: {full_search_pattern}")
+    # 检查storage_location是否包含扩展名
+    has_extension = bool(re.search(r'\.[a-zA-Z0-9]+$', storage_location))
     
-    # 检查路径是否存在(至少目录部分)
-    search_dir = os.path.dirname(full_search_pattern)
-    if not os.path.exists(search_dir):
-        error_msg = f"错误: 搜索目录不存在: {search_dir}"
-        logger.error(error_msg)
-        raise FileNotFoundError(error_msg)  # 抛出异常而不是返回False
+    full_search_patterns = []
+    if has_extension:
+        # 如果指定了扩展名,使用原始模式
+        full_search_patterns.append(os.path.normpath(os.path.join(STRUCTURE_UPLOAD_BASE_PATH, storage_location)))
+    else:
+        # 如果没有指定扩展名,自动添加所有支持的扩展名
+        base_pattern = storage_location.rstrip('/')
+        if base_pattern.endswith('*'):
+            # 如果已经以*结尾,添加扩展名
+            for ext in ['.csv', '.xlsx', '.xls']:
+                pattern = os.path.normpath(os.path.join(STRUCTURE_UPLOAD_BASE_PATH, f"{base_pattern}{ext}"))
+                full_search_patterns.append(pattern)
+        else:
+            # 如果不以*结尾,添加/*.扩展名
+            if not base_pattern.endswith('/*'):
+                base_pattern = f"{base_pattern}/*"
+            for ext in ['.csv', '.xlsx', '.xls']:
+                pattern = os.path.normpath(os.path.join(STRUCTURE_UPLOAD_BASE_PATH, f"{base_pattern}{ext}"))
+                full_search_patterns.append(pattern)
+    
+    logger.info(f"完整文件搜索模式: {full_search_patterns}")
     
     # 查找匹配的文件
-    try:
-        # 增加 recursive=True 如果需要递归查找子目录中的文件 (例如 storage_location 是 a/b/**/*.csv)
-        # 当前假设模式只在指定目录下匹配,例如 /data/subdir/*.csv
-        found_files = glob.glob(full_search_pattern, recursive=False)
-    except Exception as glob_err:
-         logger.error(f"查找文件时发生错误 (模式: {full_search_pattern}): {str(glob_err)}")
-         raise  # 重新抛出异常
-
+    found_files = []
+    for search_pattern in full_search_patterns:
+        # 检查目录是否存在
+        search_dir = os.path.dirname(search_pattern)
+        if not os.path.exists(search_dir):
+            logger.warning(f"搜索目录不存在: {search_dir}")
+            continue
+            
+        try:
+            # 查找匹配的文件
+            matching_files = glob.glob(search_pattern, recursive=False)
+            if matching_files:
+                found_files.extend(matching_files)
+                logger.info(f"在模式 {search_pattern} 下找到 {len(matching_files)} 个文件")
+        except Exception as glob_err:
+            logger.error(f"查找文件时发生错误 (模式: {search_pattern}): {str(glob_err)}")
+            # 继续查找其他模式
+    
     if not found_files:
-        logger.warning(f"在目录 {search_dir} 下未找到匹配模式 '{os.path.basename(full_search_pattern)}' 的文件")
+        logger.warning(f"使用搜索模式 {full_search_patterns} 未找到任何匹配文件")
         return True  # 找不到文件视为正常情况,返回成功
 
-    logger.info(f"找到 {len(found_files)} 个匹配文件: {found_files}")
+    found_files = list(set(found_files))  # 去重
+    logger.info(f"总共找到 {len(found_files)} 个匹配文件: {found_files}")
 
     # 如果是全量刷新,在处理任何文件前清空表
     if execution_mode == 'full_refresh':
@@ -384,7 +459,6 @@ def run(table_name, execution_mode='append', exec_date=None, target_type=None,
 
     # 处理并归档每个找到的文件
     processed_files_count = 0
-    failed_files = []
 
     for file_path in found_files:
         file_start_time = datetime.now()
@@ -392,16 +466,15 @@ def run(table_name, execution_mode='append', exec_date=None, target_type=None,
         normalized_file_path = os.path.normpath(file_path)
         logger.info(f"--- 开始处理文件: {os.path.basename(normalized_file_path)} ---")
         try:
-            # 加载CSV数据到表 (注意:full_refresh时也是append模式加载,因为表已清空)
-            load_success = load_csv_to_table(normalized_file_path, table_name, 'append')
+            # 根据文件类型加载数据到表
+            load_success = load_file_to_table(normalized_file_path, table_name, 'append')
             
             if load_success:
                 logger.info(f"文件 {os.path.basename(normalized_file_path)} 加载成功。")
                 processed_files_count += 1
                 # 归档文件
                 try:
-                    # 计算相对路径部分 (storage_location 可能包含子目录)
-                    # 使用 os.path.dirname 获取 storage_location 的目录部分
+                    # 计算相对路径部分
                     relative_dir = os.path.dirname(storage_location)
                     
                     # 获取当前日期
@@ -434,8 +507,7 @@ def run(table_name, execution_mode='append', exec_date=None, target_type=None,
                     logger.error(f"原始文件路径: {normalized_file_path}")
             else:
                 logger.error(f"文件 {os.path.basename(normalized_file_path)} 加载失败,中止处理。")
-                # 修改:任何一个文件加载失败就直接返回 False
-                # 记录最终统计
+                # 任何一个文件加载失败就直接返回 False
                 overall_end_time = datetime.now()
                 overall_duration = (overall_end_time - overall_start_time).total_seconds()
                 logger.info(f"===== {script_name} 执行完成 (失败) =====")
@@ -447,8 +519,7 @@ def run(table_name, execution_mode='append', exec_date=None, target_type=None,
 
         except Exception as file_proc_err:
             logger.error(f"处理文件 {os.path.basename(normalized_file_path)} 时发生意外错误", exc_info=True)
-            # 修改:任何一个文件处理异常就直接返回 False
-            # 记录最终统计
+            # 任何一个文件处理异常就直接返回 False
             overall_end_time = datetime.now()
             overall_duration = (overall_end_time - overall_start_time).total_seconds()
             logger.info(f"===== {script_name} 执行完成 (异常) =====")
@@ -478,9 +549,9 @@ if __name__ == "__main__":
     # 直接执行时的测试代码
     import argparse
     
-    parser = argparse.ArgumentParser(description='从CSV文件加载数据到表(支持通配符)')
+    parser = argparse.ArgumentParser(description='从CSV或Excel文件加载数据到表(支持通配符)')
     parser.add_argument('--table', type=str, required=True, help='目标表名')
-    parser.add_argument('--pattern', type=str, required=True, help='CSV文件查找模式 (相对于基准上传路径的相对路径,例如: data/*.csv 或 *.csv)')
+    parser.add_argument('--pattern', type=str, required=True, help='文件查找模式 (相对于基准上传路径的相对路径,例如: data/*.csv 或 data/*.xlsx 或 data/*)')
     parser.add_argument('--mode', type=str, default='append', choices=['append', 'full_refresh'], help='执行模式: append 或 full_refresh')
     
     args = parser.parse_args()
@@ -501,8 +572,8 @@ if __name__ == "__main__":
     success = run(**run_kwargs)
     
     if success:
-        print("CSV文件加载任务执行完毕,所有文件处理成功。")
+        print("文件加载任务执行完毕,所有文件处理成功。")
         sys.exit(0)
     else:
-        print("CSV文件加载任务执行完毕,但有部分或全部文件处理失败。")
+        print("文件加载任务执行完毕,但有部分或全部文件处理失败。")
         sys.exit(1)

+ 3 - 0
requirements.txt

@@ -4,3 +4,6 @@ neo4j>=5.19.0
 pendulum>=3.0.0
 networkx>=3.4.2
 pandas>=2.2.3
+xlrd>=2.0.1
+openpyxl>=3.1.5
+