subdag.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """
  19. This module is deprecated. Please use :mod:`airflow.utils.task_group`.
  20. The module which provides a way to nest your DAGs and so your levels of complexity.
  21. """
  22. from __future__ import annotations
  23. import warnings
  24. from enum import Enum
  25. from typing import TYPE_CHECKING
  26. from sqlalchemy import select
  27. from airflow.api.common.experimental.get_task_instance import get_task_instance
  28. from airflow.api_internal.internal_api_call import InternalApiConfig
  29. from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskInstanceNotFound
  30. from airflow.models import DagRun
  31. from airflow.models.dag import DagContext
  32. from airflow.models.pool import Pool
  33. from airflow.models.taskinstance import TaskInstance
  34. from airflow.sensors.base import BaseSensorOperator
  35. from airflow.utils.session import NEW_SESSION, create_session, provide_session
  36. from airflow.utils.state import DagRunState, TaskInstanceState
  37. from airflow.utils.types import DagRunType
  38. if TYPE_CHECKING:
  39. from datetime import datetime
  40. from sqlalchemy.orm.session import Session
  41. from airflow.models.dag import DAG
  42. from airflow.utils.context import Context
  43. class SkippedStatePropagationOptions(Enum):
  44. """Available options for skipped state propagation of subdag's tasks to parent dag tasks."""
  45. ALL_LEAVES = "all_leaves"
  46. ANY_LEAF = "any_leaf"
  47. class SubDagOperator(BaseSensorOperator):
  48. """
  49. This class is deprecated, please use :class:`airflow.utils.task_group.TaskGroup`.
  50. This runs a sub dag. By convention, a sub dag's dag_id
  51. should be prefixed by its parent and a dot. As in `parent.child`.
  52. Although SubDagOperator can occupy a pool/concurrency slot,
  53. user can specify the mode=reschedule so that the slot will be
  54. released periodically to avoid potential deadlock.
  55. :param subdag: the DAG object to run as a subdag of the current DAG.
  56. :param session: sqlalchemy session
  57. :param conf: Configuration for the subdag
  58. :param propagate_skipped_state: by setting this argument you can define
  59. whether the skipped state of leaf task(s) should be propagated to the
  60. parent dag's downstream task.
  61. """
  62. ui_color = "#555"
  63. ui_fgcolor = "#fff"
  64. subdag: DAG
  65. @provide_session
  66. def __init__(
  67. self,
  68. *,
  69. subdag: DAG,
  70. session: Session = NEW_SESSION,
  71. conf: dict | None = None,
  72. propagate_skipped_state: SkippedStatePropagationOptions | None = None,
  73. **kwargs,
  74. ) -> None:
  75. super().__init__(**kwargs)
  76. self.subdag = subdag
  77. self.conf = conf
  78. self.propagate_skipped_state = propagate_skipped_state
  79. self._validate_dag(kwargs)
  80. if not InternalApiConfig.get_use_internal_api():
  81. self._validate_pool(session)
  82. warnings.warn(
  83. """This class is deprecated. Please use `airflow.utils.task_group.TaskGroup`.""",
  84. RemovedInAirflow3Warning,
  85. stacklevel=4,
  86. )
  87. def _validate_dag(self, kwargs):
  88. dag = kwargs.get("dag") or DagContext.get_current_dag()
  89. if not dag:
  90. raise AirflowException("Please pass in the `dag` param or call within a DAG context manager")
  91. if dag.dag_id + "." + kwargs["task_id"] != self.subdag.dag_id:
  92. raise AirflowException(
  93. f"The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. "
  94. f"Expected '{dag.dag_id}.{kwargs['task_id']}'; received '{self.subdag.dag_id}'."
  95. )
  96. def _validate_pool(self, session):
  97. if self.pool:
  98. conflicts = [t for t in self.subdag.tasks if t.pool == self.pool]
  99. if conflicts:
  100. # only query for pool conflicts if one may exist
  101. pool = session.scalar(select(Pool).where(Pool.slots == 1, Pool.pool == self.pool))
  102. if pool and any(t.pool == self.pool for t in self.subdag.tasks):
  103. raise AirflowException(
  104. f"SubDagOperator {self.task_id} and subdag task{'s' if len(conflicts) > 1 else ''} "
  105. f"{', '.join(t.task_id for t in conflicts)} both use pool {self.pool}, "
  106. f"but the pool only has 1 slot. The subdag tasks will never run."
  107. )
  108. def _get_dagrun(self, execution_date):
  109. dag_runs = DagRun.find(
  110. dag_id=self.subdag.dag_id,
  111. execution_date=execution_date,
  112. )
  113. return dag_runs[0] if dag_runs else None
  114. def _reset_dag_run_and_task_instances(self, dag_run: DagRun, execution_date: datetime) -> None:
  115. """
  116. Set task instance states to allow for execution.
  117. The state of the DAG run will be set to RUNNING, and failed task
  118. instances to ``None`` for scheduler to pick up.
  119. :param dag_run: DAG run to reset.
  120. :param execution_date: Execution date to select task instances.
  121. """
  122. with create_session() as session:
  123. dag_run.state = DagRunState.RUNNING
  124. session.merge(dag_run)
  125. failed_task_instances = session.scalars(
  126. select(TaskInstance)
  127. .where(TaskInstance.dag_id == self.subdag.dag_id)
  128. .where(TaskInstance.execution_date == execution_date)
  129. .where(TaskInstance.state.in_((TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED)))
  130. )
  131. for task_instance in failed_task_instances:
  132. task_instance.state = None
  133. session.merge(task_instance)
  134. session.commit()
  135. def pre_execute(self, context):
  136. super().pre_execute(context)
  137. execution_date = context["execution_date"]
  138. dag_run = self._get_dagrun(execution_date)
  139. if dag_run is None:
  140. if context["data_interval_start"] is None or context["data_interval_end"] is None:
  141. data_interval: tuple[datetime, datetime] | None = None
  142. else:
  143. data_interval = (context["data_interval_start"], context["data_interval_end"])
  144. dag_run = self.subdag.create_dagrun(
  145. run_type=DagRunType.SCHEDULED,
  146. execution_date=execution_date,
  147. state=DagRunState.RUNNING,
  148. conf=self.conf,
  149. external_trigger=True,
  150. data_interval=data_interval,
  151. )
  152. self.log.info("Created DagRun: %s", dag_run.run_id)
  153. else:
  154. self.log.info("Found existing DagRun: %s", dag_run.run_id)
  155. if dag_run.state == DagRunState.FAILED:
  156. self._reset_dag_run_and_task_instances(dag_run, execution_date)
  157. def poke(self, context: Context):
  158. execution_date = context["execution_date"]
  159. dag_run = self._get_dagrun(execution_date=execution_date)
  160. return dag_run.state != DagRunState.RUNNING
  161. def post_execute(self, context, result=None):
  162. super().post_execute(context)
  163. execution_date = context["execution_date"]
  164. dag_run = self._get_dagrun(execution_date=execution_date)
  165. self.log.info("Execution finished. State is %s", dag_run.state)
  166. if dag_run.state != DagRunState.SUCCESS:
  167. raise AirflowException(f"Expected state: SUCCESS. Actual state: {dag_run.state}")
  168. if self.propagate_skipped_state and self._check_skipped_states(context):
  169. self._skip_downstream_tasks(context)
  170. def _check_skipped_states(self, context):
  171. leaves_tis = self._get_leaves_tis(context["execution_date"])
  172. if self.propagate_skipped_state == SkippedStatePropagationOptions.ANY_LEAF:
  173. return any(ti.state == TaskInstanceState.SKIPPED for ti in leaves_tis)
  174. if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES:
  175. return all(ti.state == TaskInstanceState.SKIPPED for ti in leaves_tis)
  176. raise AirflowException(
  177. f"Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used."
  178. )
  179. def _get_leaves_tis(self, execution_date):
  180. leaves_tis = []
  181. for leaf in self.subdag.leaves:
  182. try:
  183. ti = get_task_instance(
  184. dag_id=self.subdag.dag_id, task_id=leaf.task_id, execution_date=execution_date
  185. )
  186. leaves_tis.append(ti)
  187. except TaskInstanceNotFound:
  188. continue
  189. return leaves_tis
  190. def _skip_downstream_tasks(self, context):
  191. self.log.info(
  192. "Skipping downstream tasks because propagate_skipped_state is set to %s "
  193. "and skipped task(s) were found.",
  194. self.propagate_skipped_state,
  195. )
  196. downstream_tasks = context["task"].downstream_list
  197. self.log.debug("Downstream task_ids %s", downstream_tasks)
  198. if downstream_tasks:
  199. self.skip(
  200. context["dag_run"],
  201. context["execution_date"],
  202. downstream_tasks,
  203. map_index=context["ti"].map_index,
  204. )
  205. self.log.info("Done.")