load_data.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import logging
  4. import sys
  5. import os
  6. from datetime import datetime, timedelta
  7. import sqlalchemy
  8. from sqlalchemy import create_engine, inspect, Table, Column, MetaData, text
  9. import pandas as pd
  10. import traceback
  11. import pendulum
  12. # 修改Python导入路径,确保能找到同目录下的script_utils模块
  13. current_dir = os.path.dirname(os.path.abspath(__file__))
  14. if current_dir not in sys.path:
  15. sys.path.insert(0, current_dir)
  16. # 导入脚本工具模块
  17. try:
  18. import script_utils
  19. from script_utils import get_pg_config, get_date_range, get_one_day_range, get_neo4j_driver, get_target_dt_column
  20. logger_utils = logging.getLogger("script_utils")
  21. except ImportError as e:
  22. logger_utils = None
  23. print(f"导入script_utils模块失败: {str(e)}")
  24. # 配置日志记录器
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  28. handlers=[
  29. logging.StreamHandler(sys.stdout)
  30. ]
  31. )
  32. logger = logging.getLogger("load_data")
  33. def get_source_database_info(table_name, script_name=None):
  34. """
  35. 根据表名和脚本名从Neo4j获取源数据库连接信息
  36. 参数:
  37. table_name (str): 表名
  38. script_name (str, optional): 脚本名称
  39. 返回:
  40. dict: 数据库连接信息字典
  41. 异常:
  42. Exception: 当无法获取数据库连接信息时抛出异常
  43. """
  44. logger.info(f"获取表 {table_name} 的源数据库连接信息")
  45. driver = get_neo4j_driver()
  46. try:
  47. with driver.session() as session:
  48. # 首先查询表对应的源节点
  49. query = """
  50. MATCH (target {en_name: $table_name})-[rel]->(source:DataSource)
  51. RETURN source.en_name AS source_name,
  52. source.database AS database,
  53. source.host AS host,
  54. source.port AS port,
  55. source.username AS username,
  56. source.password AS password,
  57. source.type AS db_type,
  58. source.schema AS schema,
  59. source.table AS source_table,
  60. labels(source) AS labels
  61. """
  62. result = session.run(query, table_name=table_name)
  63. record = result.single()
  64. if not record:
  65. error_msg = f"未找到表 {table_name} 对应的源节点"
  66. logger.error(error_msg)
  67. raise Exception(error_msg)
  68. # 获取源节点的数据库连接信息
  69. database_info = {
  70. "source_name": record.get("source_name"),
  71. "database": record.get("database"),
  72. "host": record.get("host"),
  73. "port": record.get("port"),
  74. "username": record.get("username"),
  75. "password": record.get("password"),
  76. "db_type": record.get("db_type"),
  77. "schema": record.get("schema", "public"),
  78. "source_table": record.get("source_table"),
  79. "labels": record.get("labels", [])
  80. }
  81. # 检查是否包含数据库连接信息
  82. if not database_info.get("database"):
  83. error_msg = f"源节点 {database_info['source_name']} 没有数据库连接信息"
  84. logger.error(error_msg)
  85. raise Exception(error_msg)
  86. logger.info(f"成功获取表 {table_name} 的源数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  87. return database_info
  88. except Exception as e:
  89. logger.error(f"获取源数据库连接信息时出错: {str(e)}")
  90. logger.error(traceback.format_exc())
  91. raise Exception(f"获取源数据库连接信息失败: {str(e)}")
  92. finally:
  93. driver.close()
  94. def get_target_database_info():
  95. """
  96. 获取目标数据库连接信息
  97. 返回:
  98. dict: 数据库连接信息字典
  99. 异常:
  100. Exception: 无法获取目标数据库连接信息时抛出异常
  101. """
  102. logger.info("获取目标数据库连接信息")
  103. try:
  104. # 尝试从script_utils中获取PG_CONFIG
  105. pg_config = get_pg_config()
  106. if not pg_config:
  107. raise ValueError("无法获取PG_CONFIG配置")
  108. # 检查必要的配置项
  109. required_keys = ["host", "port", "user", "password", "database"]
  110. missing_keys = [key for key in required_keys if key not in pg_config]
  111. if missing_keys:
  112. raise ValueError(f"PG_CONFIG缺少必要的配置项: {', '.join(missing_keys)}")
  113. # 构建连接信息
  114. database_info = {
  115. "host": pg_config.get("host"),
  116. "port": pg_config.get("port"),
  117. "username": pg_config.get("user"),
  118. "password": pg_config.get("password"),
  119. "database": pg_config.get("database"),
  120. "db_type": "postgresql",
  121. "schema": "public"
  122. }
  123. logger.info(f"成功获取目标数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
  124. return database_info
  125. except Exception as e:
  126. logger.error(f"获取目标数据库连接信息时出错: {str(e)}")
  127. logger.error(traceback.format_exc())
  128. raise # 直接抛出异常,不提供默认连接信息
  129. def get_sqlalchemy_engine(db_info):
  130. """
  131. 根据数据库连接信息创建SQLAlchemy引擎
  132. 参数:
  133. db_info (dict): 数据库连接信息
  134. 返回:
  135. Engine: SQLAlchemy引擎对象
  136. """
  137. if not db_info:
  138. logger.error("数据库连接信息为空,无法创建SQLAlchemy引擎")
  139. return None
  140. try:
  141. db_type = db_info.get("db_type", "").lower()
  142. if db_type == "postgresql":
  143. url = f"postgresql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  144. elif db_type == "mysql":
  145. url = f"mysql+pymysql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  146. elif db_type == "oracle":
  147. url = f"oracle+cx_oracle://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  148. elif db_type == "mssql":
  149. url = f"mssql+pymssql://{db_info['username']}:{db_info['password']}@{db_info['host']}:{db_info['port']}/{db_info['database']}"
  150. else:
  151. logger.error(f"不支持的数据库类型: {db_type}")
  152. return None
  153. # 创建数据库引擎
  154. engine = create_engine(url)
  155. return engine
  156. except Exception as e:
  157. logger.error(f"创建SQLAlchemy引擎时出错: {str(e)}")
  158. logger.error(traceback.format_exc())
  159. return None
  160. def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema="public"):
  161. """
  162. 如果目标表不存在,则从源表复制表结构创建目标表
  163. 参数:
  164. source_engine: 源数据库引擎
  165. target_engine: 目标数据库引擎
  166. source_table: 源表名
  167. target_table: 目标表名
  168. schema: 模式名称
  169. 返回:
  170. bool: 操作是否成功
  171. 异常:
  172. Exception: 当源表不存在或无法创建目标表时抛出异常
  173. """
  174. logger.info(f"检查目标表 {target_table} 是否存在,不存在则创建")
  175. try:
  176. # 检查目标表是否存在
  177. target_inspector = inspect(target_engine)
  178. target_exists = target_inspector.has_table(target_table, schema=schema)
  179. if target_exists:
  180. logger.info(f"目标表 {target_table} 已存在,无需创建")
  181. return True
  182. # 目标表不存在,从源表获取表结构
  183. source_inspector = inspect(source_engine)
  184. if not source_inspector.has_table(source_table):
  185. error_msg = f"源表 {source_table} 不存在"
  186. logger.error(error_msg)
  187. raise Exception(error_msg)
  188. # 获取源表的列信息
  189. source_columns = source_inspector.get_columns(source_table)
  190. if not source_columns:
  191. error_msg = f"源表 {source_table} 没有列信息"
  192. logger.error(error_msg)
  193. raise Exception(error_msg)
  194. # 创建元数据对象
  195. metadata = MetaData()
  196. # 定义目标表结构
  197. table_def = Table(
  198. target_table,
  199. metadata,
  200. *[Column(col['name'], col['type']) for col in source_columns],
  201. schema=schema
  202. )
  203. # 在目标数据库中创建表
  204. metadata.create_all(target_engine)
  205. logger.info(f"成功在目标数据库中创建表 {target_table}")
  206. return True
  207. except Exception as e:
  208. logger.error(f"创建表时出错: {str(e)}")
  209. logger.error(traceback.format_exc())
  210. raise Exception(f"创建表失败: {str(e)}")
  211. def load_data_from_source(table_name, exec_date=None, update_mode=None, script_name=None,
  212. schedule_frequency=None, is_manual_dag_trigger=False, **kwargs):
  213. """
  214. 从源数据库加载数据到目标数据库
  215. """
  216. start_time = datetime.now()
  217. logger.info(f"===== 开始从源加载数据 =====")
  218. logger.info(f"表名: {table_name}")
  219. logger.info(f"执行日期: {exec_date}")
  220. logger.info(f"更新模式: {update_mode}")
  221. logger.info(f"脚本名称: {script_name}")
  222. logger.info(f"调度频率: {schedule_frequency}")
  223. logger.info(f"是否手动DAG触发: {is_manual_dag_trigger}")
  224. for key, value in kwargs.items():
  225. logger.info(f"其他参数 - {key}: {value}")
  226. if exec_date is None:
  227. exec_date = datetime.now().strftime('%Y-%m-%d')
  228. logger.info(f"执行日期为空,使用当前日期: {exec_date}")
  229. if schedule_frequency is None:
  230. schedule_frequency = "daily"
  231. logger.info(f"调度频率为空,使用默认值: {schedule_frequency}")
  232. try:
  233. source_db_info = get_source_database_info(table_name, script_name)
  234. target_db_info = get_target_database_info()
  235. source_engine = get_sqlalchemy_engine(source_db_info)
  236. target_engine = get_sqlalchemy_engine(target_db_info)
  237. if not source_engine or not target_engine:
  238. raise Exception("无法创建数据库引擎,无法加载数据")
  239. source_table = source_db_info.get("source_table", table_name) or table_name
  240. if not create_table_if_not_exists(source_engine, target_engine, source_table, table_name):
  241. raise Exception(f"无法创建目标表 {table_name},无法加载数据")
  242. if update_mode == "full_refresh":
  243. logger.info(f"执行全量刷新,清空表 {table_name}")
  244. with target_engine.connect() as conn:
  245. conn.execute(text(f"TRUNCATE TABLE {table_name}"))
  246. logger.info(f"成功清空表 {table_name}")
  247. else:
  248. target_dt_column = get_target_dt_column(table_name, script_name)
  249. if not target_dt_column:
  250. logger.error(f"无法获取表 {table_name} 的目标日期列,无法执行增量加载")
  251. return False
  252. try:
  253. if is_manual_dag_trigger:
  254. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  255. logger.info(f"手动DAG触发,日期范围: {start_date} 到 {end_date}")
  256. delete_sql = f"""
  257. DELETE FROM {table_name}
  258. WHERE {target_dt_column} >= '{start_date}'
  259. AND {target_dt_column} < '{end_date}'
  260. """
  261. with target_engine.connect() as conn:
  262. conn.execute(text(delete_sql))
  263. logger.info(f"成功删除表 {table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
  264. else:
  265. start_datetime, end_datetime = get_one_day_range(exec_date)
  266. start_date = start_datetime.strftime('%Y-%m-%d %H:%M:%S')
  267. end_date = end_datetime.strftime('%Y-%m-%d %H:%M:%S')
  268. logger.info(f"自动调度,日期范围: {start_date} 到 {end_date}")
  269. delete_sql = f"""
  270. DELETE FROM {table_name}
  271. WHERE create_time >= '{start_date}'
  272. AND create_time < '{end_date}'
  273. """
  274. try:
  275. with target_engine.connect() as conn:
  276. conn.execute(text(delete_sql))
  277. logger.info(f"成功删除表 {table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
  278. except Exception as del_err:
  279. logger.error(f"删除数据时出错: {str(del_err)}")
  280. logger.warning("继续执行数据加载")
  281. except Exception as date_err:
  282. logger.error(f"计算日期范围时出错: {str(date_err)}")
  283. logger.error(traceback.format_exc())
  284. return False
  285. # 第6步:加载数据(这是你最关键的修复点)
  286. logger.info(f"执行查询构建")
  287. if update_mode == "full_refresh":
  288. query = f"SELECT * FROM {source_table}"
  289. else:
  290. start_date, end_date = get_date_range(exec_date, schedule_frequency)
  291. source_inspector = inspect(source_engine)
  292. source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
  293. if target_dt_column.lower() in source_columns:
  294. query = f"""
  295. SELECT * FROM {source_table}
  296. WHERE {target_dt_column} >= '{start_date}'
  297. AND {target_dt_column} < '{end_date}'
  298. """
  299. else:
  300. logger.warning(f"源表 {source_table} 没有目标日期列 {target_dt_column},将加载全部数据")
  301. query = f"SELECT * FROM {source_table}"
  302. logger.info(f"执行查询: {query}")
  303. # with source_engine.connect() as conn:
  304. # df = pd.read_sql(query, conn)
  305. #df = pd.read_sql(query, source_engine)
  306. from sqlalchemy.orm import sessionmaker
  307. Session = sessionmaker(bind=source_engine)
  308. session = Session()
  309. try:
  310. df = pd.read_sql(query, session.connection())
  311. finally:
  312. session.close()
  313. if df.empty:
  314. logger.warning(f"查询结果为空,没有数据需要加载")
  315. return True
  316. if 'create_time' not in df.columns:
  317. df['create_time'] = datetime.now()
  318. logger.info(f"开始写入数据到目标表 {table_name},共 {len(df)} 行")
  319. df.to_sql(
  320. name=table_name,
  321. con=target_engine,
  322. if_exists='append',
  323. index=False,
  324. schema=target_db_info.get("schema", "public")
  325. )
  326. logger.info(f"成功写入数据到目标表 {table_name}")
  327. return True
  328. except Exception as e:
  329. logger.error(f"执行数据加载过程时出错: {str(e)}")
  330. logger.error(traceback.format_exc())
  331. raise Exception(f"数据加载失败: {str(e)}")
  332. finally:
  333. end_time = datetime.now()
  334. duration = (end_time - start_time).total_seconds()
  335. logger.info(f"数据加载过程结束,耗时: {int(duration // 60)}分钟 {int(duration % 60)}秒")
  336. logger.info(f"===== 数据加载结束 =====")
  337. def run(table_name, update_mode, schedule_frequency=None, script_name=None, exec_date=None, is_manual_dag_trigger=False, **kwargs):
  338. """
  339. 统一入口函数,符合Airflow动态脚本调用规范
  340. 参数:
  341. table_name (str): 要处理的表名
  342. update_mode (str): 更新模式 (append/full_refresh)
  343. schedule_frequency (str): 调度频率
  344. script_name (str): 脚本名称
  345. exec_date: 执行日期
  346. is_manual_dag_trigger (bool): 是否手动DAG触发
  347. **kwargs: 其他可能的参数
  348. 返回:
  349. bool: 执行成功返回True,否则抛出异常
  350. """
  351. logger.info(f"开始执行脚本...")
  352. # 获取当前脚本的文件名(如果没有传入)
  353. if script_name is None:
  354. script_name = os.path.basename(__file__)
  355. # 打印所有传入的参数
  356. logger.info(f"===== 传入参数信息 =====")
  357. logger.info(f"table_name: {table_name}")
  358. logger.info(f"update_mode: {update_mode}")
  359. logger.info(f"schedule_frequency: {schedule_frequency}")
  360. logger.info(f"exec_date: {exec_date}")
  361. logger.info(f"script_name: {script_name}")
  362. logger.info(f"is_manual_dag_trigger: {is_manual_dag_trigger}")
  363. # 打印所有可能的额外参数
  364. for key, value in kwargs.items():
  365. logger.info(f"额外参数 - {key}: {value}")
  366. logger.info(f"========================")
  367. # 实际调用内部处理函数,不再捕获异常,让异常直接传递给上层调用者
  368. return load_data_from_source(
  369. table_name=table_name,
  370. exec_date=exec_date,
  371. update_mode=update_mode,
  372. script_name=script_name,
  373. schedule_frequency=schedule_frequency,
  374. is_manual_dag_trigger=is_manual_dag_trigger,
  375. **kwargs
  376. )
  377. if __name__ == "__main__":
  378. # 直接执行时调用统一入口函数,传入测试参数
  379. run(
  380. table_name="test_table",
  381. update_mode="append",
  382. schedule_frequency="daily",
  383. exec_date=datetime.now().strftime('%Y-%m-%d'),
  384. script_name=os.path.basename(__file__),
  385. is_manual_dag_trigger=True
  386. )