base.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import abc
  19. import logging
  20. from dataclasses import dataclass
  21. from datetime import timedelta
  22. from typing import TYPE_CHECKING, Any, AsyncIterator
  23. from airflow.callbacks.callback_requests import TaskCallbackRequest
  24. from airflow.callbacks.database_callback_sink import DatabaseCallbackSink
  25. from airflow.models.taskinstance import SimpleTaskInstance
  26. from airflow.utils.log.logging_mixin import LoggingMixin
  27. from airflow.utils.session import NEW_SESSION, provide_session
  28. from airflow.utils.state import TaskInstanceState
  29. if TYPE_CHECKING:
  30. from sqlalchemy.orm import Session
  31. from airflow.models import TaskInstance
  32. log = logging.getLogger(__name__)
  33. @dataclass
  34. class StartTriggerArgs:
  35. """Arguments required for start task execution from triggerer."""
  36. trigger_cls: str
  37. next_method: str
  38. trigger_kwargs: dict[str, Any] | None = None
  39. next_kwargs: dict[str, Any] | None = None
  40. timeout: timedelta | None = None
  41. class BaseTrigger(abc.ABC, LoggingMixin):
  42. """
  43. Base class for all triggers.
  44. A trigger has two contexts it can exist in:
  45. - Inside an Operator, when it's passed to TaskDeferred
  46. - Actively running in a trigger worker
  47. We use the same class for both situations, and rely on all Trigger classes
  48. to be able to return the arguments (possible to encode with Airflow-JSON) that will
  49. let them be re-instantiated elsewhere.
  50. """
  51. def __init__(self, **kwargs):
  52. # these values are set by triggerer when preparing to run the instance
  53. # when run, they are injected into logger record.
  54. self.task_instance = None
  55. self.trigger_id = None
  56. def _set_context(self, context):
  57. """Part of LoggingMixin and used mainly for configuration of task logging; not used for triggers."""
  58. raise NotImplementedError
  59. @abc.abstractmethod
  60. def serialize(self) -> tuple[str, dict[str, Any]]:
  61. """
  62. Return the information needed to reconstruct this Trigger.
  63. :return: Tuple of (class path, keyword arguments needed to re-instantiate).
  64. """
  65. raise NotImplementedError("Triggers must implement serialize()")
  66. @abc.abstractmethod
  67. async def run(self) -> AsyncIterator[TriggerEvent]:
  68. """
  69. Run the trigger in an asynchronous context.
  70. The trigger should yield an Event whenever it wants to fire off
  71. an event, and return None if it is finished. Single-event triggers
  72. should thus yield and then immediately return.
  73. If it yields, it is likely that it will be resumed very quickly,
  74. but it may not be (e.g. if the workload is being moved to another
  75. triggerer process, or a multi-event trigger was being used for a
  76. single-event task defer).
  77. In either case, Trigger classes should assume they will be persisted,
  78. and then rely on cleanup() being called when they are no longer needed.
  79. """
  80. raise NotImplementedError("Triggers must implement run()")
  81. yield # To convince Mypy this is an async iterator.
  82. async def cleanup(self) -> None:
  83. """
  84. Cleanup the trigger.
  85. Called when the trigger is no longer needed, and it's being removed
  86. from the active triggerer process.
  87. This method follows the async/await pattern to allow to run the cleanup
  88. in triggerer main event loop. Exceptions raised by the cleanup method
  89. are ignored, so if you would like to be able to debug them and be notified
  90. that cleanup method failed, you should wrap your code with try/except block
  91. and handle it appropriately (in async-compatible way).
  92. """
  93. def __repr__(self) -> str:
  94. classpath, kwargs = self.serialize()
  95. kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
  96. return f"<{classpath} {kwargs_str}>"
  97. class TriggerEvent:
  98. """
  99. Something that a trigger can fire when its conditions are met.
  100. Events must have a uniquely identifying value that would be the same
  101. wherever the trigger is run; this is to ensure that if the same trigger
  102. is being run in two locations (for HA reasons) that we can deduplicate its
  103. events.
  104. """
  105. def __init__(self, payload: Any):
  106. self.payload = payload
  107. def __repr__(self) -> str:
  108. return f"TriggerEvent<{self.payload!r}>"
  109. def __eq__(self, other):
  110. if isinstance(other, TriggerEvent):
  111. return other.payload == self.payload
  112. return False
  113. @provide_session
  114. def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None:
  115. """
  116. Handle the submit event for a given task instance.
  117. This function sets the next method and next kwargs of the task instance,
  118. as well as its state to scheduled. It also adds the event's payload
  119. into the kwargs for the task.
  120. :param task_instance: The task instance to handle the submit event for.
  121. :param session: The session to be used for the database callback sink.
  122. """
  123. # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist
  124. next_kwargs = task_instance.next_kwargs or {}
  125. # Add the event's payload into the kwargs for the task
  126. next_kwargs["event"] = self.payload
  127. # Update the next kwargs of the task instance
  128. task_instance.next_kwargs = next_kwargs
  129. # Remove ourselves as its trigger
  130. task_instance.trigger_id = None
  131. # Set the state of the task instance to scheduled
  132. task_instance.state = TaskInstanceState.SCHEDULED
  133. class BaseTaskEndEvent(TriggerEvent):
  134. """
  135. Base event class to end the task without resuming on worker.
  136. :meta private:
  137. """
  138. task_instance_state: TaskInstanceState
  139. def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None:
  140. """
  141. Initialize the class with the specified parameters.
  142. :param xcoms: A dictionary of XComs or None.
  143. :param kwargs: Additional keyword arguments.
  144. """
  145. if "payload" in kwargs:
  146. raise ValueError("Param 'payload' not supported for this class.")
  147. super().__init__(payload=self.task_instance_state)
  148. self.xcoms = xcoms
  149. @provide_session
  150. def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None:
  151. """
  152. Submit event for the given task instance.
  153. Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable.
  154. :param task_instance: The task instance to be submitted.
  155. :param session: The session to be used for the database callback sink.
  156. """
  157. # Mark the task with terminal state and prevent it from resuming on worker
  158. task_instance.trigger_id = None
  159. task_instance.set_state(self.task_instance_state, session=session)
  160. self._submit_callback_if_necessary(task_instance=task_instance, session=session)
  161. self._push_xcoms_if_necessary(task_instance=task_instance)
  162. def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session) -> None:
  163. """Submit a callback request if the task state is SUCCESS or FAILED."""
  164. if self.task_instance_state in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED):
  165. request = TaskCallbackRequest(
  166. full_filepath=task_instance.dag_model.fileloc,
  167. simple_task_instance=SimpleTaskInstance.from_ti(task_instance),
  168. task_callback_type=self.task_instance_state,
  169. )
  170. log.info("Sending callback: %s", request)
  171. try:
  172. DatabaseCallbackSink().send(callback=request, session=session)
  173. except Exception:
  174. log.exception("Failed to send callback.")
  175. def _push_xcoms_if_necessary(self, *, task_instance: TaskInstance) -> None:
  176. """Pushes XComs to the database if they are provided."""
  177. if self.xcoms:
  178. for key, value in self.xcoms.items():
  179. task_instance.xcom_push(key=key, value=value)
  180. class TaskSuccessEvent(BaseTaskEndEvent):
  181. """Yield this event in order to end the task successfully."""
  182. task_instance_state = TaskInstanceState.SUCCESS
  183. class TaskFailedEvent(BaseTaskEndEvent):
  184. """Yield this event in order to end the task with failure."""
  185. task_instance_state = TaskInstanceState.FAILED
  186. class TaskSkippedEvent(BaseTaskEndEvent):
  187. """Yield this event in order to end the task with status 'skipped'."""
  188. task_instance_state = TaskInstanceState.SKIPPED