load_data.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import logging
  4. import sys
  5. import os
  6. from datetime import datetime, timedelta
  7. import sqlalchemy
  8. from sqlalchemy import create_engine, inspect, Table, Column, MetaData, text
  9. import pandas as pd
  10. import traceback
  11. import pendulum
  12. # 修改Python导入路径,确保能找到同目录下的script_utils模块
  13. current_dir = os.path.dirname(os.path.abspath(__file__))
  14. if current_dir not in sys.path:
  15. sys.path.insert(0, current_dir)
  16. # 导入脚本工具模块
  17. try:
  18. import script_utils
  19. from script_utils import get_pg_config, get_date_range, get_one_day_range, get_neo4j_driver, get_target_dt_column
  20. logger_utils = logging.getLogger("script_utils")
  21. except ImportError as e:
  22. logger_utils = None
  23. print(f"导入script_utils模块失败: {str(e)}")
  24. # 配置日志记录器
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  28. handlers=[
  29. logging.StreamHandler(sys.stdout)
  30. ]
  31. )
  32. logger = logging.getLogger("load_data")
  33. def get_source_database_info(table_name, script_name=None):
  34. """
  35. 根据表名和脚本名从Neo4j获取源数据库连接信息
  36. 参数:
  37. table_name (str): 表名
  38. script_name (str, optional): 脚本名称
  39. 返回:
  40. dict: 数据库连接信息字典
  41. 异常:
  42. Exception: 当无法获取数据库连接信息时抛出异常
  43. """
  44. logger.info(f"获取表 {table_name} 的源数据库连接信息")
  45. driver = get_neo4j_driver()
  46. try:
  47. with driver.session() as session:
  48. # 首先查询表对应的源节点
  49. query = """
  50. MATCH (target {en_name: $table_name})-[rel]->(source:DataSource)
  51. RETURN source.en_name AS source_name,
  52. source.database AS database,
  53. source.host AS host,
  54. source.port AS port,
  55. source.username AS username,
  56. source.password AS password,
  57. source.type AS db_type,
  58. source.schema AS schema,
  59. source.table AS source_table,
  60. labels(source) AS labels
  61. """
  62. result = session.run(query, table_name=table_name)
  63. record = result.single()
  64. if not record:
  65. error_msg = f"未找到表 {table_name} 对应的源节点"
  66. logger.error(error_msg)
  67. raise Exception(error_msg)
  68. # 获取源节点的数据库连接信息
  69. database_info = {
  70. "source_name": record.get("source_name"),
  71. "database": record.get("database"),
  72. "host": record.get("host"),
  73. "port": record.get("port"),
  74. "username": record.get("username"),
  75. "password": record.get("password"),
  76. "db_type": record.get("db_type"),
  77. "schema": record.get("schema", "public"),
  78. "source_table": record.get("source_table"),
  79. "labels": record.get("labels", [])
  80. }
  81. # 检查是否包含数据库连接信息
  82. if not database_info.get("database"):
  83. error_msg = f"源节点 {database_info['source_name']} 没有数据库连接信息"
  84. logger.error(error_msg)
  85. raise Exception(error_msg)
  86. logger.info(f"成功获取表 {table_name} 的源数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  87. return database_info
  88. except Exception as e:
  89. logger.error(f"获取源数据库连接信息时出错: {str(e)}")
  90. logger.error(traceback.format_exc())
  91. raise Exception(f"获取源数据库连接信息失败: {str(e)}")
  92. finally:
  93. driver.close()
  94. def get_target_database_info():
  95. """
  96. 获取目标数据库连接信息
  97. 返回:
  98. dict: 数据库连接信息字典
  99. 异常:
  100. Exception: 无法获取目标数据库连接信息时抛出异常
  101. """
  102. logger.info("获取目标数据库连接信息")
  103. try:
  104. # 尝试从script_utils中获取PG_CONFIG
  105. pg_config = get_pg_config()
  106. if not pg_config:
  107. raise ValueError("无法获取PG_CONFIG配置")
  108. # 检查必要的配置项
  109. required_keys = ["host", "port", "user", "password", "database"]
  110. missing_keys = [key for key in required_keys if key not in pg_config]
  111. if missing_keys:
  112. raise ValueError(f"PG_CONFIG缺少必要的配置项: {', '.join(missing_keys)}")
  113. # 构建连接信息
  114. database_info = {
  115. "host": pg_config.get("host"),
  116. "port": pg_config.get("port"),
  117. "username": pg_config.get("user"),
  118. "password": pg_config.get("password"),
  119. "database": pg_config.get("database"),
  120. "db_type": "postgresql",
  121. "schema": "public"
  122. }
  123. logger.info(f"成功获取目标数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  124. return database_info
  125. except Exception as e:
  126. logger.error(f"获取目标数据库连接信息时出错: {str(e)}")
  127. logger.error(traceback.format_exc())
  128. raise # 直接抛出异常,不提供默认连接信息
  129. def get_sqlalchemy_engine(db_info):
  130. """
  131. 根据数据库连接信息创建SQLAlchemy引擎
  132. 参数:
  133. db_info (dict): 数据库连接信息
  134. 返回:
  135. Engine: SQLAlchemy引擎对象
  136. """
  137. if not db_info:
  138. logger.error("数据库连接信息为空,无法创建SQLAlchemy引擎")
  139. return None
  140. try:
  141. db_type = db_info.get("db_type", "").lower()
  142. if db_type == "postgresql":
  143. url = f"postgresql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  144. elif db_type == "mysql":
  145. url = f"mysql+pymysql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  146. elif db_type == "oracle":
  147. url = f"oracle+cx_oracle://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  148. elif db_type == "mssql":
  149. url = f"mssql+pymssql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  150. else:
  151. logger.error(f"不支持的数据库类型: {db_type}")
  152. return None
  153. # 创建数据库引擎
  154. engine = create_engine(url)
  155. return engine
  156. except Exception as e:
  157. logger.error(f"创建SQLAlchemy引擎时出错: {str(e)}")
  158. logger.error(traceback.format_exc())
  159. return None
  160. def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema="public"):
  161. """
  162. 如果目标表不存在,则从源表复制表结构创建目标表
  163. 参数:
  164. source_engine: 源数据库引擎
  165. target_engine: 目标数据库引擎
  166. source_table: 源表名
  167. target_table: 目标表名
  168. schema: 模式名称
  169. 返回:
  170. bool: 操作是否成功
  171. 异常:
  172. Exception: 当源表不存在或无法创建目标表时抛出异常
  173. """
  174. logger.info(f"检查目标表 {target_table} 是否存在,不存在则创建")
  175. try:
  176. # 检查目标表是否存在
  177. target_inspector = inspect(target_engine)
  178. target_exists = target_inspector.has_table(target_table, schema=schema)
  179. if target_exists:
  180. logger.info(f"目标表 {target_table} 已存在,无需创建")
  181. return True
  182. # 目标表不存在,从源表获取表结构
  183. source_inspector = inspect(source_engine)
  184. if not source_inspector.has_table(source_table):
  185. error_msg = f"源表 {source_table} 不存在"
  186. logger.error(error_msg)
  187. raise Exception(error_msg)
  188. # 获取源表的列信息
  189. source_columns = source_inspector.get_columns(source_table)
  190. if not source_columns:
  191. error_msg = f"源表 {source_table} 没有列信息"
  192. logger.error(error_msg)
  193. raise Exception(error_msg)
  194. # 创建元数据对象
  195. metadata = MetaData()
  196. # 定义目标表结构
  197. table_def = Table(
  198. target_table,
  199. metadata,
  200. *[Column(col['name'], col['type']) for col in source_columns],
  201. schema=schema
  202. )
  203. # 在目标数据库中创建表
  204. metadata.create_all(target_engine)
  205. logger.info(f"成功在目标数据库中创建表 {target_table}")
  206. return True
  207. except Exception as e:
  208. logger.error(f"创建表时出错: {str(e)}")
  209. logger.error(traceback.format_exc())
  210. raise Exception(f"创建表失败: {str(e)}")
  211. def load_data_from_source(table_name, exec_date=None, update_mode=None, script_name=None,
  212. schedule_frequency=None, is_manual_dag_trigger=False, **kwargs):
  213. """
  214. 从源数据库加载数据到目标数据库
  215. """
  216. start_time = datetime.now()
  217. logger.info(f"===== 开始从源加载数据 =====")
  218. logger.info(f"表名: {table_name}")
  219. logger.info(f"执行日期: {exec_date}")
  220. logger.info(f"更新模式: {update_mode}")
  221. logger.info(f"脚本名称: {script_name}")
  222. logger.info(f"调度频率: {schedule_frequency}")
  223. logger.info(f"是否手动DAG触发: {is_manual_dag_trigger}")
  224. for key, value in kwargs.items():
  225. logger.info(f"其他参数 - {key}: {value}")
  226. if exec_date is None:
  227. exec_date = datetime.now().strftime('%Y-%m-%d')
  228. logger.info(f"执行日期为空,使用当前日期: {exec_date}")
  229. if schedule_frequency is None:
  230. schedule_frequency = "daily"
  231. logger.info(f"调度频率为空,使用默认值: {schedule_frequency}")
  232. try:
  233. # 获取源数据库和目标数据库信息
  234. source_db_info = get_source_database_info(table_name, script_name)
  235. target_db_info = get_target_database_info()
  236. # 创建数据库引擎
  237. source_engine = get_sqlalchemy_engine(source_db_info)
  238. target_engine = get_sqlalchemy_engine(target_db_info)
  239. if not source_engine or not target_engine:
  240. raise Exception("无法创建数据库引擎,无法加载数据")
  241. # 获取源表名
  242. source_table = source_db_info.get("source_table", table_name) or table_name
  243. # 确保目标表存在
  244. if not create_table_if_not_exists(source_engine, target_engine, source_table, table_name):
  245. raise Exception(f"无法创建目标表 {table_name},无法加载数据")
  246. # 根据更新模式处理数据
  247. if update_mode == "full_refresh":
  248. # 执行全量刷新,清空表
  249. logger.info(f"执行全量刷新,清空表 {table_name}")
  250. with target_engine.begin() as conn: # 使用begin()自动管理事务
  251. conn.execute(f"TRUNCATE TABLE {table_name}")
  252. logger.info(f"成功清空表 {table_name}")
  253. # 构建全量查询
  254. query = f"SELECT * FROM {source_table}"
  255. else:
  256. # 增量更新,需要获取目标日期列和日期范围
  257. target_dt_column = get_target_dt_column(table_name, script_name)
  258. if not target_dt_column:
  259. logger.error(f"无法获取表 {table_name} 的目标日期列,无法执行增量加载")
  260. return False
  261. try:
  262. # 根据是否手动DAG触发决定日期范围
  263. if is_manual_dag_trigger:
  264. # 手动触发
  265. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  266. logger.info(f"手动DAG触发,日期范围: {start_date} 到 {end_date}")
  267. # 执行删除操作
  268. delete_sql = f"""
  269. DELETE FROM {table_name}
  270. WHERE {target_dt_column} >= '{start_date}'
  271. AND {target_dt_column} < '{end_date}'
  272. """
  273. with target_engine.begin() as conn: # 使用begin()自动管理事务
  274. conn.execute(delete_sql)
  275. logger.info(f"成功删除表 {table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
  276. else:
  277. # 自动调度
  278. start_datetime, end_datetime = get_one_day_range(exec_date)
  279. start_date = start_datetime.strftime('%Y-%m-%d %H:%M:%S')
  280. end_date = end_datetime.strftime('%Y-%m-%d %H:%M:%S')
  281. logger.info(f"自动调度,日期范围: {start_date} 到 {end_date}")
  282. # 执行删除操作
  283. delete_sql = f"""
  284. DELETE FROM {table_name}
  285. WHERE create_time >= '{start_date}'
  286. AND create_time < '{end_date}'
  287. """
  288. try:
  289. with target_engine.begin() as conn: # 使用begin()自动管理事务
  290. conn.execute(delete_sql)
  291. logger.info(f"成功删除表 {table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
  292. except Exception as del_err:
  293. logger.error(f"删除数据时出错: {str(del_err)}")
  294. logger.warning("继续执行数据加载")
  295. # 构建增量查询
  296. if not is_manual_dag_trigger:
  297. # 对于自动调度,重新计算日期范围用于查询
  298. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  299. # 检查源表是否含有目标日期列
  300. source_inspector = inspect(source_engine)
  301. source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
  302. if target_dt_column.lower() in source_columns:
  303. # 源表含有目标日期列,构建包含日期条件的查询
  304. query = f"""
  305. SELECT * FROM {source_table}
  306. WHERE {target_dt_column} >= '{start_date}'
  307. AND {target_dt_column} < '{end_date}'
  308. """
  309. else:
  310. # 源表不含目标日期列,构建全量查询
  311. logger.warning(f"源表 {source_table} 没有目标日期列 {target_dt_column},将加载全部数据")
  312. query = f"SELECT * FROM {source_table}"
  313. except Exception as date_err:
  314. logger.error(f"计算日期范围时出错: {str(date_err)}")
  315. logger.error(traceback.format_exc())
  316. return False
  317. # 执行查询加载数据
  318. logger.info(f"执行查询: {query}")
  319. try:
  320. # 直接使用SQLAlchemy执行查询,然后手动创建DataFrame
  321. rows = []
  322. column_names = []
  323. with source_engine.connect() as connection:
  324. result_proxy = connection.execute(text(query))
  325. rows = result_proxy.fetchall()
  326. column_names = result_proxy.keys()
  327. df = pd.DataFrame(rows, columns=column_names)
  328. # 检查结果是否为空
  329. if df.empty:
  330. logger.warning(f"查询结果为空,没有数据需要加载")
  331. return True
  332. # 添加create_time列(如果不存在)
  333. if 'create_time' not in df.columns:
  334. df['create_time'] = datetime.now()
  335. # 写入数据到目标表
  336. logger.info(f"开始写入数据到目标表 {table_name},共 {len(df)} 行")
  337. with target_engine.connect() as connection:
  338. df.to_sql(
  339. name=table_name,
  340. con=connection,
  341. if_exists='append',
  342. index=False,
  343. schema=target_db_info.get("schema", "public")
  344. )
  345. logger.info(f"成功写入数据到目标表 {table_name}")
  346. return True
  347. except Exception as query_err:
  348. logger.error(f"执行查询或写入数据时出错: {str(query_err)}")
  349. logger.error(traceback.format_exc())
  350. raise Exception(f"数据查询或写入失败: {str(query_err)}")
  351. except Exception as e:
  352. logger.error(f"执行数据加载过程时出错: {str(e)}")
  353. logger.error(traceback.format_exc())
  354. raise Exception(f"数据加载失败: {str(e)}")
  355. finally:
  356. end_time = datetime.now()
  357. duration = (end_time - start_time).total_seconds()
  358. logger.info(f"数据加载过程结束,耗时: {int(duration // 60)}分钟 {int(duration % 60)}秒")
  359. logger.info(f"===== 数据加载结束 =====")
  360. def run(table_name, update_mode, schedule_frequency=None, script_name=None, exec_date=None, is_manual_dag_trigger=False, **kwargs):
  361. """
  362. 统一入口函数,符合Airflow动态脚本调用规范
  363. 参数:
  364. table_name (str): 要处理的表名
  365. update_mode (str): 更新模式 (append/full_refresh)
  366. schedule_frequency (str): 调度频率
  367. script_name (str): 脚本名称
  368. exec_date: 执行日期
  369. is_manual_dag_trigger (bool): 是否手动DAG触发
  370. **kwargs: 其他可能的参数
  371. 返回:
  372. bool: 执行成功返回True,否则抛出异常
  373. """
  374. logger.info(f"开始执行脚本...")
  375. # 获取当前脚本的文件名(如果没有传入)
  376. if script_name is None:
  377. script_name = os.path.basename(__file__)
  378. # 打印所有传入的参数
  379. logger.info(f"===== 传入参数信息 =====")
  380. logger.info(f"table_name: {table_name}")
  381. logger.info(f"update_mode: {update_mode}")
  382. logger.info(f"schedule_frequency: {schedule_frequency}")
  383. logger.info(f"exec_date: {exec_date}")
  384. logger.info(f"script_name: {script_name}")
  385. logger.info(f"is_manual_dag_trigger: {is_manual_dag_trigger}")
  386. # 打印所有可能的额外参数
  387. for key, value in kwargs.items():
  388. logger.info(f"额外参数 - {key}: {value}")
  389. logger.info(f"========================")
  390. # 实际调用内部处理函数,不再捕获异常,让异常直接传递给上层调用者
  391. return load_data_from_source(
  392. table_name=table_name,
  393. exec_date=exec_date,
  394. update_mode=update_mode,
  395. script_name=script_name,
  396. schedule_frequency=schedule_frequency,
  397. is_manual_dag_trigger=is_manual_dag_trigger,
  398. **kwargs
  399. )
  400. if __name__ == "__main__":
  401. # 直接执行时调用统一入口函数,传入测试参数
  402. run(
  403. table_name="test_table",
  404. update_mode="append",
  405. schedule_frequency="daily",
  406. exec_date=datetime.now().strftime('%Y-%m-%d'),
  407. script_name=os.path.basename(__file__),
  408. is_manual_dag_trigger=True
  409. )