external_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import asyncio
  19. import typing
  20. from typing import Any
  21. from asgiref.sync import sync_to_async
  22. from deprecated import deprecated
  23. from sqlalchemy import func
  24. from airflow.exceptions import RemovedInAirflow3Warning
  25. from airflow.models import DagRun, TaskInstance
  26. from airflow.triggers.base import BaseTrigger, TriggerEvent
  27. from airflow.utils.sensor_helper import _get_count
  28. from airflow.utils.session import NEW_SESSION, provide_session
  29. from airflow.utils.state import TaskInstanceState
  30. from airflow.utils.timezone import utcnow
  31. if typing.TYPE_CHECKING:
  32. from datetime import datetime
  33. from sqlalchemy.orm import Session
  34. from airflow.utils.state import DagRunState
  35. class WorkflowTrigger(BaseTrigger):
  36. """
  37. A trigger to monitor tasks, task group and dag execution in Apache Airflow.
  38. :param external_dag_id: The ID of the external DAG.
  39. :param execution_dates: A list of execution dates for the external DAG.
  40. :param external_task_ids: A collection of external task IDs to wait for.
  41. :param external_task_group_id: The ID of the external task group to wait for.
  42. :param failed_states: States considered as failed for external tasks.
  43. :param skipped_states: States considered as skipped for external tasks.
  44. :param allowed_states: States considered as successful for external tasks.
  45. :param poke_interval: The interval (in seconds) for poking the external tasks.
  46. :param soft_fail: If True, the trigger will not fail the entire DAG on external task failure.
  47. """
  48. def __init__(
  49. self,
  50. external_dag_id: str,
  51. execution_dates: list,
  52. external_task_ids: typing.Collection[str] | None = None,
  53. external_task_group_id: str | None = None,
  54. failed_states: typing.Iterable[str] | None = None,
  55. skipped_states: typing.Iterable[str] | None = None,
  56. allowed_states: typing.Iterable[str] | None = None,
  57. poke_interval: float = 2.0,
  58. soft_fail: bool = False,
  59. **kwargs,
  60. ):
  61. self.external_dag_id = external_dag_id
  62. self.external_task_ids = external_task_ids
  63. self.external_task_group_id = external_task_group_id
  64. self.failed_states = failed_states
  65. self.skipped_states = skipped_states
  66. self.allowed_states = allowed_states
  67. self.execution_dates = execution_dates
  68. self.poke_interval = poke_interval
  69. self.soft_fail = soft_fail
  70. super().__init__(**kwargs)
  71. def serialize(self) -> tuple[str, dict[str, Any]]:
  72. """Serialize the trigger param and module path."""
  73. return (
  74. "airflow.triggers.external_task.WorkflowTrigger",
  75. {
  76. "external_dag_id": self.external_dag_id,
  77. "external_task_ids": self.external_task_ids,
  78. "external_task_group_id": self.external_task_group_id,
  79. "failed_states": self.failed_states,
  80. "skipped_states": self.skipped_states,
  81. "allowed_states": self.allowed_states,
  82. "execution_dates": self.execution_dates,
  83. "poke_interval": self.poke_interval,
  84. "soft_fail": self.soft_fail,
  85. },
  86. )
  87. async def run(self) -> typing.AsyncIterator[TriggerEvent]:
  88. """Check periodically tasks, task group or dag status."""
  89. while True:
  90. if self.failed_states:
  91. failed_count = await self._get_count(self.failed_states)
  92. if failed_count > 0:
  93. yield TriggerEvent({"status": "failed"})
  94. return
  95. else:
  96. yield TriggerEvent({"status": "success"})
  97. return
  98. if self.skipped_states:
  99. skipped_count = await self._get_count(self.skipped_states)
  100. if skipped_count > 0:
  101. yield TriggerEvent({"status": "skipped"})
  102. return
  103. allowed_count = await self._get_count(self.allowed_states)
  104. if allowed_count == len(self.execution_dates):
  105. yield TriggerEvent({"status": "success"})
  106. return
  107. self.log.info("Sleeping for %s seconds", self.poke_interval)
  108. await asyncio.sleep(self.poke_interval)
  109. @sync_to_async
  110. def _get_count(self, states: typing.Iterable[str] | None) -> int:
  111. """
  112. Get the count of records against dttm filter and states. Async wrapper for _get_count.
  113. :param states: task or dag states
  114. :return The count of records.
  115. """
  116. return _get_count(
  117. dttm_filter=self.execution_dates,
  118. external_task_ids=self.external_task_ids,
  119. external_task_group_id=self.external_task_group_id,
  120. external_dag_id=self.external_dag_id,
  121. states=states,
  122. )
  123. @deprecated(
  124. reason="TaskStateTrigger has been deprecated and will be removed in future.",
  125. category=RemovedInAirflow3Warning,
  126. )
  127. class TaskStateTrigger(BaseTrigger):
  128. """
  129. Waits asynchronously for a task in a different DAG to complete for a specific logical date.
  130. :param dag_id: The dag_id that contains the task you want to wait for
  131. :param task_id: The task_id that contains the task you want to
  132. wait for.
  133. :param states: allowed states, default is ``['success']``
  134. :param execution_dates: task execution time interval
  135. :param poll_interval: The time interval in seconds to check the state.
  136. The default value is 5 sec.
  137. :param trigger_start_time: time in Datetime format when the trigger was started. Is used
  138. to control the execution of trigger to prevent infinite loop in case if specified name
  139. of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter
  140. from the time, when the trigger was started and if the execution lasts more time than expected,
  141. the trigger will terminate with 'timeout' status.
  142. """
  143. def __init__(
  144. self,
  145. dag_id: str,
  146. execution_dates: list[datetime],
  147. trigger_start_time: datetime,
  148. states: list[str] | None = None,
  149. task_id: str | None = None,
  150. poll_interval: float = 2.0,
  151. ):
  152. super().__init__()
  153. self.dag_id = dag_id
  154. self.task_id = task_id
  155. self.states = states
  156. self.execution_dates = execution_dates
  157. self.poll_interval = poll_interval
  158. self.trigger_start_time = trigger_start_time
  159. self.states = states or [TaskInstanceState.SUCCESS.value]
  160. self._timeout_sec = 60
  161. def serialize(self) -> tuple[str, dict[str, typing.Any]]:
  162. """Serialize TaskStateTrigger arguments and classpath."""
  163. return (
  164. "airflow.triggers.external_task.TaskStateTrigger",
  165. {
  166. "dag_id": self.dag_id,
  167. "task_id": self.task_id,
  168. "states": self.states,
  169. "execution_dates": self.execution_dates,
  170. "poll_interval": self.poll_interval,
  171. "trigger_start_time": self.trigger_start_time,
  172. },
  173. )
  174. async def run(self) -> typing.AsyncIterator[TriggerEvent]:
  175. """
  176. Check periodically in the database to see if the dag exists and is in the running state.
  177. If found, wait until the task specified will reach one of the expected states.
  178. If dag with specified name was not in the running state after _timeout_sec seconds
  179. after starting execution process of the trigger, terminate with status 'timeout'.
  180. """
  181. try:
  182. while True:
  183. delta = utcnow() - self.trigger_start_time
  184. if delta.total_seconds() < self._timeout_sec:
  185. # mypy confuses typing here
  186. if await self.count_running_dags() == 0: # type: ignore[call-arg]
  187. self.log.info("Waiting for DAG to start execution...")
  188. await asyncio.sleep(self.poll_interval)
  189. else:
  190. yield TriggerEvent({"status": "timeout"})
  191. return
  192. # mypy confuses typing here
  193. if await self.count_tasks() == len(self.execution_dates): # type: ignore[call-arg]
  194. yield TriggerEvent({"status": "success"})
  195. return
  196. self.log.info("Task is still running, sleeping for %s seconds...", self.poll_interval)
  197. await asyncio.sleep(self.poll_interval)
  198. except Exception:
  199. yield TriggerEvent({"status": "failed"})
  200. @sync_to_async
  201. @provide_session
  202. def count_running_dags(self, session: Session):
  203. """Count how many dag instances in running state in the database."""
  204. dags = (
  205. session.query(func.count("*"))
  206. .filter(
  207. TaskInstance.dag_id == self.dag_id,
  208. TaskInstance.execution_date.in_(self.execution_dates),
  209. TaskInstance.state.in_(["running", "success"]),
  210. )
  211. .scalar()
  212. )
  213. return dags
  214. @sync_to_async
  215. @provide_session
  216. def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
  217. """Count how many task instances in the database match our criteria."""
  218. count = (
  219. session.query(func.count("*")) # .count() is inefficient
  220. .filter(
  221. TaskInstance.dag_id == self.dag_id,
  222. TaskInstance.task_id == self.task_id,
  223. TaskInstance.state.in_(self.states),
  224. TaskInstance.execution_date.in_(self.execution_dates),
  225. )
  226. .scalar()
  227. )
  228. return typing.cast(int, count)
  229. class DagStateTrigger(BaseTrigger):
  230. """
  231. Waits asynchronously for a DAG to complete for a specific logical date.
  232. :param dag_id: The dag_id that contains the task you want to wait for
  233. :param states: allowed states, default is ``['success']``
  234. :param execution_dates: The logical date at which DAG run.
  235. :param poll_interval: The time interval in seconds to check the state.
  236. The default value is 5.0 sec.
  237. """
  238. def __init__(
  239. self,
  240. dag_id: str,
  241. states: list[DagRunState],
  242. execution_dates: list[datetime],
  243. poll_interval: float = 5.0,
  244. ):
  245. super().__init__()
  246. self.dag_id = dag_id
  247. self.states = states
  248. self.execution_dates = execution_dates
  249. self.poll_interval = poll_interval
  250. def serialize(self) -> tuple[str, dict[str, typing.Any]]:
  251. """Serialize DagStateTrigger arguments and classpath."""
  252. return (
  253. "airflow.triggers.external_task.DagStateTrigger",
  254. {
  255. "dag_id": self.dag_id,
  256. "states": self.states,
  257. "execution_dates": self.execution_dates,
  258. "poll_interval": self.poll_interval,
  259. },
  260. )
  261. async def run(self) -> typing.AsyncIterator[TriggerEvent]:
  262. """Check periodically if the dag run exists, and has hit one of the states yet, or not."""
  263. while True:
  264. # mypy confuses typing here
  265. num_dags = await self.count_dags() # type: ignore[call-arg]
  266. if num_dags == len(self.execution_dates):
  267. yield TriggerEvent(self.serialize())
  268. return
  269. await asyncio.sleep(self.poll_interval)
  270. @sync_to_async
  271. @provide_session
  272. def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
  273. """Count how many dag runs in the database match our criteria."""
  274. count = (
  275. session.query(func.count("*")) # .count() is inefficient
  276. .filter(
  277. DagRun.dag_id == self.dag_id,
  278. DagRun.state.in_(self.states),
  279. DagRun.execution_date.in_(self.execution_dates),
  280. )
  281. .scalar()
  282. )
  283. return typing.cast(int, count)