db_cleanup.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. """
  18. This module took inspiration from the community maintenance dag.
  19. See:
  20. (https://github.com/teamclairvoyant/airflow-maintenance-dags/blob/4e5c7682a808082561d60cbc9cafaa477b0d8c65/db-cleanup/airflow-db-cleanup.py).
  21. """
  22. from __future__ import annotations
  23. import csv
  24. import logging
  25. import os
  26. from contextlib import contextmanager
  27. from dataclasses import dataclass
  28. from typing import TYPE_CHECKING, Any
  29. from sqlalchemy import and_, column, false, func, inspect, select, table, text
  30. from sqlalchemy.exc import OperationalError, ProgrammingError
  31. from sqlalchemy.ext.compiler import compiles
  32. from sqlalchemy.orm import aliased
  33. from sqlalchemy.sql.expression import ClauseElement, Executable, tuple_
  34. from airflow.cli.simple_table import AirflowConsole
  35. from airflow.configuration import conf
  36. from airflow.exceptions import AirflowException
  37. from airflow.utils import timezone
  38. from airflow.utils.db import reflect_tables
  39. from airflow.utils.helpers import ask_yesno
  40. from airflow.utils.session import NEW_SESSION, provide_session
  41. if TYPE_CHECKING:
  42. from pendulum import DateTime
  43. from sqlalchemy.orm import Query, Session
  44. from airflow.models import Base
  45. logger = logging.getLogger(__name__)
  46. ARCHIVE_TABLE_PREFIX = "_airflow_deleted__"
  47. @dataclass
  48. class _TableConfig:
  49. """
  50. Config class for performing cleanup on a table.
  51. :param table_name: the table
  52. :param extra_columns: any columns besides recency_column_name that we'll need in queries
  53. :param recency_column_name: date column to filter by
  54. :param keep_last: whether the last record should be kept even if it's older than clean_before_timestamp
  55. :param keep_last_filters: the "keep last" functionality will preserve the most recent record
  56. in the table. to ignore certain records even if they are the latest in the table, you can
  57. supply additional filters here (e.g. externally triggered dag runs)
  58. :param keep_last_group_by: if keeping the last record, can keep the last record for each group
  59. """
  60. table_name: str
  61. recency_column_name: str
  62. extra_columns: list[str] | None = None
  63. keep_last: bool = False
  64. keep_last_filters: Any | None = None
  65. keep_last_group_by: Any | None = None
  66. def __post_init__(self):
  67. self.recency_column = column(self.recency_column_name)
  68. self.orm_model: Base = table(
  69. self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column
  70. )
  71. def __lt__(self, other):
  72. return self.table_name < other.table_name
  73. @property
  74. def readable_config(self):
  75. return {
  76. "table": self.orm_model.name,
  77. "recency_column": str(self.recency_column),
  78. "keep_last": self.keep_last,
  79. "keep_last_filters": [str(x) for x in self.keep_last_filters] if self.keep_last_filters else None,
  80. "keep_last_group_by": str(self.keep_last_group_by),
  81. }
  82. config_list: list[_TableConfig] = [
  83. _TableConfig(table_name="job", recency_column_name="latest_heartbeat"),
  84. _TableConfig(table_name="dag", recency_column_name="last_parsed_time"),
  85. _TableConfig(
  86. table_name="dag_run",
  87. recency_column_name="start_date",
  88. extra_columns=["dag_id", "external_trigger"],
  89. keep_last=True,
  90. keep_last_filters=[column("external_trigger") == false()],
  91. keep_last_group_by=["dag_id"],
  92. ),
  93. _TableConfig(table_name="dataset_event", recency_column_name="timestamp"),
  94. _TableConfig(table_name="import_error", recency_column_name="timestamp"),
  95. _TableConfig(table_name="log", recency_column_name="dttm"),
  96. _TableConfig(table_name="sla_miss", recency_column_name="timestamp"),
  97. _TableConfig(table_name="task_fail", recency_column_name="start_date"),
  98. _TableConfig(table_name="task_instance", recency_column_name="start_date"),
  99. _TableConfig(table_name="task_instance_history", recency_column_name="start_date"),
  100. _TableConfig(table_name="task_reschedule", recency_column_name="start_date"),
  101. _TableConfig(table_name="xcom", recency_column_name="timestamp"),
  102. _TableConfig(table_name="callback_request", recency_column_name="created_at"),
  103. _TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"),
  104. _TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"),
  105. _TableConfig(table_name="trigger", recency_column_name="created_date"),
  106. ]
  107. if conf.get("webserver", "session_backend") == "database":
  108. config_list.append(_TableConfig(table_name="session", recency_column_name="expiry"))
  109. config_dict: dict[str, _TableConfig] = {x.orm_model.name: x for x in sorted(config_list)}
  110. def _check_for_rows(*, query: Query, print_rows=False):
  111. num_entities = query.count()
  112. print(f"Found {num_entities} rows meeting deletion criteria.")
  113. if print_rows:
  114. max_rows_to_print = 100
  115. if num_entities > 0:
  116. print(f"Printing first {max_rows_to_print} rows.")
  117. logger.debug("print entities query: %s", query)
  118. for entry in query.limit(max_rows_to_print):
  119. print(entry.__dict__)
  120. return num_entities
  121. def _dump_table_to_file(*, target_table, file_path, export_format, session):
  122. if export_format == "csv":
  123. with open(file_path, "w") as f:
  124. csv_writer = csv.writer(f)
  125. cursor = session.execute(text(f"SELECT * FROM {target_table}"))
  126. csv_writer.writerow(cursor.keys())
  127. csv_writer.writerows(cursor.fetchall())
  128. else:
  129. raise AirflowException(f"Export format {export_format} is not supported.")
  130. def _do_delete(*, query, orm_model, skip_archive, session):
  131. import re2
  132. print("Performing Delete...")
  133. # using bulk delete
  134. # create a new table and copy the rows there
  135. timestamp_str = re2.sub(r"[^\d]", "", timezone.utcnow().isoformat())[:14]
  136. target_table_name = f"{ARCHIVE_TABLE_PREFIX}{orm_model.name}__{timestamp_str}"
  137. print(f"Moving data to table {target_table_name}")
  138. bind = session.get_bind()
  139. dialect_name = bind.dialect.name
  140. if dialect_name == "mysql":
  141. # MySQL with replication needs this split into two queries, so just do it for all MySQL
  142. # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
  143. session.execute(text(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}"))
  144. metadata = reflect_tables([target_table_name], session)
  145. target_table = metadata.tables[target_table_name]
  146. insert_stm = target_table.insert().from_select(target_table.c, query)
  147. logger.debug("insert statement:\n%s", insert_stm.compile())
  148. session.execute(insert_stm)
  149. else:
  150. stmt = CreateTableAs(target_table_name, query.selectable)
  151. logger.debug("ctas query:\n%s", stmt.compile())
  152. session.execute(stmt)
  153. session.commit()
  154. # delete the rows from the old table
  155. metadata = reflect_tables([orm_model.name, target_table_name], session)
  156. source_table = metadata.tables[orm_model.name]
  157. target_table = metadata.tables[target_table_name]
  158. logger.debug("rows moved; purging from %s", source_table.name)
  159. if dialect_name == "sqlite":
  160. pk_cols = source_table.primary_key.columns
  161. delete = source_table.delete().where(
  162. tuple_(*pk_cols).in_(select(*[target_table.c[x.name] for x in source_table.primary_key.columns]))
  163. )
  164. else:
  165. delete = source_table.delete().where(
  166. and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
  167. )
  168. logger.debug("delete statement:\n%s", delete.compile())
  169. session.execute(delete)
  170. session.commit()
  171. if skip_archive:
  172. bind = session.get_bind()
  173. target_table.drop(bind=bind)
  174. session.commit()
  175. print("Finished Performing Delete")
  176. def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session):
  177. subquery = select(*group_by_columns, func.max(recency_column).label(max_date_colname))
  178. if keep_last_filters is not None:
  179. for entry in keep_last_filters:
  180. subquery = subquery.filter(entry)
  181. if group_by_columns is not None:
  182. subquery = subquery.group_by(*group_by_columns)
  183. return subquery.subquery(name="latest")
  184. class CreateTableAs(Executable, ClauseElement):
  185. """Custom sqlalchemy clause element for CTAS operations."""
  186. inherit_cache = False
  187. def __init__(self, name, query):
  188. self.name = name
  189. self.query = query
  190. @compiles(CreateTableAs)
  191. def _compile_create_table_as__other(element, compiler, **kw):
  192. return f"CREATE TABLE {element.name} AS {compiler.process(element.query)}"
  193. def _build_query(
  194. *,
  195. orm_model,
  196. recency_column,
  197. keep_last,
  198. keep_last_filters,
  199. keep_last_group_by,
  200. clean_before_timestamp,
  201. session,
  202. **kwargs,
  203. ):
  204. base_table_alias = "base"
  205. base_table = aliased(orm_model, name=base_table_alias)
  206. query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
  207. base_table_recency_col = base_table.c[recency_column.name]
  208. conditions = [base_table_recency_col < clean_before_timestamp]
  209. if keep_last:
  210. max_date_col_name = "max_date_per_group"
  211. group_by_columns = [column(x) for x in keep_last_group_by]
  212. subquery = _subquery_keep_last(
  213. recency_column=recency_column,
  214. keep_last_filters=keep_last_filters,
  215. group_by_columns=group_by_columns,
  216. max_date_colname=max_date_col_name,
  217. session=session,
  218. )
  219. query = query.select_from(base_table).outerjoin(
  220. subquery,
  221. and_(
  222. *[base_table.c[x] == subquery.c[x] for x in keep_last_group_by],
  223. base_table_recency_col == column(max_date_col_name),
  224. ),
  225. )
  226. conditions.append(column(max_date_col_name).is_(None))
  227. query = query.filter(and_(*conditions))
  228. return query
  229. def _cleanup_table(
  230. *,
  231. orm_model,
  232. recency_column,
  233. keep_last,
  234. keep_last_filters,
  235. keep_last_group_by,
  236. clean_before_timestamp,
  237. dry_run=True,
  238. verbose=False,
  239. skip_archive=False,
  240. session,
  241. **kwargs,
  242. ):
  243. print()
  244. if dry_run:
  245. print(f"Performing dry run for table {orm_model.name}")
  246. query = _build_query(
  247. orm_model=orm_model,
  248. recency_column=recency_column,
  249. keep_last=keep_last,
  250. keep_last_filters=keep_last_filters,
  251. keep_last_group_by=keep_last_group_by,
  252. clean_before_timestamp=clean_before_timestamp,
  253. session=session,
  254. )
  255. logger.debug("old rows query:\n%s", query.selectable.compile())
  256. print(f"Checking table {orm_model.name}")
  257. num_rows = _check_for_rows(query=query, print_rows=False)
  258. if num_rows and not dry_run:
  259. _do_delete(query=query, orm_model=orm_model, skip_archive=skip_archive, session=session)
  260. session.commit()
  261. def _confirm_delete(*, date: DateTime, tables: list[str]):
  262. for_tables = f" for tables {tables!r}" if tables else ""
  263. question = (
  264. f"You have requested that we purge all data prior to {date}{for_tables}.\n"
  265. f"This is irreversible. Consider backing up the tables first and / or doing a dry run "
  266. f"with option --dry-run.\n"
  267. f"Enter 'delete rows' (without quotes) to proceed."
  268. )
  269. print(question)
  270. answer = input().strip()
  271. if answer != "delete rows":
  272. raise SystemExit("User did not confirm; exiting.")
  273. def _confirm_drop_archives(*, tables: list[str]):
  274. # if length of tables is greater than 3, show the total count
  275. if len(tables) > 3:
  276. text_ = f"{len(tables)} archived tables prefixed with {ARCHIVE_TABLE_PREFIX}"
  277. else:
  278. text_ = f"the following archived tables: {', '.join(tables)}"
  279. question = (
  280. f"You have requested that we drop {text_}.\n"
  281. f"This is irreversible. Consider backing up the tables first.\n"
  282. )
  283. print(question)
  284. if len(tables) > 3:
  285. show_tables = ask_yesno("Show tables that will be dropped? (y/n): ")
  286. if show_tables:
  287. for table in tables:
  288. print(f" {table}")
  289. print("\n")
  290. answer = input("Enter 'drop archived tables' (without quotes) to proceed.\n").strip()
  291. if answer != "drop archived tables":
  292. raise SystemExit("User did not confirm; exiting.")
  293. def _print_config(*, configs: dict[str, _TableConfig]):
  294. data = [x.readable_config for x in configs.values()]
  295. AirflowConsole().print_as_table(data=data)
  296. @contextmanager
  297. def _suppress_with_logging(table, session):
  298. """
  299. Suppresses errors but logs them.
  300. Also stores the exception instance so it can be referred to after exiting context.
  301. """
  302. try:
  303. yield
  304. except (OperationalError, ProgrammingError):
  305. logger.warning("Encountered error when attempting to clean table '%s'. ", table)
  306. logger.debug("Traceback for table '%s'", table, exc_info=True)
  307. if session.is_active:
  308. logger.debug("Rolling back transaction")
  309. session.rollback()
  310. def _effective_table_names(*, table_names: list[str] | None):
  311. desired_table_names = set(table_names or config_dict)
  312. effective_config_dict = {k: v for k, v in config_dict.items() if k in desired_table_names}
  313. effective_table_names = set(effective_config_dict)
  314. if desired_table_names != effective_table_names:
  315. outliers = desired_table_names - effective_table_names
  316. logger.warning(
  317. "The following table(s) are not valid choices and will be skipped: %s", sorted(outliers)
  318. )
  319. if not effective_table_names:
  320. raise SystemExit("No tables selected for db cleanup. Please choose valid table names.")
  321. return effective_table_names, effective_config_dict
  322. def _get_archived_table_names(table_names, session):
  323. inspector = inspect(session.bind)
  324. db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)]
  325. effective_table_names, _ = _effective_table_names(table_names=table_names)
  326. # Filter out tables that don't start with the archive prefix
  327. archived_table_names = [
  328. table_name
  329. for table_name in db_table_names
  330. if any("__" + x + "__" in table_name for x in effective_table_names)
  331. ]
  332. return archived_table_names
  333. @provide_session
  334. def run_cleanup(
  335. *,
  336. clean_before_timestamp: DateTime,
  337. table_names: list[str] | None = None,
  338. dry_run: bool = False,
  339. verbose: bool = False,
  340. confirm: bool = True,
  341. skip_archive: bool = False,
  342. session: Session = NEW_SESSION,
  343. ):
  344. """
  345. Purges old records in airflow metadata database.
  346. The last non-externally-triggered dag run will always be kept in order to ensure
  347. continuity of scheduled dag runs.
  348. Where there are foreign key relationships, deletes will cascade, so that for
  349. example if you clean up old dag runs, the associated task instances will
  350. be deleted.
  351. :param clean_before_timestamp: The timestamp before which data should be purged
  352. :param table_names: Optional. List of table names to perform maintenance on. If list not provided,
  353. will perform maintenance on all tables.
  354. :param dry_run: If true, print rows meeting deletion criteria
  355. :param verbose: If true, may provide more detailed output.
  356. :param confirm: Require user input to confirm before processing deletions.
  357. :param skip_archive: Set to True if you don't want the purged rows preservied in an archive table.
  358. :param session: Session representing connection to the metadata database.
  359. """
  360. clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)
  361. effective_table_names, effective_config_dict = _effective_table_names(table_names=table_names)
  362. if dry_run:
  363. print("Performing dry run for db cleanup.")
  364. print(
  365. f"Data prior to {clean_before_timestamp} would be purged "
  366. f"from tables {effective_table_names} with the following config:\n"
  367. )
  368. _print_config(configs=effective_config_dict)
  369. if not dry_run and confirm:
  370. _confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names))
  371. existing_tables = reflect_tables(tables=None, session=session).tables
  372. for table_name, table_config in effective_config_dict.items():
  373. if table_name in existing_tables:
  374. with _suppress_with_logging(table_name, session):
  375. _cleanup_table(
  376. clean_before_timestamp=clean_before_timestamp,
  377. dry_run=dry_run,
  378. verbose=verbose,
  379. **table_config.__dict__,
  380. skip_archive=skip_archive,
  381. session=session,
  382. )
  383. session.commit()
  384. else:
  385. logger.warning("Table %s not found. Skipping.", table_name)
  386. @provide_session
  387. def export_archived_records(
  388. export_format,
  389. output_path,
  390. table_names=None,
  391. drop_archives=False,
  392. needs_confirm=True,
  393. session: Session = NEW_SESSION,
  394. ):
  395. """Export archived data to the given output path in the given format."""
  396. archived_table_names = _get_archived_table_names(table_names, session)
  397. # If user chose to drop archives, check there are archive tables that exists
  398. # before asking for confirmation
  399. if drop_archives and archived_table_names and needs_confirm:
  400. _confirm_drop_archives(tables=sorted(archived_table_names))
  401. export_count = 0
  402. dropped_count = 0
  403. for table_name in archived_table_names:
  404. logger.info("Exporting table %s", table_name)
  405. _dump_table_to_file(
  406. target_table=table_name,
  407. file_path=os.path.join(output_path, f"{table_name}.{export_format}"),
  408. export_format=export_format,
  409. session=session,
  410. )
  411. export_count += 1
  412. if drop_archives:
  413. logger.info("Dropping archived table %s", table_name)
  414. session.execute(text(f"DROP TABLE {table_name}"))
  415. dropped_count += 1
  416. logger.info("Total exported tables: %s, Total dropped tables: %s", export_count, dropped_count)
  417. @provide_session
  418. def drop_archived_tables(table_names, needs_confirm, session):
  419. """Drop archived tables."""
  420. archived_table_names = _get_archived_table_names(table_names, session)
  421. if needs_confirm and archived_table_names:
  422. _confirm_drop_archives(tables=sorted(archived_table_names))
  423. dropped_count = 0
  424. for table_name in archived_table_names:
  425. logger.info("Dropping archived table %s", table_name)
  426. session.execute(text(f"DROP TABLE {table_name}"))
  427. dropped_count += 1
  428. logger.info("Total dropped tables: %s", dropped_count)