taskreschedule.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """TaskReschedule tracks rescheduled task instances."""
  19. from __future__ import annotations
  20. import warnings
  21. from typing import TYPE_CHECKING
  22. from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, select, text
  23. from sqlalchemy.ext.associationproxy import association_proxy
  24. from sqlalchemy.orm import relationship
  25. from airflow.exceptions import RemovedInAirflow3Warning
  26. from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
  27. from airflow.utils.session import NEW_SESSION, provide_session
  28. from airflow.utils.sqlalchemy import UtcDateTime
  29. if TYPE_CHECKING:
  30. import datetime
  31. from sqlalchemy.orm import Query, Session
  32. from sqlalchemy.sql import Select
  33. from airflow.models.taskinstance import TaskInstance
  34. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  35. class TaskReschedule(TaskInstanceDependencies):
  36. """TaskReschedule tracks rescheduled task instances."""
  37. __tablename__ = "task_reschedule"
  38. id = Column(Integer, primary_key=True)
  39. task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
  40. dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
  41. run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
  42. map_index = Column(Integer, nullable=False, server_default=text("-1"))
  43. try_number = Column(Integer, nullable=False)
  44. start_date = Column(UtcDateTime, nullable=False)
  45. end_date = Column(UtcDateTime, nullable=False)
  46. duration = Column(Integer, nullable=False)
  47. reschedule_date = Column(UtcDateTime, nullable=False)
  48. __table_args__ = (
  49. Index("idx_task_reschedule_dag_task_run", dag_id, task_id, run_id, map_index, unique=False),
  50. ForeignKeyConstraint(
  51. [dag_id, task_id, run_id, map_index],
  52. [
  53. "task_instance.dag_id",
  54. "task_instance.task_id",
  55. "task_instance.run_id",
  56. "task_instance.map_index",
  57. ],
  58. name="task_reschedule_ti_fkey",
  59. ondelete="CASCADE",
  60. ),
  61. Index("idx_task_reschedule_dag_run", dag_id, run_id),
  62. ForeignKeyConstraint(
  63. [dag_id, run_id],
  64. ["dag_run.dag_id", "dag_run.run_id"],
  65. name="task_reschedule_dr_fkey",
  66. ondelete="CASCADE",
  67. ),
  68. )
  69. dag_run = relationship("DagRun")
  70. execution_date = association_proxy("dag_run", "execution_date")
  71. def __init__(
  72. self,
  73. task_id: str,
  74. dag_id: str,
  75. run_id: str,
  76. try_number: int,
  77. start_date: datetime.datetime,
  78. end_date: datetime.datetime,
  79. reschedule_date: datetime.datetime,
  80. map_index: int = -1,
  81. ) -> None:
  82. self.dag_id = dag_id
  83. self.task_id = task_id
  84. self.run_id = run_id
  85. self.map_index = map_index
  86. self.try_number = try_number
  87. self.start_date = start_date
  88. self.end_date = end_date
  89. self.reschedule_date = reschedule_date
  90. self.duration = (self.end_date - self.start_date).total_seconds()
  91. @classmethod
  92. def stmt_for_task_instance(
  93. cls,
  94. ti: TaskInstance | TaskInstancePydantic,
  95. *,
  96. try_number: int | None = None,
  97. descending: bool = False,
  98. ) -> Select:
  99. """
  100. Statement for task reschedules for a given the task instance.
  101. :param ti: the task instance to find task reschedules for
  102. :param descending: If True then records are returned in descending order
  103. :param try_number: Look for TaskReschedule of the given try_number. Default is None which
  104. looks for the same try_number of the given task_instance.
  105. :meta private:
  106. """
  107. if try_number is None:
  108. try_number = ti.try_number
  109. return (
  110. select(cls)
  111. .where(
  112. cls.dag_id == ti.dag_id,
  113. cls.task_id == ti.task_id,
  114. cls.run_id == ti.run_id,
  115. cls.map_index == ti.map_index,
  116. cls.try_number == try_number,
  117. )
  118. .order_by(desc(cls.id) if descending else asc(cls.id))
  119. )
  120. @staticmethod
  121. @provide_session
  122. def query_for_task_instance(
  123. task_instance: TaskInstance,
  124. descending: bool = False,
  125. session: Session = NEW_SESSION,
  126. try_number: int | None = None,
  127. ) -> Query:
  128. """
  129. Return query for task reschedules for a given the task instance (deprecated).
  130. :param session: the database session object
  131. :param task_instance: the task instance to find task reschedules for
  132. :param descending: If True then records are returned in descending order
  133. :param try_number: Look for TaskReschedule of the given try_number. Default is None which
  134. looks for the same try_number of the given task_instance.
  135. """
  136. warnings.warn(
  137. "Using this method is no longer advised, and it is expected to be removed in the future.",
  138. category=RemovedInAirflow3Warning,
  139. stacklevel=2,
  140. )
  141. if try_number is None:
  142. try_number = task_instance.try_number
  143. TR = TaskReschedule
  144. qry = session.query(TR).filter(
  145. TR.dag_id == task_instance.dag_id,
  146. TR.task_id == task_instance.task_id,
  147. TR.run_id == task_instance.run_id,
  148. TR.map_index == task_instance.map_index,
  149. TR.try_number == try_number,
  150. )
  151. if descending:
  152. return qry.order_by(desc(TR.id))
  153. else:
  154. return qry.order_by(asc(TR.id))
  155. @staticmethod
  156. @provide_session
  157. def find_for_task_instance(
  158. task_instance: TaskInstance,
  159. session: Session = NEW_SESSION,
  160. try_number: int | None = None,
  161. ) -> list[TaskReschedule]:
  162. """
  163. Return all task reschedules for the task instance and try number, in ascending order.
  164. :param session: the database session object
  165. :param task_instance: the task instance to find task reschedules for
  166. :param try_number: Look for TaskReschedule of the given try_number. Default is None which
  167. looks for the same try_number of the given task_instance.
  168. """
  169. warnings.warn(
  170. "Using this method is no longer advised, and it is expected to be removed in the future.",
  171. category=RemovedInAirflow3Warning,
  172. stacklevel=2,
  173. )
  174. return session.scalars(
  175. TaskReschedule.stmt_for_task_instance(ti=task_instance, try_number=try_number, descending=False)
  176. ).all()