taskinstancehistory.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from typing import TYPE_CHECKING
  20. import dill
  21. from sqlalchemy import (
  22. Column,
  23. DateTime,
  24. Float,
  25. ForeignKeyConstraint,
  26. Integer,
  27. String,
  28. UniqueConstraint,
  29. func,
  30. select,
  31. text,
  32. )
  33. from sqlalchemy.ext.mutable import MutableDict
  34. from airflow.models.base import Base, StringID
  35. from airflow.utils import timezone
  36. from airflow.utils.session import NEW_SESSION, provide_session
  37. from airflow.utils.sqlalchemy import (
  38. ExecutorConfigType,
  39. ExtendedJSON,
  40. UtcDateTime,
  41. )
  42. from airflow.utils.state import State, TaskInstanceState
  43. if TYPE_CHECKING:
  44. from airflow.models.taskinstance import TaskInstance
  45. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  46. class TaskInstanceHistory(Base):
  47. """
  48. Store old tries of TaskInstances.
  49. :meta private:
  50. """
  51. __tablename__ = "task_instance_history"
  52. id = Column(Integer(), primary_key=True, autoincrement=True)
  53. task_id = Column(StringID(), nullable=False)
  54. dag_id = Column(StringID(), nullable=False)
  55. run_id = Column(StringID(), nullable=False)
  56. map_index = Column(Integer, nullable=False, server_default=text("-1"))
  57. try_number = Column(Integer, nullable=False)
  58. start_date = Column(UtcDateTime)
  59. end_date = Column(UtcDateTime)
  60. duration = Column(Float)
  61. state = Column(String(20))
  62. max_tries = Column(Integer, server_default=text("-1"))
  63. hostname = Column(String(1000))
  64. unixname = Column(String(1000))
  65. job_id = Column(Integer)
  66. pool = Column(String(256), nullable=False)
  67. pool_slots = Column(Integer, default=1, nullable=False)
  68. queue = Column(String(256))
  69. priority_weight = Column(Integer)
  70. operator = Column(String(1000))
  71. custom_operator_name = Column(String(1000))
  72. queued_dttm = Column(UtcDateTime)
  73. queued_by_job_id = Column(Integer)
  74. pid = Column(Integer)
  75. executor = Column(String(1000))
  76. executor_config = Column(ExecutorConfigType(pickler=dill))
  77. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
  78. rendered_map_index = Column(String(250))
  79. external_executor_id = Column(StringID())
  80. trigger_id = Column(Integer)
  81. trigger_timeout = Column(DateTime)
  82. next_method = Column(String(1000))
  83. next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
  84. task_display_name = Column("task_display_name", String(2000), nullable=True)
  85. def __init__(
  86. self,
  87. ti: TaskInstance | TaskInstancePydantic,
  88. state: str | None = None,
  89. ):
  90. super().__init__()
  91. for column in self.__table__.columns:
  92. if column.name == "id":
  93. continue
  94. setattr(self, column.name, getattr(ti, column.name))
  95. if state:
  96. self.state = state
  97. __table_args__ = (
  98. ForeignKeyConstraint(
  99. [dag_id, task_id, run_id, map_index],
  100. [
  101. "task_instance.dag_id",
  102. "task_instance.task_id",
  103. "task_instance.run_id",
  104. "task_instance.map_index",
  105. ],
  106. name="task_instance_history_ti_fkey",
  107. ondelete="CASCADE",
  108. onupdate="CASCADE",
  109. ),
  110. UniqueConstraint(
  111. "dag_id",
  112. "task_id",
  113. "run_id",
  114. "map_index",
  115. "try_number",
  116. name="task_instance_history_dtrt_uq",
  117. ),
  118. )
  119. @staticmethod
  120. @provide_session
  121. def record_ti(ti: TaskInstance, session: NEW_SESSION = None) -> None:
  122. """Record a TaskInstance to TaskInstanceHistory."""
  123. exists_q = session.scalar(
  124. select(func.count(TaskInstanceHistory.task_id)).where(
  125. TaskInstanceHistory.dag_id == ti.dag_id,
  126. TaskInstanceHistory.task_id == ti.task_id,
  127. TaskInstanceHistory.run_id == ti.run_id,
  128. TaskInstanceHistory.map_index == ti.map_index,
  129. TaskInstanceHistory.try_number == ti.try_number,
  130. )
  131. )
  132. if exists_q:
  133. return
  134. ti_history_state = ti.state
  135. if ti.state not in State.finished:
  136. ti_history_state = TaskInstanceState.FAILED
  137. ti.end_date = timezone.utcnow()
  138. ti.set_duration()
  139. ti_history = TaskInstanceHistory(ti, state=ti_history_state)
  140. session.add(ti_history)