setup_teardown.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from typing import TYPE_CHECKING, cast
  19. from airflow.exceptions import AirflowException
  20. if TYPE_CHECKING:
  21. from airflow.models.abstractoperator import AbstractOperator
  22. from airflow.models.taskmixin import DependencyMixin
  23. from airflow.models.xcom_arg import PlainXComArg
  24. class BaseSetupTeardownContext:
  25. """
  26. Context manager for setup/teardown tasks.
  27. :meta private:
  28. """
  29. active: bool = False
  30. context_map: dict[AbstractOperator | tuple[AbstractOperator], list[AbstractOperator]] = {}
  31. _context_managed_setup_task: AbstractOperator | list[AbstractOperator] = []
  32. _previous_context_managed_setup_task: list[AbstractOperator | list[AbstractOperator]] = []
  33. _context_managed_teardown_task: AbstractOperator | list[AbstractOperator] = []
  34. _previous_context_managed_teardown_task: list[AbstractOperator | list[AbstractOperator]] = []
  35. _teardown_downstream_of_setup: AbstractOperator | list[AbstractOperator] = []
  36. _previous_teardown_downstream_of_setup: list[AbstractOperator | list[AbstractOperator]] = []
  37. _setup_upstream_of_teardown: AbstractOperator | list[AbstractOperator] = []
  38. _previous_setup_upstream_of_teardown: list[AbstractOperator | list[AbstractOperator]] = []
  39. @classmethod
  40. def push_context_managed_setup_task(cls, task: AbstractOperator | list[AbstractOperator]):
  41. setup_task = cls._context_managed_setup_task
  42. if setup_task and setup_task != task:
  43. cls._previous_context_managed_setup_task.append(cls._context_managed_setup_task)
  44. cls._context_managed_setup_task = task
  45. @classmethod
  46. def push_context_managed_teardown_task(cls, task: AbstractOperator | list[AbstractOperator]):
  47. teardown_task = cls._context_managed_teardown_task
  48. if teardown_task and teardown_task != task:
  49. cls._previous_context_managed_teardown_task.append(cls._context_managed_teardown_task)
  50. cls._context_managed_teardown_task = task
  51. @classmethod
  52. def pop_context_managed_setup_task(cls) -> AbstractOperator | list[AbstractOperator]:
  53. old_setup_task = cls._context_managed_setup_task
  54. if cls._previous_context_managed_setup_task:
  55. cls._context_managed_setup_task = cls._previous_context_managed_setup_task.pop()
  56. setup_task = cls._context_managed_setup_task
  57. if setup_task and old_setup_task:
  58. cls.set_dependency(old_setup_task, setup_task, upstream=False)
  59. else:
  60. cls._context_managed_setup_task = []
  61. return old_setup_task
  62. @classmethod
  63. def pop_context_managed_teardown_task(cls) -> AbstractOperator | list[AbstractOperator]:
  64. old_teardown_task = cls._context_managed_teardown_task
  65. if cls._previous_context_managed_teardown_task:
  66. cls._context_managed_teardown_task = cls._previous_context_managed_teardown_task.pop()
  67. teardown_task = cls._context_managed_teardown_task
  68. if teardown_task and old_teardown_task:
  69. cls.set_dependency(old_teardown_task, teardown_task)
  70. else:
  71. cls._context_managed_teardown_task = []
  72. return old_teardown_task
  73. @classmethod
  74. def pop_teardown_downstream_of_setup(cls) -> AbstractOperator | list[AbstractOperator]:
  75. old_teardown_task = cls._teardown_downstream_of_setup
  76. if cls._previous_teardown_downstream_of_setup:
  77. cls._teardown_downstream_of_setup = cls._previous_teardown_downstream_of_setup.pop()
  78. teardown_task = cls._teardown_downstream_of_setup
  79. if teardown_task and old_teardown_task:
  80. cls.set_dependency(old_teardown_task, teardown_task)
  81. else:
  82. cls._teardown_downstream_of_setup = []
  83. return old_teardown_task
  84. @classmethod
  85. def pop_setup_upstream_of_teardown(cls) -> AbstractOperator | list[AbstractOperator]:
  86. old_setup_task = cls._setup_upstream_of_teardown
  87. if cls._previous_setup_upstream_of_teardown:
  88. cls._setup_upstream_of_teardown = cls._previous_setup_upstream_of_teardown.pop()
  89. setup_task = cls._setup_upstream_of_teardown
  90. if setup_task and old_setup_task:
  91. cls.set_dependency(old_setup_task, setup_task, upstream=False)
  92. else:
  93. cls._setup_upstream_of_teardown = []
  94. return old_setup_task
  95. @classmethod
  96. def set_dependency(
  97. cls,
  98. receiving_task: AbstractOperator | list[AbstractOperator],
  99. new_task: AbstractOperator | list[AbstractOperator],
  100. upstream=True,
  101. ):
  102. if isinstance(new_task, (list, tuple)):
  103. for task in new_task:
  104. cls._set_dependency(task, receiving_task, upstream)
  105. else:
  106. cls._set_dependency(new_task, receiving_task, upstream)
  107. @staticmethod
  108. def _set_dependency(task, receiving_task, upstream):
  109. if upstream:
  110. task.set_upstream(receiving_task)
  111. else:
  112. task.set_downstream(receiving_task)
  113. @classmethod
  114. def update_context_map(cls, task: DependencyMixin):
  115. from airflow.models.abstractoperator import AbstractOperator
  116. task_ = cast(AbstractOperator, task)
  117. if task_.is_setup or task_.is_teardown:
  118. return
  119. ctx = cls.context_map
  120. def _append_or_set_item(item):
  121. if ctx.get(item) is None:
  122. ctx[item] = [task_]
  123. else:
  124. ctx[item].append(task_)
  125. if setup_task := cls._context_managed_setup_task:
  126. if isinstance(setup_task, list):
  127. _append_or_set_item(tuple(setup_task))
  128. else:
  129. _append_or_set_item(setup_task)
  130. if teardown_task := cls._context_managed_teardown_task:
  131. if isinstance(teardown_task, list):
  132. _append_or_set_item(tuple(teardown_task))
  133. else:
  134. _append_or_set_item(teardown_task)
  135. @classmethod
  136. def push_setup_teardown_task(cls, operator: AbstractOperator | list[AbstractOperator]):
  137. if isinstance(operator, list):
  138. if operator[0].is_teardown:
  139. cls._push_tasks(operator)
  140. elif operator[0].is_setup:
  141. cls._push_tasks(operator, setup=True)
  142. elif operator.is_teardown:
  143. cls._push_tasks(operator)
  144. elif operator.is_setup:
  145. cls._push_tasks(operator, setup=True)
  146. cls.active = True
  147. @classmethod
  148. def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], setup: bool = False):
  149. if isinstance(operator, list):
  150. if any(task.is_setup != operator[0].is_setup for task in operator):
  151. cls.error("All tasks in the list must be either setup or teardown tasks")
  152. if setup:
  153. cls.push_context_managed_setup_task(operator)
  154. # workout the teardown
  155. cls._update_teardown_downstream(operator)
  156. else:
  157. cls.push_context_managed_teardown_task(operator)
  158. # workout the setups
  159. cls._update_setup_upstream(operator)
  160. @classmethod
  161. def _update_teardown_downstream(cls, operator: AbstractOperator | list[AbstractOperator]):
  162. """
  163. Recursively go through the tasks downstream of the setup in the context manager.
  164. If found, update the _teardown_downstream_of_setup accordingly.
  165. """
  166. operator = operator[0] if isinstance(operator, list) else operator
  167. def _get_teardowns(tasks):
  168. teardowns = [i for i in tasks if i.is_teardown]
  169. if not teardowns:
  170. all_lists = [task.downstream_list + task.upstream_list for task in tasks]
  171. new_list = [
  172. x
  173. for sublist in all_lists
  174. for x in sublist
  175. if (isinstance(operator, list) and x in operator) or x != operator
  176. ]
  177. if not new_list:
  178. return []
  179. return _get_teardowns(new_list)
  180. return teardowns
  181. teardowns = _get_teardowns(operator.downstream_list)
  182. teardown_task = cls._teardown_downstream_of_setup
  183. if teardown_task and teardown_task != teardowns:
  184. cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
  185. cls._teardown_downstream_of_setup = teardowns
  186. @classmethod
  187. def _update_setup_upstream(cls, operator: AbstractOperator | list[AbstractOperator]):
  188. """
  189. Recursively go through the tasks upstream of the teardown task in the context manager.
  190. If found, updates the _setup_upstream_of_teardown accordingly.
  191. """
  192. operator = operator[0] if isinstance(operator, list) else operator
  193. def _get_setups(tasks):
  194. setups = [i for i in tasks if i.is_setup]
  195. if not setups:
  196. all_lists = [task.downstream_list + task.upstream_list for task in tasks]
  197. new_list = [
  198. x
  199. for sublist in all_lists
  200. for x in sublist
  201. if (isinstance(operator, list) and x in operator) or x != operator
  202. ]
  203. if not new_list:
  204. return []
  205. return _get_setups(new_list)
  206. return setups
  207. setups = _get_setups(operator.upstream_list)
  208. setup_task = cls._setup_upstream_of_teardown
  209. if setup_task and setup_task != setups:
  210. cls._previous_setup_upstream_of_teardown.append(cls._setup_upstream_of_teardown)
  211. cls._setup_upstream_of_teardown = setups
  212. @classmethod
  213. def set_teardown_task_as_leaves(cls, leaves):
  214. teardown_task = cls._teardown_downstream_of_setup
  215. if cls._context_managed_teardown_task:
  216. cls.set_dependency(cls._context_managed_teardown_task, teardown_task)
  217. else:
  218. cls.set_dependency(leaves, teardown_task)
  219. @classmethod
  220. def set_setup_task_as_roots(cls, roots):
  221. setup_task = cls._setup_upstream_of_teardown
  222. if cls._context_managed_setup_task:
  223. cls.set_dependency(cls._context_managed_setup_task, setup_task, upstream=False)
  224. else:
  225. cls.set_dependency(roots, setup_task, upstream=False)
  226. @classmethod
  227. def set_work_task_roots_and_leaves(cls):
  228. """Set the work task roots and leaves."""
  229. if setup_task := cls._context_managed_setup_task:
  230. if isinstance(setup_task, list):
  231. setup_task = tuple(setup_task)
  232. tasks_in_context = [
  233. x for x in cls.context_map.get(setup_task, []) if not x.is_teardown and not x.is_setup
  234. ]
  235. if tasks_in_context:
  236. roots = [task for task in tasks_in_context if not task.upstream_list]
  237. if not roots:
  238. setup_task >> tasks_in_context[0]
  239. else:
  240. cls.set_dependency(roots, setup_task, upstream=False)
  241. leaves = [task for task in tasks_in_context if not task.downstream_list]
  242. if not leaves:
  243. leaves = tasks_in_context[-1]
  244. cls.set_teardown_task_as_leaves(leaves)
  245. if teardown_task := cls._context_managed_teardown_task:
  246. if isinstance(teardown_task, list):
  247. teardown_task = tuple(teardown_task)
  248. tasks_in_context = [
  249. x for x in cls.context_map.get(teardown_task, []) if not x.is_teardown and not x.is_setup
  250. ]
  251. if tasks_in_context:
  252. leaves = [task for task in tasks_in_context if not task.downstream_list]
  253. if not leaves:
  254. teardown_task << tasks_in_context[-1]
  255. else:
  256. cls.set_dependency(leaves, teardown_task)
  257. roots = [task for task in tasks_in_context if not task.upstream_list]
  258. if not roots:
  259. roots = tasks_in_context[0]
  260. cls.set_setup_task_as_roots(roots)
  261. cls.set_setup_teardown_relationships()
  262. cls.active = False
  263. @classmethod
  264. def set_setup_teardown_relationships(cls):
  265. """
  266. Set relationship between setup to setup and teardown to teardown.
  267. code:: python
  268. with setuptask >> teardowntask:
  269. with setuptask2 >> teardowntask2:
  270. ...
  271. We set setuptask >> setuptask2, teardowntask >> teardowntask2
  272. """
  273. setup_task = cls.pop_context_managed_setup_task()
  274. teardown_task = cls.pop_context_managed_teardown_task()
  275. if isinstance(setup_task, list):
  276. setup_task = tuple(setup_task)
  277. if isinstance(teardown_task, list):
  278. teardown_task = tuple(teardown_task)
  279. cls.pop_teardown_downstream_of_setup()
  280. cls.pop_setup_upstream_of_teardown()
  281. cls.context_map.pop(setup_task, None)
  282. cls.context_map.pop(teardown_task, None)
  283. @classmethod
  284. def error(cls, message: str):
  285. cls.active = False
  286. cls.context_map.clear()
  287. cls._context_managed_setup_task = []
  288. cls._context_managed_teardown_task = []
  289. cls._previous_context_managed_setup_task = []
  290. cls._previous_context_managed_teardown_task = []
  291. raise ValueError(message)
  292. class SetupTeardownContext(BaseSetupTeardownContext):
  293. """Context manager for setup and teardown tasks."""
  294. @staticmethod
  295. def add_task(task: AbstractOperator | PlainXComArg):
  296. """Add task to context manager."""
  297. from airflow.models.xcom_arg import PlainXComArg
  298. if not SetupTeardownContext.active:
  299. raise AirflowException("Cannot add task to context outside the context manager.")
  300. if isinstance(task, PlainXComArg):
  301. task = task.operator
  302. SetupTeardownContext.update_context_map(task)