load_data.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import logging
  4. import sys
  5. import os
  6. from datetime import datetime, timedelta
  7. import sqlalchemy
  8. from sqlalchemy import create_engine, inspect, Table, Column, MetaData, text
  9. import pandas as pd
  10. import traceback
  11. import pendulum
  12. # 修改Python导入路径,确保能找到同目录下的script_utils模块
  13. current_dir = os.path.dirname(os.path.abspath(__file__))
  14. if current_dir not in sys.path:
  15. sys.path.insert(0, current_dir)
  16. # 导入脚本工具模块
  17. try:
  18. import script_utils
  19. from script_utils import get_pg_config, get_date_range, get_one_day_range, get_neo4j_driver, get_target_dt_column
  20. logger_utils = logging.getLogger("script_utils")
  21. except ImportError as e:
  22. logger_utils = None
  23. print(f"导入script_utils模块失败: {str(e)}")
  24. # 配置日志记录器
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  28. handlers=[
  29. logging.StreamHandler(sys.stdout)
  30. ]
  31. )
  32. logger = logging.getLogger("load_data")
  33. def get_source_database_info(table_name, script_name=None):
  34. """
  35. 根据表名和脚本名从Neo4j获取源数据库连接信息
  36. 参数:
  37. table_name (str): 表名
  38. script_name (str, optional): 脚本名称
  39. 返回:
  40. dict: 数据库连接信息字典
  41. 异常:
  42. Exception: 当无法获取数据库连接信息时抛出异常
  43. """
  44. logger.info(f"获取表 {table_name} 的源数据库连接信息")
  45. driver = get_neo4j_driver()
  46. try:
  47. with driver.session() as session:
  48. # 首先查询表对应的源节点
  49. query = """
  50. MATCH (target {en_name: $table_name})-[rel]->(source:DataSource)
  51. RETURN source.en_name AS source_name,
  52. source.database AS database,
  53. source.host AS host,
  54. source.port AS port,
  55. source.username AS username,
  56. source.password AS password,
  57. source.type AS db_type,
  58. source.schema AS schema,
  59. source.table AS source_table,
  60. labels(source) AS labels
  61. """
  62. result = session.run(query, table_name=table_name)
  63. record = result.single()
  64. if not record:
  65. error_msg = f"未找到表 {table_name} 对应的源节点"
  66. logger.error(error_msg)
  67. raise Exception(error_msg)
  68. # 获取源节点的数据库连接信息
  69. database_info = {
  70. "source_name": record.get("source_name"),
  71. "database": record.get("database"),
  72. "host": record.get("host"),
  73. "port": record.get("port"),
  74. "username": record.get("username"),
  75. "password": record.get("password"),
  76. "db_type": record.get("db_type"),
  77. "schema": record.get("schema"),
  78. "source_table": record.get("source_table"),
  79. "labels": record.get("labels", [])
  80. }
  81. # 检查是否包含数据库连接信息
  82. if not database_info.get("database"):
  83. error_msg = f"源节点 {database_info['source_name']} 没有数据库连接信息"
  84. logger.error(error_msg)
  85. raise Exception(error_msg)
  86. logger.info(f"成功获取表 {table_name} 的源数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  87. return database_info
  88. except Exception as e:
  89. logger.error(f"获取源数据库连接信息时出错: {str(e)}")
  90. logger.error(traceback.format_exc())
  91. raise Exception(f"获取源数据库连接信息失败: {str(e)}")
  92. finally:
  93. driver.close()
  94. def get_target_database_info():
  95. """
  96. 获取目标数据库连接信息
  97. 返回:
  98. dict: 数据库连接信息字典
  99. 异常:
  100. Exception: 无法获取目标数据库连接信息时抛出异常
  101. """
  102. logger.info("获取目标数据库连接信息")
  103. try:
  104. # 尝试从script_utils中获取PG_CONFIG
  105. pg_config = get_pg_config()
  106. if not pg_config:
  107. raise ValueError("无法获取PG_CONFIG配置")
  108. # 检查必要的配置项
  109. required_keys = ["host", "port", "user", "password", "database"]
  110. missing_keys = [key for key in required_keys if key not in pg_config]
  111. if missing_keys:
  112. raise ValueError(f"PG_CONFIG缺少必要的配置项: {', '.join(missing_keys)}")
  113. # 构建连接信息
  114. database_info = {
  115. "host": pg_config.get("host"),
  116. "port": pg_config.get("port"),
  117. "username": pg_config.get("user"),
  118. "password": pg_config.get("password"),
  119. "database": pg_config.get("database"),
  120. "db_type": "postgresql"
  121. }
  122. # 如果配置中有schema,则添加到连接信息中
  123. if "schema" in pg_config and pg_config["schema"]:
  124. database_info["schema"] = pg_config["schema"]
  125. logger.info(f"成功获取目标数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  126. return database_info
  127. except Exception as e:
  128. logger.error(f"获取目标数据库连接信息时出错: {str(e)}")
  129. logger.error(traceback.format_exc())
  130. raise # 直接抛出异常,不提供默认连接信息
  131. def get_sqlalchemy_engine(db_info):
  132. """
  133. 根据数据库连接信息创建SQLAlchemy引擎
  134. 参数:
  135. db_info (dict): 数据库连接信息
  136. 返回:
  137. Engine: SQLAlchemy引擎对象
  138. """
  139. if not db_info:
  140. logger.error("数据库连接信息为空,无法创建SQLAlchemy引擎")
  141. return None
  142. try:
  143. db_type = db_info.get("db_type", "").lower()
  144. if db_type == "postgresql":
  145. url = f"postgresql+psycopg2://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  146. elif db_type == "mysql":
  147. url = f"mysql+pymysql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  148. elif db_type == "oracle":
  149. url = f"oracle+cx_oracle://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  150. elif db_type == "mssql":
  151. url = f"mssql+pymssql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  152. else:
  153. logger.error(f"不支持的数据库类型: {db_type}")
  154. return None
  155. # 创建数据库引擎,让SQLAlchemy处理数据库差异
  156. engine = create_engine(url)
  157. return engine
  158. except Exception as e:
  159. logger.error(f"创建SQLAlchemy引擎时出错: {str(e)}")
  160. logger.error(traceback.format_exc())
  161. return None
  162. def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema=None):
  163. """
  164. 如果目标表不存在,则从源表复制表结构创建目标表
  165. 参数:
  166. source_engine: 源数据库引擎
  167. target_engine: 目标数据库引擎
  168. source_table: 源表名
  169. target_table: 目标表名
  170. schema: 模式名称,如果为None或空字符串则使用"ods"
  171. 返回:
  172. bool: 操作是否成功
  173. 异常:
  174. Exception: 当源表不存在或无法创建目标表时抛出异常
  175. """
  176. logger.info(f"检查目标表 {target_table} 是否存在,不存在则创建")
  177. try:
  178. # 处理schema参数
  179. if schema == "" or schema is None:
  180. # 如果传递的schema为空,使用"ods"
  181. schema = "ods"
  182. logger.info(f"schema参数为空,使用默认schema: {schema}")
  183. else:
  184. # 如果传递的schema不为空,使用传递的schema
  185. if schema != "ods":
  186. logger.warning(f"使用非标准schema: {schema},建议使用'ods'作为目标schema")
  187. logger.info(f"使用传递的schema: {schema}")
  188. table_display_name = f"{schema}.{target_table}"
  189. logger.info(f"目标表完整名称: {table_display_name}")
  190. # 检查目标表是否存在
  191. target_inspector = inspect(target_engine)
  192. target_exists = target_inspector.has_table(target_table, schema=schema)
  193. if target_exists:
  194. logger.info(f"目标表 {table_display_name} 已存在,无需创建")
  195. return True
  196. # 目标表不存在,从源表获取表结构
  197. source_inspector = inspect(source_engine)
  198. # 处理表名中可能包含的schema信息
  199. source_schema = None
  200. if '.' in source_table:
  201. source_schema, source_table = source_table.split('.', 1)
  202. if not source_inspector.has_table(source_table, schema=source_schema):
  203. error_msg = f"源表 {source_table} 不存在"
  204. if source_schema:
  205. error_msg = f"源表 {source_schema}.{source_table} 不存在"
  206. logger.error(error_msg)
  207. raise Exception(error_msg)
  208. # 获取源表的列信息
  209. source_columns = source_inspector.get_columns(source_table, schema=source_schema)
  210. if not source_columns:
  211. error_msg = f"源表 {source_table} 没有列信息"
  212. logger.error(error_msg)
  213. raise Exception(error_msg)
  214. # 创建元数据对象
  215. metadata = MetaData()
  216. # 检查源表中是否已存在create_time和update_time字段
  217. existing_column_names = [col['name'].lower() for col in source_columns]
  218. has_create_time = 'create_time' in existing_column_names
  219. has_update_time = 'update_time' in existing_column_names
  220. # 构建列定义列表
  221. columns = [Column(col['name'], col['type']) for col in source_columns]
  222. # 如果不存在create_time字段,则添加
  223. if not has_create_time:
  224. from sqlalchemy import TIMESTAMP
  225. columns.append(Column('create_time', TIMESTAMP, nullable=True))
  226. logger.info(f"为表 {table_display_name} 添加 create_time 字段")
  227. # 如果不存在update_time字段,则添加
  228. if not has_update_time:
  229. from sqlalchemy import TIMESTAMP
  230. columns.append(Column('update_time', TIMESTAMP, nullable=True))
  231. logger.info(f"为表 {table_display_name} 添加 update_time 字段")
  232. # 定义目标表结构,让SQLAlchemy处理数据类型映射
  233. # 现在schema总是有值(至少是"ods")
  234. table_def = Table(
  235. target_table,
  236. metadata,
  237. *columns,
  238. schema=schema
  239. )
  240. # 在目标数据库中创建表
  241. metadata.create_all(target_engine)
  242. logger.info(f"成功在目标数据库中创建表 {table_display_name}")
  243. return True
  244. except Exception as e:
  245. logger.error(f"创建表时出错: {str(e)}")
  246. logger.error(traceback.format_exc())
  247. raise Exception(f"创建表失败: {str(e)}")
  248. def load_data_from_source(table_name, exec_date=None, update_mode=None, script_name=None,
  249. schedule_frequency=None, is_manual_dag_trigger=False, **kwargs):
  250. """
  251. 从源数据库加载数据到目标数据库
  252. """
  253. start_time = datetime.now()
  254. logger.info(f"===== 开始从源加载数据 =====")
  255. logger.info(f"表名: {table_name}")
  256. logger.info(f"执行日期: {exec_date}")
  257. logger.info(f"更新模式: {update_mode}")
  258. logger.info(f"脚本名称: {script_name}")
  259. logger.info(f"调度频率: {schedule_frequency}")
  260. logger.info(f"是否手动DAG触发: {is_manual_dag_trigger}")
  261. for key, value in kwargs.items():
  262. logger.info(f"其他参数 - {key}: {value}")
  263. if exec_date is None:
  264. exec_date = datetime.now().strftime('%Y-%m-%d')
  265. logger.info(f"执行日期为空,使用当前日期: {exec_date}")
  266. if schedule_frequency is None:
  267. schedule_frequency = "daily"
  268. logger.info(f"调度频率为空,使用默认值: {schedule_frequency}")
  269. try:
  270. # 获取源数据库和目标数据库信息
  271. source_db_info = get_source_database_info(table_name, script_name)
  272. target_db_info = get_target_database_info()
  273. # 创建数据库引擎
  274. source_engine = get_sqlalchemy_engine(source_db_info)
  275. target_engine = get_sqlalchemy_engine(target_db_info)
  276. if not source_engine or not target_engine:
  277. raise Exception("无法创建数据库引擎,无法加载数据")
  278. # 获取源表名和源schema
  279. source_table = source_db_info.get("source_table", table_name) or table_name
  280. source_schema = source_db_info.get("schema")
  281. # 构建完整的源表名
  282. if source_schema:
  283. full_source_table_name = f"{source_schema}.{source_table}"
  284. else:
  285. full_source_table_name = source_table
  286. logger.info(f"源表完整名称: {full_source_table_name}")
  287. # 获取目标schema
  288. target_schema = target_db_info.get("schema")
  289. # 构建完整的目标表名
  290. if target_schema:
  291. full_table_name = f"{target_schema}.{table_name}"
  292. else:
  293. full_table_name = table_name
  294. logger.info(f"目标表完整名称: {full_table_name}")
  295. # 确保目标表存在 - create_table_if_not_exists必须使用"ods"作为schema
  296. if not create_table_if_not_exists(source_engine, target_engine, full_source_table_name, table_name, "ods"):
  297. raise Exception(f"无法创建目标表 {full_table_name},无法加载数据")
  298. # 根据更新模式处理数据
  299. if update_mode == "full_refresh":
  300. # 执行全量刷新,清空表
  301. logger.info(f"执行全量刷新,清空表 {full_table_name}")
  302. with target_engine.begin() as conn: # 使用begin()自动管理事务
  303. conn.execute(text(f"TRUNCATE TABLE {full_table_name}"))
  304. logger.info(f"成功清空表 {full_table_name}")
  305. # 构建全量查询
  306. query = f"SELECT * FROM {full_source_table_name}"
  307. else:
  308. # 增量更新,需要获取目标日期列和日期范围
  309. target_dt_column = get_target_dt_column(table_name, script_name)
  310. if not target_dt_column:
  311. logger.error(f"无法获取表 {table_name} 的目标日期列,无法执行增量加载")
  312. return False
  313. try:
  314. # 根据是否手动DAG触发决定日期范围
  315. if is_manual_dag_trigger:
  316. # 手动触发
  317. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  318. logger.info(f"手动DAG触发,日期范围: {start_date} 到 {end_date}")
  319. # 执行删除操作
  320. delete_sql = f"""
  321. DELETE FROM {full_table_name}
  322. WHERE {target_dt_column} >= '{start_date}'
  323. AND {target_dt_column} < '{end_date}'
  324. """
  325. with target_engine.begin() as conn: # 使用begin()自动管理事务
  326. conn.execute(text(delete_sql))
  327. logger.info(f"成功删除表 {full_table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
  328. else:
  329. # 自动调度
  330. start_datetime, end_datetime = get_one_day_range(exec_date)
  331. start_date = start_datetime.strftime('%Y-%m-%d %H:%M:%S')
  332. end_date = end_datetime.strftime('%Y-%m-%d %H:%M:%S')
  333. logger.info(f"自动调度,日期范围: {start_date} 到 {end_date}")
  334. # 执行删除操作
  335. delete_sql = f"""
  336. DELETE FROM {full_table_name}
  337. WHERE create_time >= '{start_date}'
  338. AND create_time < '{end_date}'
  339. """
  340. try:
  341. with target_engine.begin() as conn: # 使用begin()自动管理事务
  342. conn.execute(text(delete_sql))
  343. logger.info(f"成功删除表 {full_table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
  344. except Exception as del_err:
  345. logger.error(f"删除数据时出错: {str(del_err)}")
  346. logger.warning("继续执行数据加载")
  347. # 构建增量查询
  348. if not is_manual_dag_trigger:
  349. # 对于自动调度,重新计算日期范围用于查询
  350. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  351. # 检查源表是否含有目标日期列
  352. source_inspector = inspect(source_engine)
  353. # 处理源表的schema信息用于检查列
  354. if source_schema:
  355. source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table, schema=source_schema)]
  356. else:
  357. source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
  358. if target_dt_column.lower() in source_columns:
  359. # 源表含有目标日期列,构建包含日期条件的查询
  360. query = f"""
  361. SELECT * FROM {full_source_table_name}
  362. WHERE {target_dt_column} >= '{start_date}'
  363. AND {target_dt_column} < '{end_date}'
  364. """
  365. else:
  366. # 源表不含目标日期列,构建全量查询
  367. logger.warning(f"源表 {full_source_table_name} 没有目标日期列 {target_dt_column},将加载全部数据")
  368. query = f"SELECT * FROM {full_source_table_name}"
  369. except Exception as date_err:
  370. logger.error(f"计算日期范围时出错: {str(date_err)}")
  371. logger.error(traceback.format_exc())
  372. return False
  373. # 执行查询加载数据
  374. logger.info(f"执行查询: {query}")
  375. try:
  376. # 直接使用SQLAlchemy执行查询,然后手动创建DataFrame
  377. rows = []
  378. column_names = []
  379. with source_engine.connect() as connection:
  380. result_proxy = connection.execute(text(query))
  381. rows = result_proxy.fetchall()
  382. column_names = result_proxy.keys()
  383. df = pd.DataFrame(rows, columns=column_names)
  384. # 检查结果是否为空
  385. if df.empty:
  386. logger.warning(f"查询结果为空,没有数据需要加载")
  387. return True
  388. # 获取当前时间戳
  389. current_time = datetime.now()
  390. # 设置create_time列为当前时间 - 记录数据加载时间
  391. df['create_time'] = current_time
  392. logger.info(f"设置 create_time 字段为当前时间: {current_time}")
  393. # update_time字段保持为NULL,因为这是数据加载而非数据更新
  394. # update_time只有在数据被修改时才应该被设置
  395. if 'update_time' not in df.columns:
  396. df['update_time'] = None # 显式设置为NULL
  397. logger.info(f"为数据添加 update_time 字段,初始值为: NULL (数据加载时不设置更新时间)")
  398. # 写入数据到目标表
  399. logger.info(f"开始写入数据到目标表 {full_table_name},共 {len(df)} 行")
  400. with target_engine.connect() as connection:
  401. # 处理schema参数,如果为空则不传递schema参数
  402. if target_schema:
  403. df.to_sql(
  404. name=table_name,
  405. con=connection,
  406. if_exists='append',
  407. index=False,
  408. schema=target_schema
  409. )
  410. else:
  411. df.to_sql(
  412. name=table_name,
  413. con=connection,
  414. if_exists='append',
  415. index=False
  416. )
  417. logger.info(f"成功写入数据到目标表 {full_table_name}")
  418. return True
  419. except Exception as query_err:
  420. logger.error(f"执行查询或写入数据时出错: {str(query_err)}")
  421. logger.error(traceback.format_exc())
  422. raise Exception(f"数据查询或写入失败: {str(query_err)}")
  423. except Exception as e:
  424. logger.error(f"执行数据加载过程时出错: {str(e)}")
  425. logger.error(traceback.format_exc())
  426. raise Exception(f"数据加载失败: {str(e)}")
  427. finally:
  428. end_time = datetime.now()
  429. duration = (end_time - start_time).total_seconds()
  430. logger.info(f"数据加载过程结束,耗时: {int(duration // 60)}分钟 {int(duration % 60)}秒")
  431. logger.info(f"===== 数据加载结束 =====")
  432. def run(table_name, update_mode, schedule_frequency=None, script_name=None, exec_date=None, is_manual_dag_trigger=False, **kwargs):
  433. """
  434. 统一入口函数,符合Airflow动态脚本调用规范
  435. 参数:
  436. table_name (str): 要处理的表名
  437. update_mode (str): 更新模式 (append/full_refresh)
  438. schedule_frequency (str): 调度频率
  439. script_name (str): 脚本名称
  440. exec_date: 执行日期
  441. is_manual_dag_trigger (bool): 是否手动DAG触发
  442. **kwargs: 其他可能的参数
  443. 返回:
  444. bool: 执行成功返回True,否则抛出异常
  445. """
  446. logger.info(f"开始执行脚本...")
  447. # 获取当前脚本的文件名(如果没有传入)
  448. if script_name is None:
  449. script_name = os.path.basename(__file__)
  450. # 打印所有传入的参数
  451. logger.info(f"===== 传入参数信息 =====")
  452. logger.info(f"table_name: {table_name}")
  453. logger.info(f"update_mode: {update_mode}")
  454. logger.info(f"schedule_frequency: {schedule_frequency}")
  455. logger.info(f"exec_date: {exec_date}")
  456. logger.info(f"script_name: {script_name}")
  457. logger.info(f"is_manual_dag_trigger: {is_manual_dag_trigger}")
  458. # 打印所有可能的额外参数
  459. for key, value in kwargs.items():
  460. logger.info(f"额外参数 - {key}: {value}")
  461. logger.info(f"========================")
  462. # 实际调用内部处理函数,不再捕获异常,让异常直接传递给上层调用者
  463. return load_data_from_source(
  464. table_name=table_name,
  465. exec_date=exec_date,
  466. update_mode=update_mode,
  467. script_name=script_name,
  468. schedule_frequency=schedule_frequency,
  469. is_manual_dag_trigger=is_manual_dag_trigger,
  470. **kwargs
  471. )
  472. if __name__ == "__main__":
  473. # 直接执行时调用统一入口函数,传入测试参数
  474. run(
  475. table_name="test_table",
  476. update_mode="append",
  477. schedule_frequency="daily",
  478. exec_date=datetime.now().strftime('%Y-%m-%d'),
  479. script_name=os.path.basename(__file__),
  480. is_manual_dag_trigger=True
  481. )