123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import logging
- import sys
- import os
- import pandas as pd
- import psycopg2
- from datetime import datetime
- import csv
- from dags.config import PG_CONFIG
- # 配置日志记录器
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(sys.stdout)
- ]
- )
- logger = logging.getLogger("load_file")
- def get_pg_conn():
- """获取PostgreSQL连接"""
- return psycopg2.connect(**PG_CONFIG)
- def get_table_columns(table_name):
- """
- 获取表的列信息,包括列名和注释
-
- 返回:
- dict: {列名: 列注释} 的字典
- """
- conn = get_pg_conn()
- cursor = conn.cursor()
- try:
- # 查询表列信息
- cursor.execute("""
- SELECT
- column_name,
- col_description((table_schema || '.' || table_name)::regclass::oid, ordinal_position) as column_comment
- FROM
- information_schema.columns
- WHERE
- table_name = %s
- ORDER BY
- ordinal_position
- """, (table_name,))
-
- columns = {}
- for row in cursor.fetchall():
- col_name = row[0]
- col_comment = row[1] if row[1] else col_name # 如果注释为空,使用列名
- columns[col_name] = col_comment
-
- return columns
- except Exception as e:
- logger.error(f"获取表 {table_name} 的列信息时出错: {str(e)}")
- return {}
- finally:
- cursor.close()
- conn.close()
- def match_csv_columns(csv_headers, table_columns):
- """
- 匹配CSV列名与表列名
-
- 策略:
- 1. 尝试通过表字段注释匹配CSV列名
- 2. 尝试通过名称直接匹配
-
- 参数:
- csv_headers (list): CSV文件的列名列表
- table_columns (dict): {列名: 列注释} 的字典
-
- 返回:
- dict: {CSV列名: 表列名} 的映射字典
- """
- mapping = {}
-
- # 通过注释匹配
- comment_to_column = {comment: col for col, comment in table_columns.items()}
- for header in csv_headers:
- if header in comment_to_column:
- mapping[header] = comment_to_column[header]
- continue
-
- # 尝试直接名称匹配
- if header in table_columns:
- mapping[header] = header
-
- return mapping
- def load_csv_to_table(csv_file, table_name, execution_mode='append'):
- """
- 将CSV文件数据加载到目标表
-
- 参数:
- csv_file (str): CSV文件路径
- table_name (str): 目标表名
- execution_mode (str): 执行模式,'append'或'full_refresh'
-
- 返回:
- bool: 成功返回True,失败返回False
- """
- conn = None
- try:
- # 读取CSV文件,尝试自动检测编码
- try:
- df = pd.read_csv(csv_file, encoding='utf-8')
- except UnicodeDecodeError:
- try:
- df = pd.read_csv(csv_file, encoding='gbk')
- except UnicodeDecodeError:
- df = pd.read_csv(csv_file, encoding='latin1')
-
- logger.info(f"成功读取CSV文件: {csv_file}, 共 {len(df)} 行")
-
- # 获取CSV列名
- csv_headers = df.columns.tolist()
- logger.info(f"CSV列名: {csv_headers}")
-
- # 获取表结构
- table_columns = get_table_columns(table_name)
- if not table_columns:
- logger.error(f"无法获取表 {table_name} 的列信息")
- return False
-
- logger.info(f"表 {table_name} 的列信息: {table_columns}")
-
- # 匹配CSV列与表列
- column_mapping = match_csv_columns(csv_headers, table_columns)
- logger.info(f"列映射关系: {column_mapping}")
-
- if not column_mapping:
- logger.error(f"无法建立CSV列与表列的映射关系")
- return False
-
- # 筛选和重命名列
- df_mapped = df[list(column_mapping.keys())].rename(columns=column_mapping)
-
- # 连接数据库
- conn = get_pg_conn()
- cursor = conn.cursor()
-
- # 根据执行模式确定操作
- if execution_mode == 'full_refresh':
- # 如果是全量刷新,先清空表
- logger.info(f"执行全量刷新,清空表 {table_name}")
- cursor.execute(f"TRUNCATE TABLE {table_name}")
-
- # 构建INSERT语句
- columns = ', '.join(df_mapped.columns)
- placeholders = ', '.join(['%s'] * len(df_mapped.columns))
- insert_sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
-
- # 批量插入数据
- rows = [tuple(row) for row in df_mapped.values]
- cursor.executemany(insert_sql, rows)
-
- # 提交事务
- conn.commit()
- logger.info(f"成功插入 {len(rows)} 行数据到表 {table_name}")
-
- return True
- except Exception as e:
- logger.error(f"加载CSV数据到表时出错: {str(e)}")
- if conn:
- conn.rollback()
- return False
- finally:
- if conn:
- conn.close()
- def run(table_name, execution_mode='append', exec_date=None, target_type=None,
- storage_location=None, frequency=None, **kwargs):
- """
- 统一入口函数,符合Airflow动态脚本调用规范
-
- 参数:
- table_name (str): 要处理的表名
- execution_mode (str): 执行模式 (append/full_refresh)
- exec_date: 执行日期
- target_type: 目标类型,对于CSV文件应为'structure'
- storage_location: CSV文件路径
- frequency: 更新频率
- **kwargs: 其他可能的参数
-
- 返回:
- bool: 执行成功返回True,否则返回False
- """
- logger.info(f"===== 开始执行CSV文件加载 =====")
- logger.info(f"表名: {table_name}")
- logger.info(f"执行模式: {execution_mode}")
- logger.info(f"执行日期: {exec_date}")
- logger.info(f"目标类型: {target_type}")
- logger.info(f"文件路径: {storage_location}")
- logger.info(f"更新频率: {frequency}")
-
- # 记录其他参数
- for key, value in kwargs.items():
- logger.info(f"其他参数 - {key}: {value}")
-
- # 检查必要参数
- if not storage_location:
- logger.error("未提供CSV文件路径")
- return False
-
- # 检查文件是否存在
- if not os.path.exists(storage_location):
- logger.error(f"CSV文件不存在: {storage_location}")
- return False
-
- # 记录执行开始时间
- start_time = datetime.now()
-
- try:
- # 加载CSV数据到表
- result = load_csv_to_table(storage_location, table_name, execution_mode)
-
- # 记录执行结束时间
- end_time = datetime.now()
- duration = (end_time - start_time).total_seconds()
-
- if result:
- logger.info(f"CSV文件加载成功,耗时: {duration:.2f}秒")
- else:
- logger.error(f"CSV文件加载失败,耗时: {duration:.2f}秒")
-
- return result
- except Exception as e:
- # 记录执行结束时间
- end_time = datetime.now()
- duration = (end_time - start_time).total_seconds()
-
- logger.error(f"CSV文件加载过程中出错: {str(e)}")
- logger.error(f"CSV文件加载失败,耗时: {duration:.2f}秒")
-
- return False
- finally:
- logger.info(f"===== CSV文件加载执行完成 =====")
- if __name__ == "__main__":
- # 直接执行时的测试代码
- import argparse
-
- parser = argparse.ArgumentParser(description='从CSV文件加载数据到表')
- parser.add_argument('--table', type=str, required=True, help='目标表名')
- parser.add_argument('--file', type=str, required=True, help='CSV文件路径')
- parser.add_argument('--mode', type=str, default='append', help='执行模式: append或full_refresh')
-
- args = parser.parse_args()
-
- success = run(
- table_name=args.table,
- execution_mode=args.mode,
- storage_location=args.file,
- target_type='structure'
- )
-
- if success:
- print("CSV文件加载成功")
- sys.exit(0)
- else:
- print("CSV文件加载失败")
- sys.exit(1)
|