task_group.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """A collection of closely related tasks on the same DAG that should be grouped together visually."""
  19. from __future__ import annotations
  20. import copy
  21. import functools
  22. import operator
  23. import weakref
  24. from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
  25. import methodtools
  26. import re2
  27. from airflow.exceptions import (
  28. AirflowDagCycleException,
  29. AirflowException,
  30. DuplicateTaskIdFound,
  31. TaskAlreadyInTaskGroup,
  32. )
  33. from airflow.models.taskmixin import DAGNode
  34. from airflow.serialization.enums import DagAttributeTypes
  35. from airflow.utils.helpers import validate_group_key, validate_instance_args
  36. from airflow.utils.trigger_rule import TriggerRule
  37. if TYPE_CHECKING:
  38. from sqlalchemy.orm import Session
  39. from airflow.models.abstractoperator import AbstractOperator
  40. from airflow.models.baseoperator import BaseOperator
  41. from airflow.models.dag import DAG
  42. from airflow.models.expandinput import ExpandInput
  43. from airflow.models.operator import Operator
  44. from airflow.models.taskmixin import DependencyMixin
  45. from airflow.utils.edgemodifier import EdgeModifier
  46. # TODO: The following mapping is used to validate that the arguments passed to the TaskGroup are of the
  47. # correct type. This is a temporary solution until we find a more sophisticated method for argument
  48. # validation. One potential method is to use get_type_hints from the typing module. However, this is not
  49. # fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
  50. # version that supports `get_type_hints` effectively or find a better approach, we can replace this
  51. # manual type-checking method.
  52. TASKGROUP_ARGS_EXPECTED_TYPES = {
  53. "group_id": str,
  54. "prefix_group_id": bool,
  55. "tooltip": str,
  56. "ui_color": str,
  57. "ui_fgcolor": str,
  58. "add_suffix_on_collision": bool,
  59. }
  60. class TaskGroup(DAGNode):
  61. """
  62. A collection of tasks.
  63. When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across
  64. all tasks within the group if necessary.
  65. :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
  66. with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
  67. set to None.
  68. :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
  69. this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
  70. Default is True.
  71. :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
  72. for the root TaskGroup.
  73. :param dag: The DAG that this TaskGroup belongs to.
  74. :param default_args: A dictionary of default parameters to be used
  75. as constructor keyword parameters when initialising operators,
  76. will override default_args defined in the DAG level.
  77. Note that operators have the same hook, and precede those defined
  78. here, meaning that if your dict contains `'depends_on_past': True`
  79. here and `'depends_on_past': False` in the operator's call
  80. `default_args`, the actual value will be `False`.
  81. :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
  82. :param ui_color: The fill color of the TaskGroup node when displayed in the UI
  83. :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
  84. :param add_suffix_on_collision: If this task group name already exists,
  85. automatically add `__1` etc suffixes
  86. """
  87. used_group_ids: set[str | None]
  88. def __init__(
  89. self,
  90. group_id: str | None,
  91. prefix_group_id: bool = True,
  92. parent_group: TaskGroup | None = None,
  93. dag: DAG | None = None,
  94. default_args: dict[str, Any] | None = None,
  95. tooltip: str = "",
  96. ui_color: str = "CornflowerBlue",
  97. ui_fgcolor: str = "#000",
  98. add_suffix_on_collision: bool = False,
  99. ):
  100. from airflow.models.dag import DagContext
  101. self.prefix_group_id = prefix_group_id
  102. self.default_args = copy.deepcopy(default_args or {})
  103. dag = dag or DagContext.get_current_dag()
  104. if group_id is None:
  105. # This creates a root TaskGroup.
  106. if parent_group:
  107. raise AirflowException("Root TaskGroup cannot have parent_group")
  108. # used_group_ids is shared across all TaskGroups in the same DAG to keep track
  109. # of used group_id to avoid duplication.
  110. self.used_group_ids = set()
  111. self.dag = dag
  112. else:
  113. if prefix_group_id:
  114. # If group id is used as prefix, it should not contain spaces nor dots
  115. # because it is used as prefix in the task_id
  116. validate_group_key(group_id)
  117. else:
  118. if not isinstance(group_id, str):
  119. raise ValueError("group_id must be str")
  120. if not group_id:
  121. raise ValueError("group_id must not be empty")
  122. if not parent_group and not dag:
  123. raise AirflowException("TaskGroup can only be used inside a dag")
  124. parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
  125. if not parent_group:
  126. raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
  127. if dag is not parent_group.dag:
  128. raise RuntimeError(
  129. "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
  130. )
  131. self.used_group_ids = parent_group.used_group_ids
  132. # if given group_id already used assign suffix by incrementing largest used suffix integer
  133. # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
  134. self._group_id = group_id
  135. self._check_for_group_id_collisions(add_suffix_on_collision)
  136. self.children: dict[str, DAGNode] = {}
  137. if parent_group:
  138. parent_group.add(self)
  139. self._update_default_args(parent_group)
  140. self.used_group_ids.add(self.group_id)
  141. if self.group_id:
  142. self.used_group_ids.add(self.downstream_join_id)
  143. self.used_group_ids.add(self.upstream_join_id)
  144. self.tooltip = tooltip
  145. self.ui_color = ui_color
  146. self.ui_fgcolor = ui_fgcolor
  147. # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
  148. # so that we can optimize the number of edges when entire TaskGroups depend on each other.
  149. self.upstream_group_ids: set[str | None] = set()
  150. self.downstream_group_ids: set[str | None] = set()
  151. self.upstream_task_ids = set()
  152. self.downstream_task_ids = set()
  153. validate_instance_args(self, TASKGROUP_ARGS_EXPECTED_TYPES)
  154. def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
  155. if self._group_id is None:
  156. return
  157. # if given group_id already used assign suffix by incrementing largest used suffix integer
  158. # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
  159. if self._group_id in self.used_group_ids:
  160. if not add_suffix_on_collision:
  161. raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
  162. base = re2.split(r"__\d+$", self._group_id)[0]
  163. suffixes = sorted(
  164. int(re2.split(r"^.+__", used_group_id)[1])
  165. for used_group_id in self.used_group_ids
  166. if used_group_id is not None and re2.match(rf"^{base}__\d+$", used_group_id)
  167. )
  168. if not suffixes:
  169. self._group_id += "__1"
  170. else:
  171. self._group_id = f"{base}__{suffixes[-1] + 1}"
  172. def _update_default_args(self, parent_group: TaskGroup):
  173. if parent_group.default_args:
  174. self.default_args = {**parent_group.default_args, **self.default_args}
  175. @classmethod
  176. def create_root(cls, dag: DAG) -> TaskGroup:
  177. """Create a root TaskGroup with no group_id or parent."""
  178. return cls(group_id=None, dag=dag)
  179. @property
  180. def node_id(self):
  181. return self.group_id
  182. @property
  183. def is_root(self) -> bool:
  184. """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
  185. return not self.group_id
  186. @property
  187. def parent_group(self) -> TaskGroup | None:
  188. return self.task_group
  189. def __iter__(self):
  190. for child in self.children.values():
  191. yield from self._iter_child(child)
  192. @staticmethod
  193. def _iter_child(child):
  194. """Iterate over the children of this TaskGroup."""
  195. if isinstance(child, TaskGroup):
  196. yield from child
  197. else:
  198. yield child
  199. def add(self, task: DAGNode) -> DAGNode:
  200. """
  201. Add a task to this TaskGroup.
  202. :meta private:
  203. """
  204. from airflow.models.abstractoperator import AbstractOperator
  205. if TaskGroupContext.active:
  206. if task.task_group and task.task_group != self:
  207. task.task_group.children.pop(task.node_id, None)
  208. task.task_group = self
  209. existing_tg = task.task_group
  210. if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
  211. raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
  212. # Set the TG first, as setting it might change the return value of node_id!
  213. task.task_group = weakref.proxy(self)
  214. key = task.node_id
  215. if key in self.children:
  216. node_type = "Task" if hasattr(task, "task_id") else "Task Group"
  217. raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
  218. if isinstance(task, TaskGroup):
  219. if self.dag:
  220. if task.dag is not None and self.dag is not task.dag:
  221. raise RuntimeError(
  222. "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
  223. )
  224. task.dag = self.dag
  225. if task.children:
  226. raise AirflowException("Cannot add a non-empty TaskGroup")
  227. self.children[key] = task
  228. return task
  229. def _remove(self, task: DAGNode) -> None:
  230. key = task.node_id
  231. if key not in self.children:
  232. raise KeyError(f"Node id {key!r} not part of this task group")
  233. self.used_group_ids.remove(key)
  234. del self.children[key]
  235. @property
  236. def group_id(self) -> str | None:
  237. """group_id of this TaskGroup."""
  238. if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id:
  239. # defer to parent whether it adds a prefix
  240. return self.task_group.child_id(self._group_id)
  241. return self._group_id
  242. @property
  243. def label(self) -> str | None:
  244. """group_id excluding parent's group_id used as the node label in UI."""
  245. return self._group_id
  246. def update_relative(
  247. self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
  248. ) -> None:
  249. """
  250. Override TaskMixin.update_relative.
  251. Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
  252. accordingly so that we can reduce the number of edges when displaying Graph view.
  253. """
  254. if isinstance(other, TaskGroup):
  255. # Handles setting relationship between a TaskGroup and another TaskGroup
  256. if upstream:
  257. parent, child = (self, other)
  258. if edge_modifier:
  259. edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
  260. else:
  261. parent, child = (other, self)
  262. if edge_modifier:
  263. edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
  264. parent.upstream_group_ids.add(child.group_id)
  265. child.downstream_group_ids.add(parent.group_id)
  266. else:
  267. # Handles setting relationship between a TaskGroup and a task
  268. for task in other.roots:
  269. if not isinstance(task, DAGNode):
  270. raise AirflowException(
  271. "Relationships can only be set between TaskGroup "
  272. f"or operators; received {task.__class__.__name__}"
  273. )
  274. # Do not set a relationship between a TaskGroup and a Label's roots
  275. if self == task:
  276. continue
  277. if upstream:
  278. self.upstream_task_ids.add(task.node_id)
  279. if edge_modifier:
  280. edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
  281. else:
  282. self.downstream_task_ids.add(task.node_id)
  283. if edge_modifier:
  284. edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
  285. def _set_relatives(
  286. self,
  287. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  288. upstream: bool = False,
  289. edge_modifier: EdgeModifier | None = None,
  290. ) -> None:
  291. """
  292. Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
  293. Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
  294. """
  295. if not isinstance(task_or_task_list, Sequence):
  296. task_or_task_list = [task_or_task_list]
  297. for task_like in task_or_task_list:
  298. self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
  299. if upstream:
  300. for task in self.get_roots():
  301. task.set_upstream(task_or_task_list)
  302. else:
  303. for task in self.get_leaves():
  304. task.set_downstream(task_or_task_list)
  305. def __enter__(self) -> TaskGroup:
  306. TaskGroupContext.push_context_managed_task_group(self)
  307. return self
  308. def __exit__(self, _type, _value, _tb):
  309. TaskGroupContext.pop_context_managed_task_group()
  310. def has_task(self, task: BaseOperator) -> bool:
  311. """Return True if this TaskGroup or its children TaskGroups contains the given task."""
  312. if task.task_id in self.children:
  313. return True
  314. return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
  315. @property
  316. def roots(self) -> list[BaseOperator]:
  317. """Required by TaskMixin."""
  318. return list(self.get_roots())
  319. @property
  320. def leaves(self) -> list[BaseOperator]:
  321. """Required by TaskMixin."""
  322. return list(self.get_leaves())
  323. def get_roots(self) -> Generator[BaseOperator, None, None]:
  324. """Return a generator of tasks with no upstream dependencies within the TaskGroup."""
  325. tasks = list(self)
  326. ids = {x.task_id for x in tasks}
  327. for task in tasks:
  328. if task.upstream_task_ids.isdisjoint(ids):
  329. yield task
  330. def get_leaves(self) -> Generator[BaseOperator, None, None]:
  331. """Return a generator of tasks with no downstream dependencies within the TaskGroup."""
  332. tasks = list(self)
  333. ids = {x.task_id for x in tasks}
  334. def has_non_teardown_downstream(task, exclude: str):
  335. for down_task in task.downstream_list:
  336. if down_task.task_id == exclude:
  337. continue
  338. elif down_task.task_id not in ids:
  339. continue
  340. elif not down_task.is_teardown:
  341. return True
  342. return False
  343. def recurse_for_first_non_teardown(task):
  344. for upstream_task in task.upstream_list:
  345. if upstream_task.task_id not in ids:
  346. # upstream task is not in task group
  347. continue
  348. elif upstream_task.is_teardown:
  349. yield from recurse_for_first_non_teardown(upstream_task)
  350. elif task.is_teardown and upstream_task.is_setup:
  351. # don't go through the teardown-to-setup path
  352. continue
  353. # return unless upstream task already has non-teardown downstream in group
  354. elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
  355. yield upstream_task
  356. for task in tasks:
  357. if task.downstream_task_ids.isdisjoint(ids):
  358. if not task.is_teardown:
  359. yield task
  360. else:
  361. yield from recurse_for_first_non_teardown(task)
  362. def child_id(self, label):
  363. """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
  364. if self.prefix_group_id:
  365. group_id = self.group_id
  366. if group_id:
  367. return f"{group_id}.{label}"
  368. return label
  369. @property
  370. def upstream_join_id(self) -> str:
  371. """
  372. Creates a unique ID for upstream dependencies of this TaskGroup.
  373. If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
  374. upstream_join_id will be created in Graph view to join the outgoing edges from this
  375. TaskGroup to reduce the total number of edges needed to be displayed.
  376. """
  377. return f"{self.group_id}.upstream_join_id"
  378. @property
  379. def downstream_join_id(self) -> str:
  380. """
  381. Creates a unique ID for downstream dependencies of this TaskGroup.
  382. If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
  383. downstream_join_id will be created in Graph view to join the outgoing edges from this
  384. TaskGroup to reduce the total number of edges needed to be displayed.
  385. """
  386. return f"{self.group_id}.downstream_join_id"
  387. def get_task_group_dict(self) -> dict[str, TaskGroup]:
  388. """Return a flat dictionary of group_id: TaskGroup."""
  389. task_group_map = {}
  390. def build_map(task_group):
  391. if not isinstance(task_group, TaskGroup):
  392. return
  393. task_group_map[task_group.group_id] = task_group
  394. for child in task_group.children.values():
  395. build_map(child)
  396. build_map(self)
  397. return task_group_map
  398. def get_child_by_label(self, label: str) -> DAGNode:
  399. """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
  400. return self.children[self.child_id(label)]
  401. def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
  402. """Serialize task group; required by DAGNode."""
  403. from airflow.serialization.serialized_objects import TaskGroupSerialization
  404. return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
  405. def hierarchical_alphabetical_sort(self):
  406. """
  407. Sort children in hierarchical alphabetical order.
  408. - groups in alphabetical order first
  409. - tasks in alphabetical order after them.
  410. :return: list of tasks in hierarchical alphabetical order
  411. """
  412. return sorted(
  413. self.children.values(), key=lambda node: (not isinstance(node, TaskGroup), node.node_id)
  414. )
  415. def topological_sort(self, _include_subdag_tasks: bool = False):
  416. """
  417. Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
  418. :return: list of tasks in topological order
  419. """
  420. # This uses a modified version of Kahn's Topological Sort algorithm to
  421. # not have to pre-compute the "in-degree" of the nodes.
  422. from airflow.operators.subdag import SubDagOperator # Avoid circular import
  423. graph_unsorted = copy.copy(self.children)
  424. graph_sorted: list[DAGNode] = []
  425. # special case
  426. if not self.children:
  427. return graph_sorted
  428. # Run until the unsorted graph is empty.
  429. while graph_unsorted:
  430. # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
  431. # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
  432. # pair from the unsorted graph, and append it to the sorted graph. Note here that by using
  433. # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
  434. # the unsorted graph as we move through it.
  435. #
  436. # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
  437. # during each pass through the graph. If not, we need to exit as the graph therefore can't be
  438. # sorted.
  439. acyclic = False
  440. for node in list(graph_unsorted.values()):
  441. for edge in node.upstream_list:
  442. if edge.node_id in graph_unsorted:
  443. break
  444. # Check for task's group is a child (or grand child) of this TG,
  445. tg = edge.task_group
  446. while tg:
  447. if tg.node_id in graph_unsorted:
  448. break
  449. tg = tg.task_group
  450. if tg:
  451. # We are already going to visit that TG
  452. break
  453. else:
  454. acyclic = True
  455. del graph_unsorted[node.node_id]
  456. graph_sorted.append(node)
  457. if _include_subdag_tasks and isinstance(node, SubDagOperator):
  458. graph_sorted.extend(
  459. node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
  460. )
  461. if not acyclic:
  462. raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
  463. return graph_sorted
  464. def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
  465. """
  466. Return mapped task groups in the hierarchy.
  467. Groups are returned from the closest to the outmost. If *self* is a
  468. mapped task group, it is returned first.
  469. :meta private:
  470. """
  471. group: TaskGroup | None = self
  472. while group is not None:
  473. if isinstance(group, MappedTaskGroup):
  474. yield group
  475. group = group.task_group
  476. def iter_tasks(self) -> Iterator[AbstractOperator]:
  477. """Return an iterator of the child tasks."""
  478. from airflow.models.abstractoperator import AbstractOperator
  479. groups_to_visit = [self]
  480. while groups_to_visit:
  481. visiting = groups_to_visit.pop(0)
  482. for child in visiting.children.values():
  483. if isinstance(child, AbstractOperator):
  484. yield child
  485. elif isinstance(child, TaskGroup):
  486. groups_to_visit.append(child)
  487. else:
  488. raise ValueError(
  489. f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
  490. )
  491. class MappedTaskGroup(TaskGroup):
  492. """
  493. A mapped task group.
  494. This doesn't really do anything special, just holds some additional metadata
  495. for expansion later.
  496. Don't instantiate this class directly; call *expand* or *expand_kwargs* on
  497. a ``@task_group`` function instead.
  498. """
  499. def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
  500. super().__init__(**kwargs)
  501. self._expand_input = expand_input
  502. def __iter__(self):
  503. from airflow.models.abstractoperator import AbstractOperator
  504. for child in self.children.values():
  505. if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS:
  506. raise ValueError(
  507. "Task-generated mapping within a mapped task group is not allowed with trigger rule 'always'"
  508. )
  509. yield from self._iter_child(child)
  510. def iter_mapped_dependencies(self) -> Iterator[Operator]:
  511. """Upstream dependencies that provide XComs used by this mapped task group."""
  512. from airflow.models.xcom_arg import XComArg
  513. for op, _ in XComArg.iter_xcom_references(self._expand_input):
  514. yield op
  515. @methodtools.lru_cache(maxsize=None)
  516. def get_parse_time_mapped_ti_count(self) -> int:
  517. """
  518. Return the Number of instances a task in this group should be mapped to, when a DAG run is created.
  519. This only considers literal mapped arguments, and would return *None*
  520. when any non-literal values are used for mapping.
  521. If this group is inside mapped task groups, all the nested counts are
  522. multiplied and accounted.
  523. :meta private:
  524. :raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
  525. :return: The total number of mapped instances each task should have.
  526. """
  527. return functools.reduce(
  528. operator.mul,
  529. (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
  530. )
  531. def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
  532. """
  533. Return the number of instances a task in this group should be mapped to at run time.
  534. This considers both literal and non-literal mapped arguments, and the
  535. result is therefore available when all depended tasks have finished. The
  536. return value should be identical to ``parse_time_mapped_ti_count`` if
  537. all mapped arguments are literal.
  538. If this group is inside mapped task groups, all the nested counts are
  539. multiplied and accounted.
  540. :meta private:
  541. :raise NotFullyPopulated: If upstream tasks are not all complete yet.
  542. :return: Total number of mapped TIs this task should have.
  543. """
  544. groups = self.iter_mapped_task_groups()
  545. return functools.reduce(
  546. operator.mul,
  547. (g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
  548. )
  549. def __exit__(self, exc_type, exc_val, exc_tb):
  550. for op, _ in self._expand_input.iter_references():
  551. self.set_upstream(op)
  552. super().__exit__(exc_type, exc_val, exc_tb)
  553. class TaskGroupContext:
  554. """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
  555. active: bool = False
  556. _context_managed_task_group: TaskGroup | None = None
  557. _previous_context_managed_task_groups: list[TaskGroup] = []
  558. @classmethod
  559. def push_context_managed_task_group(cls, task_group: TaskGroup):
  560. """Push a TaskGroup into the list of managed TaskGroups."""
  561. if cls._context_managed_task_group:
  562. cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
  563. cls._context_managed_task_group = task_group
  564. cls.active = True
  565. @classmethod
  566. def pop_context_managed_task_group(cls) -> TaskGroup | None:
  567. """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup."""
  568. old_task_group = cls._context_managed_task_group
  569. if cls._previous_context_managed_task_groups:
  570. cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
  571. else:
  572. cls._context_managed_task_group = None
  573. cls.active = False
  574. return old_task_group
  575. @classmethod
  576. def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
  577. """Get the current TaskGroup."""
  578. from airflow.models.dag import DagContext
  579. if not cls._context_managed_task_group:
  580. dag = dag or DagContext.get_current_dag()
  581. if dag:
  582. # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
  583. return dag.task_group
  584. return cls._context_managed_task_group
  585. def task_group_to_dict(task_item_or_group):
  586. """Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
  587. from airflow.models.abstractoperator import AbstractOperator
  588. from airflow.models.mappedoperator import MappedOperator
  589. if isinstance(task := task_item_or_group, AbstractOperator):
  590. setup_teardown_type = {}
  591. is_mapped = {}
  592. if task.is_setup is True:
  593. setup_teardown_type["setupTeardownType"] = "setup"
  594. elif task.is_teardown is True:
  595. setup_teardown_type["setupTeardownType"] = "teardown"
  596. if isinstance(task, MappedOperator):
  597. is_mapped["isMapped"] = True
  598. return {
  599. "id": task.task_id,
  600. "value": {
  601. "label": task.label,
  602. "labelStyle": f"fill:{task.ui_fgcolor};",
  603. "style": f"fill:{task.ui_color};",
  604. "rx": 5,
  605. "ry": 5,
  606. **is_mapped,
  607. **setup_teardown_type,
  608. },
  609. }
  610. task_group = task_item_or_group
  611. is_mapped = isinstance(task_group, MappedTaskGroup)
  612. children = [
  613. task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
  614. ]
  615. if task_group.upstream_group_ids or task_group.upstream_task_ids:
  616. children.append(
  617. {
  618. "id": task_group.upstream_join_id,
  619. "value": {
  620. "label": "",
  621. "labelStyle": f"fill:{task_group.ui_fgcolor};",
  622. "style": f"fill:{task_group.ui_color};",
  623. "shape": "circle",
  624. },
  625. }
  626. )
  627. if task_group.downstream_group_ids or task_group.downstream_task_ids:
  628. # This is the join node used to reduce the number of edges between two TaskGroup.
  629. children.append(
  630. {
  631. "id": task_group.downstream_join_id,
  632. "value": {
  633. "label": "",
  634. "labelStyle": f"fill:{task_group.ui_fgcolor};",
  635. "style": f"fill:{task_group.ui_color};",
  636. "shape": "circle",
  637. },
  638. }
  639. )
  640. return {
  641. "id": task_group.group_id,
  642. "value": {
  643. "label": task_group.label,
  644. "labelStyle": f"fill:{task_group.ui_fgcolor};",
  645. "style": f"fill:{task_group.ui_color}",
  646. "rx": 5,
  647. "ry": 5,
  648. "clusterLabelPos": "top",
  649. "tooltip": task_group.tooltip,
  650. "isMapped": is_mapped,
  651. },
  652. "children": children,
  653. }