load_data.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import logging
  4. import sys
  5. import os
  6. from datetime import datetime, timedelta
  7. import sqlalchemy
  8. from sqlalchemy import create_engine, inspect, Table, Column, MetaData, text
  9. import pandas as pd
  10. import traceback
  11. import pendulum
  12. # 修改Python导入路径,确保能找到同目录下的script_utils模块
  13. current_dir = os.path.dirname(os.path.abspath(__file__))
  14. if current_dir not in sys.path:
  15. sys.path.insert(0, current_dir)
  16. # 导入脚本工具模块
  17. try:
  18. import script_utils
  19. from script_utils import get_pg_config, get_date_range, get_one_day_range, get_neo4j_driver, get_target_dt_column
  20. logger_utils = logging.getLogger("script_utils")
  21. except ImportError as e:
  22. logger_utils = None
  23. print(f"导入script_utils模块失败: {str(e)}")
  24. # 配置日志记录器
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  28. handlers=[
  29. logging.StreamHandler(sys.stdout)
  30. ]
  31. )
  32. logger = logging.getLogger("load_data")
  33. def get_source_database_info(table_name, script_name=None):
  34. """
  35. 根据表名和脚本名从Neo4j获取源数据库连接信息
  36. 参数:
  37. table_name (str): 表名
  38. script_name (str, optional): 脚本名称
  39. 返回:
  40. dict: 数据库连接信息字典
  41. 异常:
  42. Exception: 当无法获取数据库连接信息时抛出异常
  43. """
  44. logger.info(f"获取表 {table_name} 的源数据库连接信息")
  45. driver = get_neo4j_driver()
  46. try:
  47. with driver.session() as session:
  48. # 首先查询表对应的源节点
  49. query = """
  50. MATCH (target {en_name: $table_name})-[rel]->(source:DataSource)
  51. RETURN source.en_name AS source_name,
  52. source.database AS database,
  53. source.host AS host,
  54. source.port AS port,
  55. source.username AS username,
  56. source.password AS password,
  57. source.type AS db_type,
  58. source.schema AS schema,
  59. source.table AS source_table,
  60. labels(source) AS labels
  61. """
  62. result = session.run(query, table_name=table_name)
  63. record = result.single()
  64. if not record:
  65. error_msg = f"未找到表 {table_name} 对应的源节点"
  66. logger.error(error_msg)
  67. raise Exception(error_msg)
  68. # 获取源节点的数据库连接信息
  69. database_info = {
  70. "source_name": record.get("source_name"),
  71. "database": record.get("database"),
  72. "host": record.get("host"),
  73. "port": record.get("port"),
  74. "username": record.get("username"),
  75. "password": record.get("password"),
  76. "db_type": record.get("db_type"),
  77. "schema": record.get("schema", "public"),
  78. "source_table": record.get("source_table"),
  79. "labels": record.get("labels", [])
  80. }
  81. # 检查是否包含数据库连接信息
  82. if not database_info.get("database"):
  83. error_msg = f"源节点 {database_info['source_name']} 没有数据库连接信息"
  84. logger.error(error_msg)
  85. raise Exception(error_msg)
  86. logger.info(f"成功获取表 {table_name} 的源数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  87. return database_info
  88. except Exception as e:
  89. logger.error(f"获取源数据库连接信息时出错: {str(e)}")
  90. logger.error(traceback.format_exc())
  91. raise Exception(f"获取源数据库连接信息失败: {str(e)}")
  92. finally:
  93. driver.close()
  94. def get_target_database_info():
  95. """
  96. 获取目标数据库连接信息
  97. 返回:
  98. dict: 数据库连接信息字典
  99. 异常:
  100. Exception: 无法获取目标数据库连接信息时抛出异常
  101. """
  102. logger.info("获取目标数据库连接信息")
  103. try:
  104. # 尝试从script_utils中获取PG_CONFIG
  105. pg_config = get_pg_config()
  106. if not pg_config:
  107. raise ValueError("无法获取PG_CONFIG配置")
  108. # 检查必要的配置项
  109. required_keys = ["host", "port", "user", "password", "database"]
  110. missing_keys = [key for key in required_keys if key not in pg_config]
  111. if missing_keys:
  112. raise ValueError(f"PG_CONFIG缺少必要的配置项: {', '.join(missing_keys)}")
  113. # 构建连接信息
  114. database_info = {
  115. "host": pg_config.get("host"),
  116. "port": pg_config.get("port"),
  117. "username": pg_config.get("user"),
  118. "password": pg_config.get("password"),
  119. "database": pg_config.get("database"),
  120. "db_type": "postgresql",
  121. "schema": "public"
  122. }
  123. logger.info(f"成功获取目标数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  124. return database_info
  125. except Exception as e:
  126. logger.error(f"获取目标数据库连接信息时出错: {str(e)}")
  127. logger.error(traceback.format_exc())
  128. raise # 直接抛出异常,不提供默认连接信息
  129. def get_sqlalchemy_engine(db_info):
  130. """
  131. 根据数据库连接信息创建SQLAlchemy引擎
  132. 参数:
  133. db_info (dict): 数据库连接信息
  134. 返回:
  135. Engine: SQLAlchemy引擎对象
  136. """
  137. if not db_info:
  138. logger.error("数据库连接信息为空,无法创建SQLAlchemy引擎")
  139. return None
  140. try:
  141. db_type = db_info.get("db_type", "").lower()
  142. if db_type == "postgresql":
  143. url = f"postgresql+psycopg2://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  144. elif db_type == "mysql":
  145. url = f"mysql+pymysql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  146. elif db_type == "oracle":
  147. url = f"oracle+cx_oracle://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  148. elif db_type == "mssql":
  149. url = f"mssql+pymssql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  150. else:
  151. logger.error(f"不支持的数据库类型: {db_type}")
  152. return None
  153. # 创建数据库引擎,让SQLAlchemy处理数据库差异
  154. engine = create_engine(url)
  155. return engine
  156. except Exception as e:
  157. logger.error(f"创建SQLAlchemy引擎时出错: {str(e)}")
  158. logger.error(traceback.format_exc())
  159. return None
  160. def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema="public"):
  161. """
  162. 如果目标表不存在,则从源表复制表结构创建目标表
  163. 参数:
  164. source_engine: 源数据库引擎
  165. target_engine: 目标数据库引擎
  166. source_table: 源表名
  167. target_table: 目标表名
  168. schema: 模式名称
  169. 返回:
  170. bool: 操作是否成功
  171. 异常:
  172. Exception: 当源表不存在或无法创建目标表时抛出异常
  173. """
  174. logger.info(f"检查目标表 {target_table} 是否存在,不存在则创建")
  175. try:
  176. # 检查目标表是否存在
  177. target_inspector = inspect(target_engine)
  178. target_exists = target_inspector.has_table(target_table, schema=schema)
  179. if target_exists:
  180. logger.info(f"目标表 {target_table} 已存在,无需创建")
  181. return True
  182. # 目标表不存在,从源表获取表结构
  183. source_inspector = inspect(source_engine)
  184. # 处理表名中可能包含的schema信息
  185. source_schema = None
  186. if '.' in source_table:
  187. source_schema, source_table = source_table.split('.', 1)
  188. if not source_inspector.has_table(source_table, schema=source_schema):
  189. error_msg = f"源表 {source_table} 不存在"
  190. if source_schema:
  191. error_msg = f"源表 {source_schema}.{source_table} 不存在"
  192. logger.error(error_msg)
  193. raise Exception(error_msg)
  194. # 获取源表的列信息
  195. source_columns = source_inspector.get_columns(source_table, schema=source_schema)
  196. if not source_columns:
  197. error_msg = f"源表 {source_table} 没有列信息"
  198. logger.error(error_msg)
  199. raise Exception(error_msg)
  200. # 创建元数据对象
  201. metadata = MetaData()
  202. # 检查源表中是否已存在create_time和update_time字段
  203. existing_column_names = [col['name'].lower() for col in source_columns]
  204. has_create_time = 'create_time' in existing_column_names
  205. has_update_time = 'update_time' in existing_column_names
  206. # 构建列定义列表
  207. columns = [Column(col['name'], col['type']) for col in source_columns]
  208. # 如果不存在create_time字段,则添加
  209. if not has_create_time:
  210. from sqlalchemy import TIMESTAMP
  211. columns.append(Column('create_time', TIMESTAMP, nullable=True))
  212. logger.info(f"为表 {target_table} 添加 create_time 字段")
  213. # 如果不存在update_time字段,则添加
  214. if not has_update_time:
  215. from sqlalchemy import TIMESTAMP
  216. columns.append(Column('update_time', TIMESTAMP, nullable=True))
  217. logger.info(f"为表 {target_table} 添加 update_time 字段")
  218. # 定义目标表结构,让SQLAlchemy处理数据类型映射
  219. table_def = Table(
  220. target_table,
  221. metadata,
  222. *columns,
  223. schema=schema
  224. )
  225. # 在目标数据库中创建表
  226. metadata.create_all(target_engine)
  227. logger.info(f"成功在目标数据库中创建表 {schema}.{target_table}")
  228. return True
  229. except Exception as e:
  230. logger.error(f"创建表时出错: {str(e)}")
  231. logger.error(traceback.format_exc())
  232. raise Exception(f"创建表失败: {str(e)}")
  233. def load_data_from_source(table_name, exec_date=None, update_mode=None, script_name=None,
  234. schedule_frequency=None, is_manual_dag_trigger=False, **kwargs):
  235. """
  236. 从源数据库加载数据到目标数据库
  237. """
  238. start_time = datetime.now()
  239. logger.info(f"===== 开始从源加载数据 =====")
  240. logger.info(f"表名: {table_name}")
  241. logger.info(f"执行日期: {exec_date}")
  242. logger.info(f"更新模式: {update_mode}")
  243. logger.info(f"脚本名称: {script_name}")
  244. logger.info(f"调度频率: {schedule_frequency}")
  245. logger.info(f"是否手动DAG触发: {is_manual_dag_trigger}")
  246. for key, value in kwargs.items():
  247. logger.info(f"其他参数 - {key}: {value}")
  248. if exec_date is None:
  249. exec_date = datetime.now().strftime('%Y-%m-%d')
  250. logger.info(f"执行日期为空,使用当前日期: {exec_date}")
  251. if schedule_frequency is None:
  252. schedule_frequency = "daily"
  253. logger.info(f"调度频率为空,使用默认值: {schedule_frequency}")
  254. try:
  255. # 获取源数据库和目标数据库信息
  256. source_db_info = get_source_database_info(table_name, script_name)
  257. target_db_info = get_target_database_info()
  258. # 创建数据库引擎
  259. source_engine = get_sqlalchemy_engine(source_db_info)
  260. target_engine = get_sqlalchemy_engine(target_db_info)
  261. if not source_engine or not target_engine:
  262. raise Exception("无法创建数据库引擎,无法加载数据")
  263. # 获取源表名
  264. source_table = source_db_info.get("source_table", table_name) or table_name
  265. # 确保目标表存在
  266. if not create_table_if_not_exists(source_engine, target_engine, source_table, table_name):
  267. raise Exception(f"无法创建目标表 {table_name},无法加载数据")
  268. # 根据更新模式处理数据
  269. if update_mode == "full_refresh":
  270. # 执行全量刷新,清空表
  271. logger.info(f"执行全量刷新,清空表 {table_name}")
  272. with target_engine.begin() as conn: # 使用begin()自动管理事务
  273. conn.execute(f"TRUNCATE TABLE {table_name}")
  274. logger.info(f"成功清空表 {table_name}")
  275. # 构建全量查询
  276. query = f"SELECT * FROM {source_table}"
  277. else:
  278. # 增量更新,需要获取目标日期列和日期范围
  279. target_dt_column = get_target_dt_column(table_name, script_name)
  280. if not target_dt_column:
  281. logger.error(f"无法获取表 {table_name} 的目标日期列,无法执行增量加载")
  282. return False
  283. try:
  284. # 根据是否手动DAG触发决定日期范围
  285. if is_manual_dag_trigger:
  286. # 手动触发
  287. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  288. logger.info(f"手动DAG触发,日期范围: {start_date} 到 {end_date}")
  289. # 执行删除操作
  290. delete_sql = f"""
  291. DELETE FROM {table_name}
  292. WHERE {target_dt_column} >= '{start_date}'
  293. AND {target_dt_column} < '{end_date}'
  294. """
  295. with target_engine.begin() as conn: # 使用begin()自动管理事务
  296. conn.execute(delete_sql)
  297. logger.info(f"成功删除表 {table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
  298. else:
  299. # 自动调度
  300. start_datetime, end_datetime = get_one_day_range(exec_date)
  301. start_date = start_datetime.strftime('%Y-%m-%d %H:%M:%S')
  302. end_date = end_datetime.strftime('%Y-%m-%d %H:%M:%S')
  303. logger.info(f"自动调度,日期范围: {start_date} 到 {end_date}")
  304. # 执行删除操作
  305. delete_sql = f"""
  306. DELETE FROM {table_name}
  307. WHERE create_time >= '{start_date}'
  308. AND create_time < '{end_date}'
  309. """
  310. try:
  311. with target_engine.begin() as conn: # 使用begin()自动管理事务
  312. conn.execute(delete_sql)
  313. logger.info(f"成功删除表 {table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
  314. except Exception as del_err:
  315. logger.error(f"删除数据时出错: {str(del_err)}")
  316. logger.warning("继续执行数据加载")
  317. # 构建增量查询
  318. if not is_manual_dag_trigger:
  319. # 对于自动调度,重新计算日期范围用于查询
  320. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  321. # 检查源表是否含有目标日期列
  322. source_inspector = inspect(source_engine)
  323. source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
  324. if target_dt_column.lower() in source_columns:
  325. # 源表含有目标日期列,构建包含日期条件的查询
  326. query = f"""
  327. SELECT * FROM {source_table}
  328. WHERE {target_dt_column} >= '{start_date}'
  329. AND {target_dt_column} < '{end_date}'
  330. """
  331. else:
  332. # 源表不含目标日期列,构建全量查询
  333. logger.warning(f"源表 {source_table} 没有目标日期列 {target_dt_column},将加载全部数据")
  334. query = f"SELECT * FROM {source_table}"
  335. except Exception as date_err:
  336. logger.error(f"计算日期范围时出错: {str(date_err)}")
  337. logger.error(traceback.format_exc())
  338. return False
  339. # 执行查询加载数据
  340. logger.info(f"执行查询: {query}")
  341. try:
  342. # 直接使用SQLAlchemy执行查询,然后手动创建DataFrame
  343. rows = []
  344. column_names = []
  345. with source_engine.connect() as connection:
  346. result_proxy = connection.execute(text(query))
  347. rows = result_proxy.fetchall()
  348. column_names = result_proxy.keys()
  349. df = pd.DataFrame(rows, columns=column_names)
  350. # 检查结果是否为空
  351. if df.empty:
  352. logger.warning(f"查询结果为空,没有数据需要加载")
  353. return True
  354. # 获取当前时间戳
  355. current_time = datetime.now()
  356. # 设置create_time列为当前时间 - 记录数据加载时间
  357. df['create_time'] = current_time
  358. logger.info(f"设置 create_time 字段为当前时间: {current_time}")
  359. # update_time字段保持为NULL,因为这是数据加载而非数据更新
  360. # update_time只有在数据被修改时才应该被设置
  361. if 'update_time' not in df.columns:
  362. df['update_time'] = None # 显式设置为NULL
  363. logger.info(f"为数据添加 update_time 字段,初始值为: NULL (数据加载时不设置更新时间)")
  364. # 写入数据到目标表
  365. logger.info(f"开始写入数据到目标表 {table_name},共 {len(df)} 行")
  366. with target_engine.connect() as connection:
  367. df.to_sql(
  368. name=table_name,
  369. con=connection,
  370. if_exists='append',
  371. index=False,
  372. schema=target_db_info.get("schema", "public")
  373. )
  374. logger.info(f"成功写入数据到目标表 {table_name}")
  375. return True
  376. except Exception as query_err:
  377. logger.error(f"执行查询或写入数据时出错: {str(query_err)}")
  378. logger.error(traceback.format_exc())
  379. raise Exception(f"数据查询或写入失败: {str(query_err)}")
  380. except Exception as e:
  381. logger.error(f"执行数据加载过程时出错: {str(e)}")
  382. logger.error(traceback.format_exc())
  383. raise Exception(f"数据加载失败: {str(e)}")
  384. finally:
  385. end_time = datetime.now()
  386. duration = (end_time - start_time).total_seconds()
  387. logger.info(f"数据加载过程结束,耗时: {int(duration // 60)}分钟 {int(duration % 60)}秒")
  388. logger.info(f"===== 数据加载结束 =====")
  389. def run(table_name, update_mode, schedule_frequency=None, script_name=None, exec_date=None, is_manual_dag_trigger=False, **kwargs):
  390. """
  391. 统一入口函数,符合Airflow动态脚本调用规范
  392. 参数:
  393. table_name (str): 要处理的表名
  394. update_mode (str): 更新模式 (append/full_refresh)
  395. schedule_frequency (str): 调度频率
  396. script_name (str): 脚本名称
  397. exec_date: 执行日期
  398. is_manual_dag_trigger (bool): 是否手动DAG触发
  399. **kwargs: 其他可能的参数
  400. 返回:
  401. bool: 执行成功返回True,否则抛出异常
  402. """
  403. logger.info(f"开始执行脚本...")
  404. # 获取当前脚本的文件名(如果没有传入)
  405. if script_name is None:
  406. script_name = os.path.basename(__file__)
  407. # 打印所有传入的参数
  408. logger.info(f"===== 传入参数信息 =====")
  409. logger.info(f"table_name: {table_name}")
  410. logger.info(f"update_mode: {update_mode}")
  411. logger.info(f"schedule_frequency: {schedule_frequency}")
  412. logger.info(f"exec_date: {exec_date}")
  413. logger.info(f"script_name: {script_name}")
  414. logger.info(f"is_manual_dag_trigger: {is_manual_dag_trigger}")
  415. # 打印所有可能的额外参数
  416. for key, value in kwargs.items():
  417. logger.info(f"额外参数 - {key}: {value}")
  418. logger.info(f"========================")
  419. # 实际调用内部处理函数,不再捕获异常,让异常直接传递给上层调用者
  420. return load_data_from_source(
  421. table_name=table_name,
  422. exec_date=exec_date,
  423. update_mode=update_mode,
  424. script_name=script_name,
  425. schedule_frequency=schedule_frequency,
  426. is_manual_dag_trigger=is_manual_dag_trigger,
  427. **kwargs
  428. )
  429. if __name__ == "__main__":
  430. # 直接执行时调用统一入口函数,传入测试参数
  431. run(
  432. table_name="test_table",
  433. update_mode="append",
  434. schedule_frequency="daily",
  435. exec_date=datetime.now().strftime('%Y-%m-%d'),
  436. script_name=os.path.basename(__file__),
  437. is_manual_dag_trigger=True
  438. )