base.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import datetime
  20. import functools
  21. import hashlib
  22. import time
  23. import traceback
  24. from datetime import timedelta
  25. from typing import TYPE_CHECKING, Any, Callable, Iterable
  26. from sqlalchemy import select
  27. from airflow import settings
  28. from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
  29. from airflow.configuration import conf
  30. from airflow.exceptions import (
  31. AirflowException,
  32. AirflowFailException,
  33. AirflowRescheduleException,
  34. AirflowSensorTimeout,
  35. AirflowSkipException,
  36. AirflowTaskTimeout,
  37. TaskDeferralError,
  38. )
  39. from airflow.executors.executor_loader import ExecutorLoader
  40. from airflow.models.baseoperator import BaseOperator
  41. from airflow.models.skipmixin import SkipMixin
  42. from airflow.models.taskreschedule import TaskReschedule
  43. from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
  44. from airflow.utils import timezone
  45. # We need to keep the import here because GCSToLocalFilesystemOperator released in
  46. # Google Provider before 3.0.0 imported apply_defaults from here.
  47. # See https://github.com/apache/airflow/issues/16035
  48. from airflow.utils.decorators import apply_defaults # noqa: F401
  49. from airflow.utils.session import NEW_SESSION, provide_session
  50. if TYPE_CHECKING:
  51. from sqlalchemy.orm.session import Session
  52. from airflow.utils.context import Context
  53. # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
  54. _MYSQL_TIMESTAMP_MAX = datetime.datetime(2038, 1, 19, 3, 14, 7, tzinfo=timezone.utc)
  55. @functools.lru_cache(maxsize=None)
  56. def _is_metadatabase_mysql() -> bool:
  57. if InternalApiConfig.get_use_internal_api():
  58. return False
  59. if settings.engine is None:
  60. raise AirflowException("Must initialize ORM first")
  61. return settings.engine.url.get_backend_name() == "mysql"
  62. class PokeReturnValue:
  63. """
  64. Optional return value for poke methods.
  65. Sensors can optionally return an instance of the PokeReturnValue class in the poke method.
  66. If an XCom value is supplied when the sensor is done, then the XCom value will be
  67. pushed through the operator return value.
  68. :param is_done: Set to true to indicate the sensor can stop poking.
  69. :param xcom_value: An optional XCOM value to be returned by the operator.
  70. """
  71. def __init__(self, is_done: bool, xcom_value: Any | None = None) -> None:
  72. self.xcom_value = xcom_value
  73. self.is_done = is_done
  74. def __bool__(self) -> bool:
  75. return self.is_done
  76. @internal_api_call
  77. @provide_session
  78. def _orig_start_date(
  79. dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION
  80. ):
  81. """
  82. Get the original start_date for a rescheduled task.
  83. :meta private:
  84. """
  85. return session.scalar(
  86. select(TaskReschedule)
  87. .where(
  88. TaskReschedule.dag_id == dag_id,
  89. TaskReschedule.task_id == task_id,
  90. TaskReschedule.run_id == run_id,
  91. TaskReschedule.map_index == map_index,
  92. # If the first try's record was not saved due to the Exception occurred and the following
  93. # transaction rollback, the next available attempt should be taken
  94. # to prevent falling in the endless rescheduling
  95. TaskReschedule.try_number >= try_number,
  96. )
  97. .order_by(TaskReschedule.id.asc())
  98. .with_only_columns(TaskReschedule.start_date)
  99. .limit(1)
  100. )
  101. class BaseSensorOperator(BaseOperator, SkipMixin):
  102. """
  103. Sensor operators are derived from this class and inherit these attributes.
  104. Sensor operators keep executing at a time interval and succeed when
  105. a criteria is met and fail if and when they time out.
  106. :param soft_fail: Set to true to mark the task as SKIPPED on failure.
  107. Mutually exclusive with never_fail.
  108. :param poke_interval: Time that the job should wait in between each try.
  109. Can be ``timedelta`` or ``float`` seconds.
  110. :param timeout: Time elapsed before the task times out and fails.
  111. Can be ``timedelta`` or ``float`` seconds.
  112. This should not be confused with ``execution_timeout`` of the
  113. ``BaseOperator`` class. ``timeout`` measures the time elapsed between the
  114. first poke and the current time (taking into account any
  115. reschedule delay between each poke), while ``execution_timeout``
  116. checks the **running** time of the task (leaving out any reschedule
  117. delay). In case that the ``mode`` is ``poke`` (see below), both of
  118. them are equivalent (as the sensor is never rescheduled), which is not
  119. the case in ``reschedule`` mode.
  120. :param mode: How the sensor operates.
  121. Options are: ``{ poke | reschedule }``, default is ``poke``.
  122. When set to ``poke`` the sensor is taking up a worker slot for its
  123. whole execution time and sleeps between pokes. Use this mode if the
  124. expected runtime of the sensor is short or if a short poke interval
  125. is required. Note that the sensor will hold onto a worker slot and
  126. a pool slot for the duration of the sensor's runtime in this mode.
  127. When set to ``reschedule`` the sensor task frees the worker slot when
  128. the criteria is not yet met and it's rescheduled at a later time. Use
  129. this mode if the time before the criteria is met is expected to be
  130. quite long. The poke interval should be more than one minute to
  131. prevent too much load on the scheduler.
  132. :param exponential_backoff: allow progressive longer waits between
  133. pokes by using exponential backoff algorithm
  134. :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
  135. :param silent_fail: If true, and poke method raises an exception different from
  136. AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
  137. and AirflowFailException, the sensor will log the error and continue
  138. its execution. Otherwise, the sensor task fails, and it can be retried
  139. based on the provided `retries` parameter.
  140. :param never_fail: If true, and poke method raises an exception, sensor will be skipped.
  141. Mutually exclusive with soft_fail.
  142. """
  143. ui_color: str = "#e6f1f2"
  144. valid_modes: Iterable[str] = ["poke", "reschedule"]
  145. # Adds one additional dependency for all sensor operators that checks if a
  146. # sensor task instance can be rescheduled.
  147. deps = BaseOperator.deps | {ReadyToRescheduleDep()}
  148. def __init__(
  149. self,
  150. *,
  151. poke_interval: timedelta | float = 60,
  152. timeout: timedelta | float = conf.getfloat("sensors", "default_timeout"),
  153. soft_fail: bool = False,
  154. mode: str = "poke",
  155. exponential_backoff: bool = False,
  156. max_wait: timedelta | float | None = None,
  157. silent_fail: bool = False,
  158. never_fail: bool = False,
  159. **kwargs,
  160. ) -> None:
  161. super().__init__(**kwargs)
  162. self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds()
  163. self.soft_fail = soft_fail
  164. self.timeout = self._coerce_timeout(timeout).total_seconds()
  165. self.mode = mode
  166. self.exponential_backoff = exponential_backoff
  167. self.max_wait = self._coerce_max_wait(max_wait)
  168. if soft_fail is True and never_fail is True:
  169. raise ValueError("soft_fail and never_fail are mutually exclusive, you can not provide both.")
  170. self.silent_fail = silent_fail
  171. self.never_fail = never_fail
  172. self._validate_input_values()
  173. @staticmethod
  174. def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta:
  175. if isinstance(poke_interval, timedelta):
  176. return poke_interval
  177. if isinstance(poke_interval, (int, float)) and poke_interval >= 0:
  178. return timedelta(seconds=poke_interval)
  179. raise AirflowException(
  180. "Operator arg `poke_interval` must be timedelta object or a non-negative number"
  181. )
  182. @staticmethod
  183. def _coerce_timeout(timeout: float | timedelta) -> timedelta:
  184. if isinstance(timeout, timedelta):
  185. return timeout
  186. if isinstance(timeout, (int, float)) and timeout >= 0:
  187. return timedelta(seconds=timeout)
  188. raise AirflowException("Operator arg `timeout` must be timedelta object or a non-negative number")
  189. @staticmethod
  190. def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None:
  191. if max_wait is None or isinstance(max_wait, timedelta):
  192. return max_wait
  193. if isinstance(max_wait, (int, float)) and max_wait >= 0:
  194. return timedelta(seconds=max_wait)
  195. raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number")
  196. def _validate_input_values(self) -> None:
  197. if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0:
  198. raise AirflowException("The poke_interval must be a non-negative number")
  199. if not isinstance(self.timeout, (int, float)) or self.timeout < 0:
  200. raise AirflowException("The timeout must be a non-negative number")
  201. if self.mode not in self.valid_modes:
  202. raise AirflowException(
  203. f"The mode must be one of {self.valid_modes},'{self.dag.dag_id if self.has_dag() else ''} "
  204. f".{self.task_id}'; received '{self.mode}'."
  205. )
  206. # Quick check for poke_interval isn't immediately over MySQL's TIMESTAMP limit.
  207. # This check is only rudimentary to catch trivial user errors, e.g. mistakenly
  208. # set the value to milliseconds instead of seconds. There's another check when
  209. # we actually try to reschedule to ensure database coherence.
  210. if self.reschedule and _is_metadatabase_mysql():
  211. if timezone.utcnow() + datetime.timedelta(seconds=self.poke_interval) > _MYSQL_TIMESTAMP_MAX:
  212. raise AirflowException(
  213. f"Cannot set poke_interval to {self.poke_interval} seconds in reschedule "
  214. f"mode since it will take reschedule time over MySQL's TIMESTAMP limit."
  215. )
  216. def poke(self, context: Context) -> bool | PokeReturnValue:
  217. """Override when deriving this class."""
  218. raise AirflowException("Override me.")
  219. def execute(self, context: Context) -> Any:
  220. started_at: datetime.datetime | float
  221. if self.reschedule:
  222. ti = context["ti"]
  223. max_tries: int = ti.max_tries or 0
  224. retries: int = self.retries or 0
  225. # If reschedule, use the start date of the first try (first try can be either the very
  226. # first execution of the task, or the first execution after the task was cleared).
  227. first_try_number = max_tries - retries + 1
  228. start_date = _orig_start_date(
  229. dag_id=ti.dag_id,
  230. task_id=ti.task_id,
  231. run_id=ti.run_id,
  232. map_index=ti.map_index,
  233. try_number=first_try_number,
  234. )
  235. if not start_date:
  236. start_date = timezone.utcnow()
  237. started_at = start_date
  238. def run_duration() -> float:
  239. # If we are in reschedule mode, then we have to compute diff
  240. # based on the time in a DB, so can't use time.monotonic
  241. return (timezone.utcnow() - start_date).total_seconds()
  242. else:
  243. started_at = start_monotonic = time.monotonic()
  244. def run_duration() -> float:
  245. return time.monotonic() - start_monotonic
  246. poke_count = 1
  247. log_dag_id = self.dag.dag_id if self.has_dag() else ""
  248. xcom_value = None
  249. while True:
  250. try:
  251. poke_return = self.poke(context)
  252. except (
  253. AirflowSensorTimeout,
  254. AirflowTaskTimeout,
  255. AirflowFailException,
  256. ) as e:
  257. if self.soft_fail:
  258. raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
  259. elif self.never_fail:
  260. raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
  261. raise e
  262. except AirflowSkipException as e:
  263. raise e
  264. except Exception as e:
  265. if self.silent_fail:
  266. self.log.error("Sensor poke failed: \n %s", traceback.format_exc())
  267. poke_return = False
  268. elif self.never_fail:
  269. raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
  270. else:
  271. raise e
  272. if poke_return:
  273. if isinstance(poke_return, PokeReturnValue):
  274. xcom_value = poke_return.xcom_value
  275. break
  276. if run_duration() > self.timeout:
  277. # If sensor is in soft fail mode but times out raise AirflowSkipException.
  278. message = (
  279. f"Sensor has timed out; run duration of {run_duration()} seconds exceeds "
  280. f"the specified timeout of {self.timeout}."
  281. )
  282. if self.soft_fail:
  283. raise AirflowSkipException(message)
  284. else:
  285. raise AirflowSensorTimeout(message)
  286. if self.reschedule:
  287. next_poke_interval = self._get_next_poke_interval(started_at, run_duration, poke_count)
  288. reschedule_date = timezone.utcnow() + timedelta(seconds=next_poke_interval)
  289. if _is_metadatabase_mysql() and reschedule_date > _MYSQL_TIMESTAMP_MAX:
  290. raise AirflowSensorTimeout(
  291. f"Cannot reschedule DAG {log_dag_id} to {reschedule_date.isoformat()} "
  292. f"since it is over MySQL's TIMESTAMP storage limit."
  293. )
  294. raise AirflowRescheduleException(reschedule_date)
  295. else:
  296. time.sleep(self._get_next_poke_interval(started_at, run_duration, poke_count))
  297. poke_count += 1
  298. self.log.info("Success criteria met. Exiting.")
  299. return xcom_value
  300. def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
  301. try:
  302. return super().resume_execution(next_method, next_kwargs, context)
  303. except (AirflowException, TaskDeferralError) as e:
  304. if self.soft_fail:
  305. raise AirflowSkipException(str(e)) from e
  306. raise
  307. def _get_next_poke_interval(
  308. self,
  309. started_at: datetime.datetime | float,
  310. run_duration: Callable[[], float],
  311. poke_count: int,
  312. ) -> float:
  313. """Use similar logic which is used for exponential backoff retry delay for operators."""
  314. if not self.exponential_backoff:
  315. return self.poke_interval
  316. if self.reschedule:
  317. # Calculate elapsed time since the sensor started
  318. elapsed_time = run_duration()
  319. # Initialize variables for the simulation
  320. cumulative_time: float = 0.0
  321. estimated_poke_count: int = 0
  322. while cumulative_time <= elapsed_time:
  323. estimated_poke_count += 1
  324. # Calculate min_backoff for the current try number
  325. min_backoff = max(int(self.poke_interval * (2 ** (estimated_poke_count - 2))), 1)
  326. # Calculate the jitter
  327. run_hash = int(
  328. hashlib.sha1(
  329. f"{self.dag_id}#{self.task_id}#{started_at}#{estimated_poke_count}".encode()
  330. ).hexdigest(),
  331. 16,
  332. )
  333. modded_hash = min_backoff + run_hash % min_backoff
  334. # Calculate the jitter, which is used to prevent multiple sensors simultaneously poking
  335. interval_with_jitter = min(modded_hash, timedelta.max.total_seconds() - 1)
  336. # Add the interval to the cumulative time
  337. cumulative_time += interval_with_jitter
  338. # Now we have an estimated_poke_count based on the elapsed time
  339. poke_count = estimated_poke_count or poke_count
  340. # The value of min_backoff should always be greater than or equal to 1.
  341. min_backoff = max(int(self.poke_interval * (2 ** (poke_count - 2))), 1)
  342. run_hash = int(
  343. hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{poke_count}".encode()).hexdigest(),
  344. 16,
  345. )
  346. modded_hash = min_backoff + run_hash % min_backoff
  347. delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1)
  348. new_interval = min(self.timeout - int(run_duration()), delay_backoff_in_seconds)
  349. if self.max_wait:
  350. new_interval = min(self.max_wait.total_seconds(), new_interval)
  351. self.log.info("new %s interval is %s", self.mode, new_interval)
  352. return new_interval
  353. def prepare_for_execution(self) -> BaseOperator:
  354. task = super().prepare_for_execution()
  355. # Sensors in `poke` mode can block execution of DAGs when running
  356. # with single process executor, thus we change the mode to`reschedule`
  357. # to allow parallel task being scheduled and executed
  358. executor, _ = ExecutorLoader.import_default_executor_cls()
  359. if executor.change_sensor_mode_to_reschedule:
  360. self.log.warning("%s changes sensor mode to 'reschedule'.", executor.__name__)
  361. task.mode = "reschedule"
  362. return task
  363. @property
  364. def reschedule(self):
  365. """Define mode rescheduled sensors."""
  366. return self.mode == "reschedule"
  367. @classmethod
  368. def get_serialized_fields(cls):
  369. return super().get_serialized_fields() | {"reschedule"}
  370. def poke_mode_only(cls):
  371. """
  372. Decorate a subclass of BaseSensorOperator with poke.
  373. Indicate that instances of this class are only safe to use poke mode.
  374. Will decorate all methods in the class to assert they did not change
  375. the mode from 'poke'.
  376. :param cls: BaseSensor class to enforce methods only use 'poke' mode.
  377. """
  378. def decorate(cls_type):
  379. def mode_getter(_):
  380. return "poke"
  381. def mode_setter(_, value):
  382. if value != "poke":
  383. raise ValueError(f"Cannot set mode to '{value}'. Only 'poke' is acceptable")
  384. if not issubclass(cls_type, BaseSensorOperator):
  385. raise ValueError(
  386. f"poke_mode_only decorator should only be "
  387. f"applied to subclasses of BaseSensorOperator,"
  388. f" got:{cls_type}."
  389. )
  390. cls_type.mode = property(mode_getter, mode_setter)
  391. return cls_type
  392. return decorate(cls)