abstractoperator.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import datetime
  20. import inspect
  21. from abc import abstractproperty
  22. from functools import cached_property
  23. from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence
  24. import methodtools
  25. from sqlalchemy import select
  26. from airflow.configuration import conf
  27. from airflow.exceptions import AirflowException
  28. from airflow.models.expandinput import NotFullyPopulated
  29. from airflow.models.taskmixin import DAGNode, DependencyMixin
  30. from airflow.template.templater import Templater
  31. from airflow.utils.context import Context
  32. from airflow.utils.db import exists_query
  33. from airflow.utils.log.secrets_masker import redact
  34. from airflow.utils.setup_teardown import SetupTeardownContext
  35. from airflow.utils.sqlalchemy import with_row_locks
  36. from airflow.utils.state import State, TaskInstanceState
  37. from airflow.utils.task_group import MappedTaskGroup
  38. from airflow.utils.trigger_rule import TriggerRule
  39. from airflow.utils.types import NOTSET, ArgNotSet
  40. from airflow.utils.weight_rule import WeightRule, db_safe_priority
  41. TaskStateChangeCallback = Callable[[Context], None]
  42. if TYPE_CHECKING:
  43. import jinja2 # Slow import.
  44. from sqlalchemy.orm import Session
  45. from airflow.models.baseoperator import BaseOperator
  46. from airflow.models.baseoperatorlink import BaseOperatorLink
  47. from airflow.models.dag import DAG
  48. from airflow.models.mappedoperator import MappedOperator
  49. from airflow.models.operator import Operator
  50. from airflow.models.taskinstance import TaskInstance
  51. from airflow.task.priority_strategy import PriorityWeightStrategy
  52. from airflow.triggers.base import StartTriggerArgs
  53. from airflow.utils.task_group import TaskGroup
  54. DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
  55. DEFAULT_POOL_SLOTS: int = 1
  56. DEFAULT_PRIORITY_WEIGHT: int = 1
  57. DEFAULT_EXECUTOR: str | None = None
  58. DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
  59. DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean(
  60. "scheduler", "ignore_first_depends_on_past_by_default"
  61. )
  62. DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
  63. DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
  64. DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
  65. seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
  66. )
  67. MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
  68. DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
  69. conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
  70. )
  71. DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
  72. DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
  73. "core", "default_task_execution_timeout"
  74. )
  75. class NotMapped(Exception):
  76. """Raise if a task is neither mapped nor has any parent mapped groups."""
  77. class AbstractOperator(Templater, DAGNode):
  78. """
  79. Common implementation for operators, including unmapped and mapped.
  80. This base class is more about sharing implementations, not defining a common
  81. interface. Unfortunately it's difficult to use this as the common base class
  82. for typing due to BaseOperator carrying too much historical baggage.
  83. The union type ``from airflow.models.operator import Operator`` is easier
  84. to use for typing purposes.
  85. :meta private:
  86. """
  87. operator_class: type[BaseOperator] | dict[str, Any]
  88. weight_rule: PriorityWeightStrategy
  89. priority_weight: int
  90. # Defines the operator level extra links.
  91. operator_extra_links: Collection[BaseOperatorLink]
  92. owner: str
  93. task_id: str
  94. outlets: list
  95. inlets: list
  96. trigger_rule: TriggerRule
  97. _needs_expansion: bool | None = None
  98. _on_failure_fail_dagrun = False
  99. HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
  100. (
  101. "log",
  102. "dag", # We show dag_id, don't need to show this too
  103. "node_id", # Duplicates task_id
  104. "task_group", # Doesn't have a useful repr, no point showing in UI
  105. "inherits_from_empty_operator", # impl detail
  106. # Decide whether to start task execution from triggerer
  107. "start_trigger_args",
  108. "start_from_trigger",
  109. # For compatibility with TG, for operators these are just the current task, no point showing
  110. "roots",
  111. "leaves",
  112. # These lists are already shown via *_task_ids
  113. "upstream_list",
  114. "downstream_list",
  115. # Not useful, implementation detail, already shown elsewhere
  116. "global_operator_extra_link_dict",
  117. "operator_extra_link_dict",
  118. )
  119. )
  120. def get_dag(self) -> DAG | None:
  121. raise NotImplementedError()
  122. @property
  123. def task_type(self) -> str:
  124. raise NotImplementedError()
  125. @property
  126. def operator_name(self) -> str:
  127. raise NotImplementedError()
  128. @property
  129. def inherits_from_empty_operator(self) -> bool:
  130. raise NotImplementedError()
  131. @property
  132. def dag_id(self) -> str:
  133. """Returns dag id if it has one or an adhoc + owner."""
  134. dag = self.get_dag()
  135. if dag:
  136. return dag.dag_id
  137. return f"adhoc_{self.owner}"
  138. @property
  139. def node_id(self) -> str:
  140. return self.task_id
  141. @abstractproperty
  142. def task_display_name(self) -> str: ...
  143. @property
  144. def label(self) -> str | None:
  145. if self.task_display_name and self.task_display_name != self.task_id:
  146. return self.task_display_name
  147. # Prefix handling if no display is given is cloned from taskmixin for compatibility
  148. tg = self.task_group
  149. if tg and tg.node_id and tg.prefix_group_id:
  150. # "task_group_id.task_id" -> "task_id"
  151. return self.task_id[len(tg.node_id) + 1 :]
  152. return self.task_id
  153. @property
  154. def is_setup(self) -> bool:
  155. raise NotImplementedError()
  156. @is_setup.setter
  157. def is_setup(self, value: bool) -> None:
  158. raise NotImplementedError()
  159. @property
  160. def is_teardown(self) -> bool:
  161. raise NotImplementedError()
  162. @is_teardown.setter
  163. def is_teardown(self, value: bool) -> None:
  164. raise NotImplementedError()
  165. @property
  166. def on_failure_fail_dagrun(self):
  167. """
  168. Whether the operator should fail the dagrun on failure.
  169. :meta private:
  170. """
  171. return self._on_failure_fail_dagrun
  172. @on_failure_fail_dagrun.setter
  173. def on_failure_fail_dagrun(self, value):
  174. """
  175. Setter for on_failure_fail_dagrun property.
  176. :meta private:
  177. """
  178. if value is True and self.is_teardown is not True:
  179. raise ValueError(
  180. f"Cannot set task on_failure_fail_dagrun for "
  181. f"'{self.task_id}' because it is not a teardown task."
  182. )
  183. self._on_failure_fail_dagrun = value
  184. def as_setup(self):
  185. self.is_setup = True
  186. return self
  187. def as_teardown(
  188. self,
  189. *,
  190. setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
  191. on_failure_fail_dagrun=NOTSET,
  192. ):
  193. self.is_teardown = True
  194. self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
  195. if on_failure_fail_dagrun is not NOTSET:
  196. self.on_failure_fail_dagrun = on_failure_fail_dagrun
  197. if not isinstance(setups, ArgNotSet):
  198. setups = [setups] if isinstance(setups, DependencyMixin) else setups
  199. for s in setups:
  200. s.is_setup = True
  201. s >> self
  202. return self
  203. def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
  204. """Get direct relative IDs to the current task, upstream or downstream."""
  205. if upstream:
  206. return self.upstream_task_ids
  207. return self.downstream_task_ids
  208. def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
  209. """
  210. Get a flat set of relative IDs, upstream or downstream.
  211. Will recurse each relative found in the direction specified.
  212. :param upstream: Whether to look for upstream or downstream relatives.
  213. """
  214. dag = self.get_dag()
  215. if not dag:
  216. return set()
  217. relatives: set[str] = set()
  218. # This is intentionally implemented as a loop, instead of calling
  219. # get_direct_relative_ids() recursively, since Python has significant
  220. # limitation on stack level, and a recursive implementation can blow up
  221. # if a DAG contains very long routes.
  222. task_ids_to_trace = self.get_direct_relative_ids(upstream)
  223. while task_ids_to_trace:
  224. task_ids_to_trace_next: set[str] = set()
  225. for task_id in task_ids_to_trace:
  226. if task_id in relatives:
  227. continue
  228. task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
  229. relatives.add(task_id)
  230. task_ids_to_trace = task_ids_to_trace_next
  231. return relatives
  232. def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]:
  233. """Get a flat list of relatives, either upstream or downstream."""
  234. dag = self.get_dag()
  235. if not dag:
  236. return set()
  237. return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)]
  238. def get_upstreams_follow_setups(self) -> Iterable[Operator]:
  239. """All upstreams and, for each upstream setup, its respective teardowns."""
  240. for task in self.get_flat_relatives(upstream=True):
  241. yield task
  242. if task.is_setup:
  243. for t in task.downstream_list:
  244. if t.is_teardown and t != self:
  245. yield t
  246. def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
  247. """
  248. Only *relevant* upstream setups and their teardowns.
  249. This method is meant to be used when we are clearing the task (non-upstream) and we need
  250. to add in the *relevant* setups and their teardowns.
  251. Relevant in this case means, the setup has a teardown that is downstream of ``self``,
  252. or the setup has no teardowns.
  253. """
  254. downstream_teardown_ids = {
  255. x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
  256. }
  257. for task in self.get_flat_relatives(upstream=True):
  258. if not task.is_setup:
  259. continue
  260. has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown)
  261. # if task has no teardowns or has teardowns downstream of self
  262. if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
  263. yield task
  264. for t in task.downstream_list:
  265. if t.is_teardown and t != self:
  266. yield t
  267. def get_upstreams_only_setups(self) -> Iterable[Operator]:
  268. """
  269. Return relevant upstream setups.
  270. This method is meant to be used when we are checking task dependencies where we need
  271. to wait for all the upstream setups to complete before we can run the task.
  272. """
  273. for task in self.get_upstreams_only_setups_and_teardowns():
  274. if task.is_setup:
  275. yield task
  276. def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
  277. """
  278. Return mapped nodes that are direct dependencies of the current task.
  279. For now, this walks the entire DAG to find mapped nodes that has this
  280. current task as an upstream. We cannot use ``downstream_list`` since it
  281. only contains operators, not task groups. In the future, we should
  282. provide a way to record an DAG node's all downstream nodes instead.
  283. Note that this does not guarantee the returned tasks actually use the
  284. current task for task mapping, but only checks those task are mapped
  285. operators, and are downstreams of the current task.
  286. To get a list of tasks that uses the current task for task mapping, use
  287. :meth:`iter_mapped_dependants` instead.
  288. """
  289. from airflow.models.mappedoperator import MappedOperator
  290. from airflow.utils.task_group import TaskGroup
  291. def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
  292. """
  293. Recursively walk children in a task group.
  294. This yields all direct children (including both tasks and task
  295. groups), and all children of any task groups.
  296. """
  297. for key, child in group.children.items():
  298. yield key, child
  299. if isinstance(child, TaskGroup):
  300. yield from _walk_group(child)
  301. dag = self.get_dag()
  302. if not dag:
  303. raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
  304. for key, child in _walk_group(dag.task_group):
  305. if key == self.node_id:
  306. continue
  307. if not isinstance(child, (MappedOperator, MappedTaskGroup)):
  308. continue
  309. if self.node_id in child.upstream_task_ids:
  310. yield child
  311. def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
  312. """
  313. Return mapped nodes that depend on the current task the expansion.
  314. For now, this walks the entire DAG to find mapped nodes that has this
  315. current task as an upstream. We cannot use ``downstream_list`` since it
  316. only contains operators, not task groups. In the future, we should
  317. provide a way to record an DAG node's all downstream nodes instead.
  318. """
  319. return (
  320. downstream
  321. for downstream in self._iter_all_mapped_downstreams()
  322. if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
  323. )
  324. def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
  325. """
  326. Return mapped task groups this task belongs to.
  327. Groups are returned from the innermost to the outmost.
  328. :meta private:
  329. """
  330. if (group := self.task_group) is None:
  331. return
  332. yield from group.iter_mapped_task_groups()
  333. def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
  334. """
  335. Get the mapped task group "closest" to this task in the DAG.
  336. :meta private:
  337. """
  338. return next(self.iter_mapped_task_groups(), None)
  339. def get_needs_expansion(self) -> bool:
  340. """
  341. Return true if the task is MappedOperator or is in a mapped task group.
  342. :meta private:
  343. """
  344. if self._needs_expansion is None:
  345. if self.get_closest_mapped_task_group() is not None:
  346. self._needs_expansion = True
  347. else:
  348. self._needs_expansion = False
  349. return self._needs_expansion
  350. def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
  351. """
  352. Get the "normal" operator from current abstract operator.
  353. MappedOperator uses this to unmap itself based on the map index. A non-
  354. mapped operator (i.e. BaseOperator subclass) simply returns itself.
  355. :meta private:
  356. """
  357. raise NotImplementedError()
  358. def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
  359. """
  360. Get the start_from_trigger value of the current abstract operator.
  361. MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
  362. execution directly from triggerer.
  363. :meta private:
  364. """
  365. raise NotImplementedError()
  366. def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
  367. """
  368. Get the start_trigger_args value of the current abstract operator.
  369. MappedOperator uses this to unmap start_trigger_args to decide how to start a task from triggerer.
  370. :meta private:
  371. """
  372. raise NotImplementedError()
  373. @property
  374. def priority_weight_total(self) -> int:
  375. """
  376. Total priority weight for the task. It might include all upstream or downstream tasks.
  377. Depending on the weight rule:
  378. - WeightRule.ABSOLUTE - only own weight
  379. - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
  380. - WeightRule.UPSTREAM - adds priority weight of all upstream tasks
  381. """
  382. from airflow.task.priority_strategy import (
  383. _AbsolutePriorityWeightStrategy,
  384. _DownstreamPriorityWeightStrategy,
  385. _UpstreamPriorityWeightStrategy,
  386. )
  387. if isinstance(self.weight_rule, _AbsolutePriorityWeightStrategy):
  388. return db_safe_priority(self.priority_weight)
  389. elif isinstance(self.weight_rule, _DownstreamPriorityWeightStrategy):
  390. upstream = False
  391. elif isinstance(self.weight_rule, _UpstreamPriorityWeightStrategy):
  392. upstream = True
  393. else:
  394. upstream = False
  395. dag = self.get_dag()
  396. if dag is None:
  397. return db_safe_priority(self.priority_weight)
  398. return db_safe_priority(
  399. self.priority_weight
  400. + sum(
  401. dag.task_dict[task_id].priority_weight
  402. for task_id in self.get_flat_relative_ids(upstream=upstream)
  403. )
  404. )
  405. @cached_property
  406. def operator_extra_link_dict(self) -> dict[str, Any]:
  407. """Returns dictionary of all extra links for the operator."""
  408. op_extra_links_from_plugin: dict[str, Any] = {}
  409. from airflow import plugins_manager
  410. plugins_manager.initialize_extra_operators_links_plugins()
  411. if plugins_manager.operator_extra_links is None:
  412. raise AirflowException("Can't load operators")
  413. for ope in plugins_manager.operator_extra_links:
  414. if ope.operators and self.operator_class in ope.operators:
  415. op_extra_links_from_plugin.update({ope.name: ope})
  416. operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
  417. # Extra links defined in Plugins overrides operator links defined in operator
  418. operator_extra_links_all.update(op_extra_links_from_plugin)
  419. return operator_extra_links_all
  420. @cached_property
  421. def global_operator_extra_link_dict(self) -> dict[str, Any]:
  422. """Returns dictionary of all global extra links."""
  423. from airflow import plugins_manager
  424. plugins_manager.initialize_extra_operators_links_plugins()
  425. if plugins_manager.global_operator_extra_links is None:
  426. raise AirflowException("Can't load operators")
  427. return {link.name: link for link in plugins_manager.global_operator_extra_links}
  428. @cached_property
  429. def extra_links(self) -> list[str]:
  430. return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
  431. def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
  432. """
  433. For an operator, gets the URLs that the ``extra_links`` entry points to.
  434. :meta private:
  435. :raise ValueError: The error message of a ValueError will be passed on through to
  436. the fronted to show up as a tooltip on the disabled link.
  437. :param ti: The TaskInstance for the URL being searched for.
  438. :param link_name: The name of the link we're looking for the URL for. Should be
  439. one of the options specified in ``extra_links``.
  440. """
  441. link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
  442. if not link:
  443. link = self.global_operator_extra_link_dict.get(link_name)
  444. if not link:
  445. return None
  446. parameters = inspect.signature(link.get_link).parameters
  447. old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)
  448. if old_signature:
  449. return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
  450. return link.get_link(self.unmap(None), ti_key=ti.key)
  451. @methodtools.lru_cache(maxsize=None)
  452. def get_parse_time_mapped_ti_count(self) -> int:
  453. """
  454. Return the number of mapped task instances that can be created on DAG run creation.
  455. This only considers literal mapped arguments, and would return *None*
  456. when any non-literal values are used for mapping.
  457. :raise NotFullyPopulated: If non-literal mapped arguments are encountered.
  458. :raise NotMapped: If the operator is neither mapped, nor has any parent
  459. mapped task groups.
  460. :return: Total number of mapped TIs this task should have.
  461. """
  462. group = self.get_closest_mapped_task_group()
  463. if group is None:
  464. raise NotMapped
  465. return group.get_parse_time_mapped_ti_count()
  466. def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
  467. """
  468. Return the number of mapped TaskInstances that can be created at run time.
  469. This considers both literal and non-literal mapped arguments, and the
  470. result is therefore available when all depended tasks have finished. The
  471. return value should be identical to ``parse_time_mapped_ti_count`` if
  472. all mapped arguments are literal.
  473. :raise NotFullyPopulated: If upstream tasks are not all complete yet.
  474. :raise NotMapped: If the operator is neither mapped, nor has any parent
  475. mapped task groups.
  476. :return: Total number of mapped TIs this task should have.
  477. """
  478. group = self.get_closest_mapped_task_group()
  479. if group is None:
  480. raise NotMapped
  481. return group.get_mapped_ti_count(run_id, session=session)
  482. def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
  483. """
  484. Create the mapped task instances for mapped task.
  485. :raise NotMapped: If this task does not need expansion.
  486. :return: The newly created mapped task instances (if any) in ascending
  487. order by map index, and the maximum map index value.
  488. """
  489. from sqlalchemy import func, or_
  490. from airflow.models.baseoperator import BaseOperator
  491. from airflow.models.mappedoperator import MappedOperator
  492. from airflow.models.taskinstance import TaskInstance
  493. from airflow.settings import task_instance_mutation_hook
  494. if not isinstance(self, (BaseOperator, MappedOperator)):
  495. raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}")
  496. try:
  497. total_length: int | None = self.get_mapped_ti_count(run_id, session=session)
  498. except NotFullyPopulated as e:
  499. # It's possible that the upstream tasks are not yet done, but we
  500. # don't have upstream of upstreams in partial DAGs (possible in the
  501. # mini-scheduler), so we ignore this exception.
  502. if not self.dag or not self.dag.partial:
  503. self.log.error(
  504. "Cannot expand %r for run %s; missing upstream values: %s",
  505. self,
  506. run_id,
  507. sorted(e.missing),
  508. )
  509. total_length = None
  510. state: TaskInstanceState | None = None
  511. unmapped_ti: TaskInstance | None = session.scalars(
  512. select(TaskInstance).where(
  513. TaskInstance.dag_id == self.dag_id,
  514. TaskInstance.task_id == self.task_id,
  515. TaskInstance.run_id == run_id,
  516. TaskInstance.map_index == -1,
  517. or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
  518. )
  519. ).one_or_none()
  520. all_expanded_tis: list[TaskInstance] = []
  521. if unmapped_ti:
  522. # The unmapped task instance still exists and is unfinished, i.e. we
  523. # haven't tried to run it before.
  524. if total_length is None:
  525. # If the DAG is partial, it's likely that the upstream tasks
  526. # are not done yet, so the task can't fail yet.
  527. if not self.dag or not self.dag.partial:
  528. unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
  529. elif total_length < 1:
  530. # If the upstream maps this to a zero-length value, simply mark
  531. # the unmapped task instance as SKIPPED (if needed).
  532. self.log.info(
  533. "Marking %s as SKIPPED since the map has %d values to expand",
  534. unmapped_ti,
  535. total_length,
  536. )
  537. unmapped_ti.state = TaskInstanceState.SKIPPED
  538. else:
  539. zero_index_ti_exists = exists_query(
  540. TaskInstance.dag_id == self.dag_id,
  541. TaskInstance.task_id == self.task_id,
  542. TaskInstance.run_id == run_id,
  543. TaskInstance.map_index == 0,
  544. session=session,
  545. )
  546. if not zero_index_ti_exists:
  547. # Otherwise convert this into the first mapped index, and create
  548. # TaskInstance for other indexes.
  549. unmapped_ti.map_index = 0
  550. self.log.debug("Updated in place to become %s", unmapped_ti)
  551. all_expanded_tis.append(unmapped_ti)
  552. # execute hook for task instance map index 0
  553. task_instance_mutation_hook(unmapped_ti)
  554. session.flush()
  555. else:
  556. self.log.debug("Deleting the original task instance: %s", unmapped_ti)
  557. session.delete(unmapped_ti)
  558. state = unmapped_ti.state
  559. if total_length is None or total_length < 1:
  560. # Nothing to fixup.
  561. indexes_to_map: Iterable[int] = ()
  562. else:
  563. # Only create "missing" ones.
  564. current_max_mapping = session.scalar(
  565. select(func.max(TaskInstance.map_index)).where(
  566. TaskInstance.dag_id == self.dag_id,
  567. TaskInstance.task_id == self.task_id,
  568. TaskInstance.run_id == run_id,
  569. )
  570. )
  571. indexes_to_map = range(current_max_mapping + 1, total_length)
  572. for index in indexes_to_map:
  573. # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
  574. ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
  575. self.log.debug("Expanding TIs upserted %s", ti)
  576. task_instance_mutation_hook(ti)
  577. ti = session.merge(ti)
  578. ti.refresh_from_task(self) # session.merge() loses task information.
  579. all_expanded_tis.append(ti)
  580. # Coerce the None case to 0 -- these two are almost treated identically,
  581. # except the unmapped ti (if exists) is marked to different states.
  582. total_expanded_ti_count = total_length or 0
  583. # Any (old) task instances with inapplicable indexes (>= the total
  584. # number we need) are set to "REMOVED".
  585. query = select(TaskInstance).where(
  586. TaskInstance.dag_id == self.dag_id,
  587. TaskInstance.task_id == self.task_id,
  588. TaskInstance.run_id == run_id,
  589. TaskInstance.map_index >= total_expanded_ti_count,
  590. )
  591. query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True)
  592. to_update = session.scalars(query)
  593. for ti in to_update:
  594. ti.state = TaskInstanceState.REMOVED
  595. session.flush()
  596. return all_expanded_tis, total_expanded_ti_count - 1
  597. def render_template_fields(
  598. self,
  599. context: Context,
  600. jinja_env: jinja2.Environment | None = None,
  601. ) -> None:
  602. """
  603. Template all attributes listed in *self.template_fields*.
  604. If the operator is mapped, this should return the unmapped, fully
  605. rendered, and map-expanded operator. The mapped operator should not be
  606. modified. However, *context* may be modified in-place to reference the
  607. unmapped operator for template rendering.
  608. If the operator is not mapped, this should modify the operator in-place.
  609. """
  610. raise NotImplementedError()
  611. def _render(self, template, context, dag: DAG | None = None):
  612. if dag is None:
  613. dag = self.get_dag()
  614. return super()._render(template, context, dag=dag)
  615. def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
  616. """Get the template environment for rendering templates."""
  617. if dag is None:
  618. dag = self.get_dag()
  619. return super().get_template_env(dag=dag)
  620. def _do_render_template_fields(
  621. self,
  622. parent: Any,
  623. template_fields: Iterable[str],
  624. context: Context,
  625. jinja_env: jinja2.Environment,
  626. seen_oids: set[int],
  627. ) -> None:
  628. """Override the base to use custom error logging."""
  629. for attr_name in template_fields:
  630. try:
  631. value = getattr(parent, attr_name)
  632. except AttributeError:
  633. raise AttributeError(
  634. f"{attr_name!r} is configured as a template field "
  635. f"but {parent.task_type} does not have this attribute."
  636. )
  637. try:
  638. if not value:
  639. continue
  640. except Exception:
  641. # This may happen if the templated field points to a class which does not support `__bool__`,
  642. # such as Pandas DataFrames:
  643. # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
  644. self.log.info(
  645. "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
  646. type(value).__name__,
  647. self.task_id,
  648. attr_name,
  649. )
  650. # We may still want to render custom classes which do not support __bool__
  651. pass
  652. try:
  653. if callable(value):
  654. rendered_content = value(context=context, jinja_env=jinja_env)
  655. else:
  656. rendered_content = self.render_template(
  657. value,
  658. context,
  659. jinja_env,
  660. seen_oids,
  661. )
  662. except Exception:
  663. value_masked = redact(name=attr_name, value=value)
  664. self.log.exception(
  665. "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
  666. self.task_id,
  667. attr_name,
  668. value_masked,
  669. )
  670. raise
  671. else:
  672. setattr(parent, attr_name, rendered_content)
  673. def __enter__(self):
  674. if not self.is_setup and not self.is_teardown:
  675. raise AirflowException("Only setup/teardown tasks can be used as context managers.")
  676. SetupTeardownContext.push_setup_teardown_task(self)
  677. return SetupTeardownContext
  678. def __exit__(self, exc_type, exc_val, exc_tb):
  679. SetupTeardownContext.set_work_task_roots_and_leaves()