renderedtifields.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Save Rendered Template Fields."""
  19. from __future__ import annotations
  20. import os
  21. from typing import TYPE_CHECKING
  22. import sqlalchemy_jsonfield
  23. from sqlalchemy import (
  24. Column,
  25. ForeignKeyConstraint,
  26. Integer,
  27. PrimaryKeyConstraint,
  28. delete,
  29. exists,
  30. select,
  31. text,
  32. )
  33. from sqlalchemy.ext.associationproxy import association_proxy
  34. from sqlalchemy.orm import relationship
  35. from airflow.api_internal.internal_api_call import internal_api_call
  36. from airflow.configuration import conf
  37. from airflow.models.base import StringID, TaskInstanceDependencies
  38. from airflow.serialization.helpers import serialize_template_field
  39. from airflow.settings import json
  40. from airflow.utils.retries import retry_db_transaction
  41. from airflow.utils.session import NEW_SESSION, provide_session
  42. if TYPE_CHECKING:
  43. from sqlalchemy.orm import Session
  44. from sqlalchemy.sql import FromClause
  45. from airflow.models import Operator
  46. from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic
  47. def get_serialized_template_fields(task: Operator):
  48. """
  49. Get and serialize the template fields for a task.
  50. Used in preparing to store them in RTIF table.
  51. :param task: Operator instance with rendered template fields
  52. :meta private:
  53. """
  54. return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields}
  55. class RenderedTaskInstanceFields(TaskInstanceDependencies):
  56. """Save Rendered Template Fields."""
  57. __tablename__ = "rendered_task_instance_fields"
  58. dag_id = Column(StringID(), primary_key=True)
  59. task_id = Column(StringID(), primary_key=True)
  60. run_id = Column(StringID(), primary_key=True)
  61. map_index = Column(Integer, primary_key=True, server_default=text("-1"))
  62. rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False)
  63. k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
  64. __table_args__ = (
  65. PrimaryKeyConstraint(
  66. "dag_id",
  67. "task_id",
  68. "run_id",
  69. "map_index",
  70. name="rendered_task_instance_fields_pkey",
  71. ),
  72. ForeignKeyConstraint(
  73. [dag_id, task_id, run_id, map_index],
  74. [
  75. "task_instance.dag_id",
  76. "task_instance.task_id",
  77. "task_instance.run_id",
  78. "task_instance.map_index",
  79. ],
  80. name="rtif_ti_fkey",
  81. ondelete="CASCADE",
  82. ),
  83. )
  84. task_instance = relationship(
  85. "TaskInstance",
  86. lazy="joined",
  87. back_populates="rendered_task_instance_fields",
  88. )
  89. # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
  90. # the relationship we can more easily find the execution date for these rows
  91. dag_run = relationship(
  92. "DagRun",
  93. primaryjoin="""and_(
  94. RenderedTaskInstanceFields.dag_id == foreign(DagRun.dag_id),
  95. RenderedTaskInstanceFields.run_id == foreign(DagRun.run_id),
  96. )""",
  97. viewonly=True,
  98. )
  99. execution_date = association_proxy("dag_run", "execution_date")
  100. def __init__(self, ti: TaskInstance, render_templates=True, rendered_fields=None):
  101. self.dag_id = ti.dag_id
  102. self.task_id = ti.task_id
  103. self.run_id = ti.run_id
  104. self.map_index = ti.map_index
  105. self.ti = ti
  106. if render_templates:
  107. ti.render_templates()
  108. if TYPE_CHECKING:
  109. assert ti.task
  110. self.task = ti.task
  111. if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None):
  112. # we can safely import it here from provider. In Airflow 2.7.0+ you need to have new version
  113. # of kubernetes provider installed to reach this place
  114. from airflow.providers.cncf.kubernetes.template_rendering import render_k8s_pod_yaml
  115. self.k8s_pod_yaml = render_k8s_pod_yaml(ti)
  116. self.rendered_fields = rendered_fields or get_serialized_template_fields(task=ti.task)
  117. self._redact()
  118. def __repr__(self):
  119. prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
  120. if self.map_index != -1:
  121. prefix += f" map_index={self.map_index}"
  122. return prefix + ">"
  123. def _redact(self):
  124. from airflow.utils.log.secrets_masker import redact
  125. if self.k8s_pod_yaml:
  126. self.k8s_pod_yaml = redact(self.k8s_pod_yaml)
  127. for field, rendered in self.rendered_fields.items():
  128. self.rendered_fields[field] = redact(rendered, field)
  129. @classmethod
  130. @internal_api_call
  131. @provide_session
  132. def _update_runtime_evaluated_template_fields(
  133. cls, ti: TaskInstance, session: Session = NEW_SESSION
  134. ) -> None:
  135. """Update rendered task instance fields for cases where runtime evaluated, not templated."""
  136. # Note: Need lazy import to break the partly loaded class loop
  137. from airflow.models.taskinstance import TaskInstance
  138. # If called via remote API the DAG needs to be re-loaded
  139. TaskInstance.ensure_dag(ti, session=session)
  140. rtif = RenderedTaskInstanceFields(ti)
  141. RenderedTaskInstanceFields.write(rtif, session=session)
  142. RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
  143. @classmethod
  144. @provide_session
  145. def get_templated_fields(
  146. cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
  147. ) -> dict | None:
  148. """
  149. Get templated field for a TaskInstance from the RenderedTaskInstanceFields table.
  150. :param ti: Task Instance
  151. :param session: SqlAlchemy Session
  152. :return: Rendered Templated TI field
  153. """
  154. result = session.scalar(
  155. select(cls).where(
  156. cls.dag_id == ti.dag_id,
  157. cls.task_id == ti.task_id,
  158. cls.run_id == ti.run_id,
  159. cls.map_index == ti.map_index,
  160. )
  161. )
  162. if result:
  163. rendered_fields = result.rendered_fields
  164. return rendered_fields
  165. else:
  166. return None
  167. @classmethod
  168. @provide_session
  169. def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
  170. """
  171. Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields table.
  172. :param ti: Task Instance
  173. :param session: SqlAlchemy Session
  174. :return: Kubernetes Pod Yaml
  175. """
  176. result = session.scalar(
  177. select(cls).where(
  178. cls.dag_id == ti.dag_id,
  179. cls.task_id == ti.task_id,
  180. cls.run_id == ti.run_id,
  181. cls.map_index == ti.map_index,
  182. )
  183. )
  184. return result.k8s_pod_yaml if result else None
  185. @provide_session
  186. @retry_db_transaction
  187. def write(self, session: Session = None):
  188. """
  189. Write instance to database.
  190. :param session: SqlAlchemy Session
  191. """
  192. session.merge(self)
  193. @classmethod
  194. @provide_session
  195. def delete_old_records(
  196. cls,
  197. task_id: str,
  198. dag_id: str,
  199. num_to_keep: int = conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0),
  200. session: Session = NEW_SESSION,
  201. ) -> None:
  202. """
  203. Keep only Last X (num_to_keep) number of records for a task by deleting others.
  204. In the case of data for a mapped task either all of the rows or none of the rows will be deleted, so
  205. we don't end up with partial data for a set of mapped Task Instances left in the database.
  206. :param task_id: Task ID
  207. :param dag_id: Dag ID
  208. :param num_to_keep: Number of Records to keep
  209. :param session: SqlAlchemy Session
  210. """
  211. if num_to_keep <= 0:
  212. return
  213. from airflow.models.dagrun import DagRun
  214. tis_to_keep_query = (
  215. select(cls.dag_id, cls.task_id, cls.run_id, DagRun.execution_date)
  216. .where(cls.dag_id == dag_id, cls.task_id == task_id)
  217. .join(cls.dag_run)
  218. .distinct()
  219. .order_by(DagRun.execution_date.desc())
  220. .limit(num_to_keep)
  221. )
  222. cls._do_delete_old_records(
  223. dag_id=dag_id,
  224. task_id=task_id,
  225. ti_clause=tis_to_keep_query.subquery(),
  226. session=session,
  227. )
  228. session.flush()
  229. @classmethod
  230. @retry_db_transaction
  231. def _do_delete_old_records(
  232. cls,
  233. *,
  234. task_id: str,
  235. dag_id: str,
  236. ti_clause: FromClause,
  237. session: Session,
  238. ) -> None:
  239. # This query might deadlock occasionally and it should be retried if fails (see decorator)
  240. stmt = (
  241. delete(cls)
  242. .where(
  243. cls.dag_id == dag_id,
  244. cls.task_id == task_id,
  245. ~exists(1).where(
  246. ti_clause.c.dag_id == cls.dag_id,
  247. ti_clause.c.task_id == cls.task_id,
  248. ti_clause.c.run_id == cls.run_id,
  249. ),
  250. )
  251. .execution_options(synchronize_session=False)
  252. )
  253. session.execute(stmt)