load_file.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import logging
  4. import sys
  5. import os
  6. import pandas as pd
  7. import psycopg2
  8. from datetime import datetime
  9. import csv
  10. from dags.config import PG_CONFIG
  11. # 配置日志记录器
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  15. handlers=[
  16. logging.StreamHandler(sys.stdout)
  17. ]
  18. )
  19. logger = logging.getLogger("load_file")
  20. def get_pg_conn():
  21. """获取PostgreSQL连接"""
  22. return psycopg2.connect(**PG_CONFIG)
  23. def get_table_columns(table_name):
  24. """
  25. 获取表的列信息,包括列名和注释
  26. 返回:
  27. dict: {列名: 列注释} 的字典
  28. """
  29. conn = get_pg_conn()
  30. cursor = conn.cursor()
  31. try:
  32. # 查询表列信息
  33. cursor.execute("""
  34. SELECT
  35. column_name,
  36. col_description((table_schema || '.' || table_name)::regclass::oid, ordinal_position) as column_comment
  37. FROM
  38. information_schema.columns
  39. WHERE
  40. table_name = %s
  41. ORDER BY
  42. ordinal_position
  43. """, (table_name,))
  44. columns = {}
  45. for row in cursor.fetchall():
  46. col_name = row[0]
  47. col_comment = row[1] if row[1] else col_name # 如果注释为空,使用列名
  48. columns[col_name] = col_comment
  49. return columns
  50. except Exception as e:
  51. logger.error(f"获取表 {table_name} 的列信息时出错: {str(e)}")
  52. return {}
  53. finally:
  54. cursor.close()
  55. conn.close()
  56. def match_csv_columns(csv_headers, table_columns):
  57. """
  58. 匹配CSV列名与表列名
  59. 策略:
  60. 1. 尝试通过表字段注释匹配CSV列名
  61. 2. 尝试通过名称直接匹配
  62. 参数:
  63. csv_headers (list): CSV文件的列名列表
  64. table_columns (dict): {列名: 列注释} 的字典
  65. 返回:
  66. dict: {CSV列名: 表列名} 的映射字典
  67. """
  68. mapping = {}
  69. # 通过注释匹配
  70. comment_to_column = {comment: col for col, comment in table_columns.items()}
  71. for header in csv_headers:
  72. if header in comment_to_column:
  73. mapping[header] = comment_to_column[header]
  74. continue
  75. # 尝试直接名称匹配
  76. if header in table_columns:
  77. mapping[header] = header
  78. return mapping
  79. def load_csv_to_table(csv_file, table_name, execution_mode='append'):
  80. """
  81. 将CSV文件数据加载到目标表
  82. 参数:
  83. csv_file (str): CSV文件路径
  84. table_name (str): 目标表名
  85. execution_mode (str): 执行模式,'append'或'full_refresh'
  86. 返回:
  87. bool: 成功返回True,失败返回False
  88. """
  89. conn = None
  90. try:
  91. # 读取CSV文件,尝试自动检测编码
  92. try:
  93. df = pd.read_csv(csv_file, encoding='utf-8')
  94. except UnicodeDecodeError:
  95. try:
  96. df = pd.read_csv(csv_file, encoding='gbk')
  97. except UnicodeDecodeError:
  98. df = pd.read_csv(csv_file, encoding='latin1')
  99. logger.info(f"成功读取CSV文件: {csv_file}, 共 {len(df)} 行")
  100. # 获取CSV列名
  101. csv_headers = df.columns.tolist()
  102. logger.info(f"CSV列名: {csv_headers}")
  103. # 获取表结构
  104. table_columns = get_table_columns(table_name)
  105. if not table_columns:
  106. logger.error(f"无法获取表 {table_name} 的列信息")
  107. return False
  108. logger.info(f"表 {table_name} 的列信息: {table_columns}")
  109. # 匹配CSV列与表列
  110. column_mapping = match_csv_columns(csv_headers, table_columns)
  111. logger.info(f"列映射关系: {column_mapping}")
  112. if not column_mapping:
  113. logger.error(f"无法建立CSV列与表列的映射关系")
  114. return False
  115. # 筛选和重命名列
  116. df_mapped = df[list(column_mapping.keys())].rename(columns=column_mapping)
  117. # 连接数据库
  118. conn = get_pg_conn()
  119. cursor = conn.cursor()
  120. # 根据执行模式确定操作
  121. if execution_mode == 'full_refresh':
  122. # 如果是全量刷新,先清空表
  123. logger.info(f"执行全量刷新,清空表 {table_name}")
  124. cursor.execute(f"TRUNCATE TABLE {table_name}")
  125. # 构建INSERT语句
  126. columns = ', '.join(df_mapped.columns)
  127. placeholders = ', '.join(['%s'] * len(df_mapped.columns))
  128. insert_sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  129. # 批量插入数据
  130. rows = [tuple(row) for row in df_mapped.values]
  131. cursor.executemany(insert_sql, rows)
  132. # 提交事务
  133. conn.commit()
  134. logger.info(f"成功插入 {len(rows)} 行数据到表 {table_name}")
  135. return True
  136. except Exception as e:
  137. logger.error(f"加载CSV数据到表时出错: {str(e)}")
  138. if conn:
  139. conn.rollback()
  140. return False
  141. finally:
  142. if conn:
  143. conn.close()
  144. def run(table_name, execution_mode='append', exec_date=None, target_type=None,
  145. storage_location=None, frequency=None, **kwargs):
  146. """
  147. 统一入口函数,符合Airflow动态脚本调用规范
  148. 参数:
  149. table_name (str): 要处理的表名
  150. execution_mode (str): 执行模式 (append/full_refresh)
  151. exec_date: 执行日期
  152. target_type: 目标类型,对于CSV文件应为'structure'
  153. storage_location: CSV文件路径
  154. frequency: 更新频率
  155. **kwargs: 其他可能的参数
  156. 返回:
  157. bool: 执行成功返回True,否则返回False
  158. """
  159. logger.info(f"===== 开始执行CSV文件加载 =====")
  160. logger.info(f"表名: {table_name}")
  161. logger.info(f"执行模式: {execution_mode}")
  162. logger.info(f"执行日期: {exec_date}")
  163. logger.info(f"目标类型: {target_type}")
  164. logger.info(f"文件路径: {storage_location}")
  165. logger.info(f"更新频率: {frequency}")
  166. # 记录其他参数
  167. for key, value in kwargs.items():
  168. logger.info(f"其他参数 - {key}: {value}")
  169. # 检查必要参数
  170. if not storage_location:
  171. logger.error("未提供CSV文件路径")
  172. return False
  173. # 检查文件是否存在
  174. if not os.path.exists(storage_location):
  175. logger.error(f"CSV文件不存在: {storage_location}")
  176. return False
  177. # 记录执行开始时间
  178. start_time = datetime.now()
  179. try:
  180. # 加载CSV数据到表
  181. result = load_csv_to_table(storage_location, table_name, execution_mode)
  182. # 记录执行结束时间
  183. end_time = datetime.now()
  184. duration = (end_time - start_time).total_seconds()
  185. if result:
  186. logger.info(f"CSV文件加载成功,耗时: {duration:.2f}秒")
  187. else:
  188. logger.error(f"CSV文件加载失败,耗时: {duration:.2f}秒")
  189. return result
  190. except Exception as e:
  191. # 记录执行结束时间
  192. end_time = datetime.now()
  193. duration = (end_time - start_time).total_seconds()
  194. logger.error(f"CSV文件加载过程中出错: {str(e)}")
  195. logger.error(f"CSV文件加载失败,耗时: {duration:.2f}秒")
  196. return False
  197. finally:
  198. logger.info(f"===== CSV文件加载执行完成 =====")
  199. if __name__ == "__main__":
  200. # 直接执行时的测试代码
  201. import argparse
  202. parser = argparse.ArgumentParser(description='从CSV文件加载数据到表')
  203. parser.add_argument('--table', type=str, required=True, help='目标表名')
  204. parser.add_argument('--file', type=str, required=True, help='CSV文件路径')
  205. parser.add_argument('--mode', type=str, default='append', help='执行模式: append或full_refresh')
  206. args = parser.parse_args()
  207. success = run(
  208. table_name=args.table,
  209. execution_mode=args.mode,
  210. storage_location=args.file,
  211. target_type='structure'
  212. )
  213. if success:
  214. print("CSV文件加载成功")
  215. sys.exit(0)
  216. else:
  217. print("CSV文件加载失败")
  218. sys.exit(1)