trigger.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import datetime
  19. from traceback import format_exception
  20. from typing import TYPE_CHECKING, Any, Iterable
  21. from sqlalchemy import Column, Integer, String, Text, delete, func, or_, select, update
  22. from sqlalchemy.orm import relationship, selectinload
  23. from sqlalchemy.sql.functions import coalesce
  24. from airflow.api_internal.internal_api_call import internal_api_call
  25. from airflow.models.base import Base
  26. from airflow.models.taskinstance import TaskInstance
  27. from airflow.utils import timezone
  28. from airflow.utils.retries import run_with_db_retries
  29. from airflow.utils.session import NEW_SESSION, provide_session
  30. from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
  31. from airflow.utils.state import TaskInstanceState
  32. if TYPE_CHECKING:
  33. from sqlalchemy.orm import Session
  34. from sqlalchemy.sql import Select
  35. from airflow.serialization.pydantic.trigger import TriggerPydantic
  36. from airflow.triggers.base import BaseTrigger
  37. class Trigger(Base):
  38. """
  39. Base Trigger class.
  40. Triggers are a workload that run in an asynchronous event loop shared with
  41. other Triggers, and fire off events that will unpause deferred Tasks,
  42. start linked DAGs, etc.
  43. They are persisted into the database and then re-hydrated into a
  44. "triggerer" process, where many are run at once. We model it so that
  45. there is a many-to-one relationship between Task and Trigger, for future
  46. deduplication logic to use.
  47. Rows will be evicted from the database when the triggerer detects no
  48. active Tasks/DAGs using them. Events are not stored in the database;
  49. when an Event is fired, the triggerer will directly push its data to the
  50. appropriate Task/DAG.
  51. """
  52. __tablename__ = "trigger"
  53. id = Column(Integer, primary_key=True)
  54. classpath = Column(String(1000), nullable=False)
  55. encrypted_kwargs = Column("kwargs", Text, nullable=False)
  56. created_date = Column(UtcDateTime, nullable=False)
  57. triggerer_id = Column(Integer, nullable=True)
  58. triggerer_job = relationship(
  59. "Job",
  60. primaryjoin="Job.id == Trigger.triggerer_id",
  61. foreign_keys=triggerer_id,
  62. uselist=False,
  63. )
  64. task_instance = relationship("TaskInstance", back_populates="trigger", lazy="selectin", uselist=False)
  65. def __init__(
  66. self,
  67. classpath: str,
  68. kwargs: dict[str, Any],
  69. created_date: datetime.datetime | None = None,
  70. ) -> None:
  71. super().__init__()
  72. self.classpath = classpath
  73. self.encrypted_kwargs = self._encrypt_kwargs(kwargs)
  74. self.created_date = created_date or timezone.utcnow()
  75. @property
  76. def kwargs(self) -> dict[str, Any]:
  77. """Return the decrypted kwargs of the trigger."""
  78. return self._decrypt_kwargs(self.encrypted_kwargs)
  79. @kwargs.setter
  80. def kwargs(self, kwargs: dict[str, Any]) -> None:
  81. """Set the encrypted kwargs of the trigger."""
  82. self.encrypted_kwargs = self._encrypt_kwargs(kwargs)
  83. @staticmethod
  84. def _encrypt_kwargs(kwargs: dict[str, Any]) -> str:
  85. """Encrypt the kwargs of the trigger."""
  86. import json
  87. from airflow.models.crypto import get_fernet
  88. from airflow.serialization.serialized_objects import BaseSerialization
  89. serialized_kwargs = BaseSerialization.serialize(kwargs)
  90. return get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")
  91. @staticmethod
  92. def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
  93. """Decrypt the kwargs of the trigger."""
  94. import json
  95. from airflow.models.crypto import get_fernet
  96. from airflow.serialization.serialized_objects import BaseSerialization
  97. # We weren't able to encrypt the kwargs in all migration paths,
  98. # so we need to handle the case where they are not encrypted.
  99. # Triggers aren't long lasting, so we can skip encrypting them now.
  100. if encrypted_kwargs.startswith("{"):
  101. decrypted_kwargs = json.loads(encrypted_kwargs)
  102. else:
  103. decrypted_kwargs = json.loads(
  104. get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
  105. )
  106. return BaseSerialization.deserialize(decrypted_kwargs)
  107. def rotate_fernet_key(self):
  108. """Encrypts data with a new key. See: :ref:`security/fernet`."""
  109. from airflow.models.crypto import get_fernet
  110. self.encrypted_kwargs = get_fernet().rotate(self.encrypted_kwargs.encode("utf-8")).decode("utf-8")
  111. @classmethod
  112. @internal_api_call
  113. @provide_session
  114. def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic:
  115. """Alternative constructor that creates a trigger row based directly off of a Trigger object."""
  116. classpath, kwargs = trigger.serialize()
  117. return cls(classpath=classpath, kwargs=kwargs)
  118. @classmethod
  119. @internal_api_call
  120. @provide_session
  121. def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[int, Trigger]:
  122. """Fetch all the Triggers by ID and return a dict mapping ID -> Trigger instance."""
  123. stmt = (
  124. select(cls)
  125. .where(cls.id.in_(ids))
  126. .options(
  127. selectinload(cls.task_instance)
  128. .joinedload(TaskInstance.trigger)
  129. .joinedload(Trigger.triggerer_job)
  130. )
  131. )
  132. return {obj.id: obj for obj in session.scalars(stmt)}
  133. @classmethod
  134. @internal_api_call
  135. @provide_session
  136. def clean_unused(cls, session: Session = NEW_SESSION) -> None:
  137. """
  138. Delete all triggers that have no tasks dependent on them.
  139. Triggers have a one-to-many relationship to task instances, so we need
  140. to clean those up first. Afterwards we can drop the triggers not
  141. referenced by anyone.
  142. """
  143. # Update all task instances with trigger IDs that are not DEFERRED to remove them
  144. for attempt in run_with_db_retries():
  145. with attempt:
  146. session.execute(
  147. update(TaskInstance)
  148. .where(
  149. TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None)
  150. )
  151. .values(trigger_id=None)
  152. )
  153. # Get all triggers that have no task instances depending on them and delete them
  154. ids = (
  155. select(cls.id)
  156. .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True)
  157. .group_by(cls.id)
  158. .having(func.count(TaskInstance.trigger_id) == 0)
  159. )
  160. if session.bind.dialect.name == "mysql":
  161. # MySQL doesn't support DELETE with JOIN, so we need to do it in two steps
  162. ids = session.scalars(ids).all()
  163. session.execute(
  164. delete(Trigger).where(Trigger.id.in_(ids)).execution_options(synchronize_session=False)
  165. )
  166. @classmethod
  167. @internal_api_call
  168. @provide_session
  169. def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None:
  170. """Take an event from an instance of itself, and trigger all dependent tasks to resume."""
  171. for task_instance in session.scalars(
  172. select(TaskInstance).where(
  173. TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
  174. )
  175. ):
  176. event.handle_submit(task_instance=task_instance)
  177. @classmethod
  178. @internal_api_call
  179. @provide_session
  180. def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None:
  181. """
  182. When a trigger has failed unexpectedly, mark everything that depended on it as failed.
  183. Notably, we have to actually run the failure code from a worker as it may
  184. have linked callbacks, so hilariously we have to re-schedule the task
  185. instances to a worker just so they can then fail.
  186. We use a special __fail__ value for next_method to achieve this that
  187. the runtime code understands as immediate-fail, and pack the error into
  188. next_kwargs.
  189. TODO: Once we have shifted callback (and email) handling to run on
  190. workers as first-class concepts, we can run the failure code here
  191. in-process, but we can't do that right now.
  192. """
  193. for task_instance in session.scalars(
  194. select(TaskInstance).where(
  195. TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
  196. )
  197. ):
  198. # Add the error and set the next_method to the fail state
  199. traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None
  200. task_instance.next_method = "__fail__"
  201. task_instance.next_kwargs = {"error": "Trigger failure", "traceback": traceback}
  202. # Remove ourselves as its trigger
  203. task_instance.trigger_id = None
  204. # Finally, mark it as scheduled so it gets re-queued
  205. task_instance.state = TaskInstanceState.SCHEDULED
  206. @classmethod
  207. @internal_api_call
  208. @provide_session
  209. def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]:
  210. """Retrieve a list of triggerer_ids."""
  211. return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all()
  212. @classmethod
  213. @internal_api_call
  214. @provide_session
  215. def assign_unassigned(
  216. cls, triggerer_id, capacity, health_check_threshold, session: Session = NEW_SESSION
  217. ) -> None:
  218. """
  219. Assign unassigned triggers based on a number of conditions.
  220. Takes a triggerer_id, the capacity for that triggerer and the Triggerer job heartrate
  221. health check threshold, and assigns unassigned triggers until that capacity is reached,
  222. or there are no more unassigned triggers.
  223. """
  224. from airflow.jobs.job import Job # To avoid circular import
  225. count = session.scalar(select(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id))
  226. capacity -= count
  227. if capacity <= 0:
  228. return
  229. alive_triggerer_ids = select(Job.id).where(
  230. Job.end_date.is_(None),
  231. Job.latest_heartbeat > timezone.utcnow() - datetime.timedelta(seconds=health_check_threshold),
  232. Job.job_type == "TriggererJob",
  233. )
  234. # Find triggers who do NOT have an alive triggerer_id, and then assign
  235. # up to `capacity` of those to us.
  236. trigger_ids_query = cls.get_sorted_triggers(
  237. capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session
  238. )
  239. if trigger_ids_query:
  240. session.execute(
  241. update(cls)
  242. .where(cls.id.in_([i.id for i in trigger_ids_query]))
  243. .values(triggerer_id=triggerer_id)
  244. .execution_options(synchronize_session=False)
  245. )
  246. session.commit()
  247. @classmethod
  248. def get_sorted_triggers(cls, capacity: int, alive_triggerer_ids: list[int] | Select, session: Session):
  249. """
  250. Get sorted triggers based on capacity and alive triggerer ids.
  251. :param capacity: The capacity of the triggerer.
  252. :param alive_triggerer_ids: The alive triggerer ids as a list or a select query.
  253. :param session: The database session.
  254. """
  255. query = with_row_locks(
  256. select(cls.id)
  257. .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False)
  258. .where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)))
  259. .order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date)
  260. .limit(capacity),
  261. session,
  262. skip_locked=True,
  263. )
  264. return session.execute(query).all()