taskinstance.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from datetime import datetime
  19. from typing import TYPE_CHECKING, Any, Iterable, Optional
  20. from typing_extensions import Annotated
  21. from airflow.exceptions import AirflowRescheduleException, TaskDeferred
  22. from airflow.models import Operator
  23. from airflow.models.baseoperator import BaseOperator
  24. from airflow.models.taskinstance import (
  25. TaskInstance,
  26. TaskReturnCode,
  27. _defer_task,
  28. _handle_reschedule,
  29. _run_raw_task,
  30. _set_ti_attrs,
  31. )
  32. from airflow.serialization.pydantic.dag import DagModelPydantic
  33. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  34. from airflow.utils.log.logging_mixin import LoggingMixin
  35. from airflow.utils.net import get_hostname
  36. from airflow.utils.pydantic import (
  37. BaseModel as BaseModelPydantic,
  38. ConfigDict,
  39. PlainSerializer,
  40. PlainValidator,
  41. is_pydantic_2_installed,
  42. )
  43. from airflow.utils.xcom import XCOM_RETURN_KEY
  44. if TYPE_CHECKING:
  45. import pendulum
  46. from sqlalchemy.orm import Session
  47. from airflow.models.dagrun import DagRun
  48. from airflow.utils.context import Context
  49. from airflow.utils.pydantic import ValidationInfo
  50. from airflow.utils.state import DagRunState
  51. def serialize_operator(x: Operator | None) -> dict | None:
  52. if x:
  53. from airflow.serialization.serialized_objects import BaseSerialization
  54. return BaseSerialization.serialize(x, use_pydantic_models=True)
  55. return None
  56. def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any:
  57. from airflow.models.baseoperator import BaseOperator
  58. from airflow.models.mappedoperator import MappedOperator
  59. if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None:
  60. return x
  61. from airflow.serialization.serialized_objects import BaseSerialization
  62. return BaseSerialization.deserialize(x, use_pydantic_models=True)
  63. PydanticOperator = Annotated[
  64. Operator,
  65. PlainValidator(validated_operator),
  66. PlainSerializer(serialize_operator, return_type=dict),
  67. ]
  68. class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
  69. """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API."""
  70. task_id: str
  71. dag_id: str
  72. run_id: str
  73. map_index: int
  74. start_date: Optional[datetime]
  75. end_date: Optional[datetime]
  76. execution_date: Optional[datetime]
  77. duration: Optional[float]
  78. state: Optional[str]
  79. try_number: int
  80. max_tries: int
  81. hostname: str
  82. unixname: str
  83. job_id: Optional[int]
  84. pool: str
  85. pool_slots: int
  86. queue: str
  87. priority_weight: Optional[int]
  88. operator: str
  89. custom_operator_name: Optional[str]
  90. queued_dttm: Optional[datetime]
  91. queued_by_job_id: Optional[int]
  92. pid: Optional[int]
  93. executor: Optional[str]
  94. executor_config: Any
  95. updated_at: Optional[datetime]
  96. rendered_map_index: Optional[str]
  97. external_executor_id: Optional[str]
  98. trigger_id: Optional[int]
  99. trigger_timeout: Optional[datetime]
  100. next_method: Optional[str]
  101. next_kwargs: Optional[dict]
  102. run_as_user: Optional[str]
  103. task: Optional[PydanticOperator]
  104. test_mode: bool
  105. dag_run: Optional[DagRunPydantic]
  106. dag_model: Optional[DagModelPydantic]
  107. raw: Optional[bool]
  108. is_trigger_log_context: Optional[bool]
  109. model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
  110. @property
  111. def _logger_name(self):
  112. return "airflow.task"
  113. def clear_xcom_data(self, session: Session | None = None):
  114. TaskInstance._clear_xcom_data(ti=self, session=session)
  115. def set_state(self, state, session: Session | None = None) -> bool:
  116. return TaskInstance._set_state(ti=self, state=state, session=session)
  117. def _run_raw_task(
  118. self,
  119. mark_success: bool = False,
  120. test_mode: bool = False,
  121. job_id: str | None = None,
  122. pool: str | None = None,
  123. raise_on_defer: bool = False,
  124. session: Session | None = None,
  125. ) -> TaskReturnCode | None:
  126. return _run_raw_task(
  127. ti=self,
  128. mark_success=mark_success,
  129. test_mode=test_mode,
  130. job_id=job_id,
  131. pool=pool,
  132. raise_on_defer=raise_on_defer,
  133. session=session,
  134. )
  135. def _run_execute_callback(self, context, task):
  136. TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type]
  137. def render_templates(self, context: Context | None = None, jinja_env=None):
  138. return TaskInstance.render_templates(self=self, context=context, jinja_env=jinja_env) # type: ignore[arg-type]
  139. def init_run_context(self, raw: bool = False) -> None:
  140. """Set the log context."""
  141. self.raw = raw
  142. self._set_context(self)
  143. def xcom_pull(
  144. self,
  145. task_ids: str | Iterable[str] | None = None,
  146. dag_id: str | None = None,
  147. key: str = XCOM_RETURN_KEY,
  148. include_prior_dates: bool = False,
  149. session: Session | None = None,
  150. *,
  151. map_indexes: int | Iterable[int] | None = None,
  152. default: Any = None,
  153. ) -> Any:
  154. """
  155. Pull an XCom value for this task instance.
  156. :param task_ids: task id or list of task ids, if None, the task_id of the current task is used
  157. :param dag_id: dag id, if None, the dag_id of the current task is used
  158. :param key: the key to identify the XCom value
  159. :param include_prior_dates: whether to include prior execution dates
  160. :param session: the sqlalchemy session
  161. :param map_indexes: map index or list of map indexes, if None, the map_index of the current task
  162. is used
  163. :param default: the default value to return if the XCom value does not exist
  164. :return: Xcom value
  165. """
  166. return TaskInstance.xcom_pull(
  167. self=self, # type: ignore[arg-type]
  168. task_ids=task_ids,
  169. dag_id=dag_id,
  170. key=key,
  171. include_prior_dates=include_prior_dates,
  172. map_indexes=map_indexes,
  173. default=default,
  174. session=session,
  175. )
  176. def xcom_push(
  177. self,
  178. key: str,
  179. value: Any,
  180. execution_date: datetime | None = None,
  181. session: Session | None = None,
  182. ) -> None:
  183. """
  184. Push an XCom value for this task instance.
  185. :param key: the key to identify the XCom value
  186. :param value: the value of the XCom
  187. :param execution_date: the execution date to push the XCom for
  188. """
  189. return TaskInstance.xcom_push(
  190. self=self, # type: ignore[arg-type]
  191. key=key,
  192. value=value,
  193. execution_date=execution_date,
  194. session=session,
  195. )
  196. def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
  197. """
  198. Return the DagRun for this TaskInstance.
  199. :param session: SQLAlchemy ORM Session
  200. :return: Pydantic serialized version of DagRun
  201. """
  202. return TaskInstance._get_dagrun(dag_id=self.dag_id, run_id=self.run_id, session=session)
  203. def _execute_task(self, context, task_orig):
  204. """
  205. Execute Task (optionally with a Timeout) and push Xcom results.
  206. :param context: Jinja2 context
  207. :param task_orig: origin task
  208. """
  209. from airflow.models.taskinstance import _execute_task
  210. return _execute_task(task_instance=self, context=context, task_orig=task_orig)
  211. def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None:
  212. """
  213. Refresh the task instance from the database based on the primary key.
  214. :param session: SQLAlchemy ORM Session
  215. :param lock_for_update: if True, indicates that the database should
  216. lock the TaskInstance (issuing a FOR UPDATE clause) until the
  217. session is committed.
  218. """
  219. from airflow.models.taskinstance import _refresh_from_db
  220. _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
  221. def set_duration(self) -> None:
  222. """Set task instance duration."""
  223. from airflow.models.taskinstance import _set_duration
  224. _set_duration(task_instance=self)
  225. @property
  226. def stats_tags(self) -> dict[str, str]:
  227. """Return task instance tags."""
  228. from airflow.models.taskinstance import _stats_tags
  229. return _stats_tags(task_instance=self)
  230. def clear_next_method_args(self) -> None:
  231. """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
  232. from airflow.models.taskinstance import _clear_next_method_args
  233. _clear_next_method_args(task_instance=self)
  234. def get_template_context(
  235. self,
  236. session: Session | None = None,
  237. ignore_param_exceptions: bool = True,
  238. ) -> Context:
  239. """
  240. Return TI Context.
  241. :param session: SQLAlchemy ORM Session
  242. :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
  243. """
  244. from airflow.models.taskinstance import _get_template_context
  245. if TYPE_CHECKING:
  246. assert self.task
  247. assert self.task.dag
  248. return _get_template_context(
  249. task_instance=self,
  250. dag=self.task.dag,
  251. session=session,
  252. ignore_param_exceptions=ignore_param_exceptions,
  253. )
  254. def is_eligible_to_retry(self):
  255. """Is task instance is eligible for retry."""
  256. from airflow.models.taskinstance import _is_eligible_to_retry
  257. return _is_eligible_to_retry(task_instance=self)
  258. def handle_failure(
  259. self,
  260. error: None | str | BaseException,
  261. test_mode: bool | None = None,
  262. context: Context | None = None,
  263. force_fail: bool = False,
  264. session: Session | None = None,
  265. ) -> None:
  266. """
  267. Handle Failure for a task instance.
  268. :param error: if specified, log the specific exception if thrown
  269. :param session: SQLAlchemy ORM Session
  270. :param test_mode: doesn't record success or failure in the DB if True
  271. :param context: Jinja2 context
  272. :param force_fail: if True, task does not retry
  273. """
  274. from airflow.models.taskinstance import _handle_failure
  275. if TYPE_CHECKING:
  276. assert self.task
  277. assert self.task.dag
  278. try:
  279. fail_stop = self.task.dag.fail_stop
  280. except Exception:
  281. fail_stop = False
  282. _handle_failure(
  283. task_instance=self,
  284. error=error,
  285. session=session,
  286. test_mode=test_mode,
  287. context=context,
  288. force_fail=force_fail,
  289. fail_stop=fail_stop,
  290. )
  291. def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
  292. """
  293. Copy common attributes from the given task.
  294. :param task: The task object to copy from
  295. :param pool_override: Use the pool_override instead of task's pool
  296. """
  297. from airflow.models.taskinstance import _refresh_from_task
  298. _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
  299. def get_previous_dagrun(
  300. self,
  301. state: DagRunState | None = None,
  302. session: Session | None = None,
  303. ) -> DagRun | None:
  304. """
  305. Return the DagRun that ran before this task instance's DagRun.
  306. :param state: If passed, it only take into account instances of a specific state.
  307. :param session: SQLAlchemy ORM Session.
  308. """
  309. from airflow.models.taskinstance import _get_previous_dagrun
  310. return _get_previous_dagrun(task_instance=self, state=state, session=session)
  311. def get_previous_execution_date(
  312. self,
  313. state: DagRunState | None = None,
  314. session: Session | None = None,
  315. ) -> pendulum.DateTime | None:
  316. """
  317. Return the execution date from property previous_ti_success.
  318. :param state: If passed, it only take into account instances of a specific state.
  319. :param session: SQLAlchemy ORM Session
  320. """
  321. from airflow.models.taskinstance import _get_previous_execution_date
  322. return _get_previous_execution_date(task_instance=self, state=state, session=session)
  323. def get_previous_start_date(
  324. self,
  325. state: DagRunState | None = None,
  326. session: Session | None = None,
  327. ) -> pendulum.DateTime | None:
  328. """
  329. Return the execution date from property previous_ti_success.
  330. :param state: If passed, it only take into account instances of a specific state.
  331. :param session: SQLAlchemy ORM Session
  332. """
  333. from airflow.models.taskinstance import _get_previous_start_date
  334. return _get_previous_start_date(task_instance=self, state=state, session=session)
  335. def email_alert(self, exception, task: BaseOperator) -> None:
  336. """
  337. Send alert email with exception information.
  338. :param exception: the exception
  339. :param task: task related to the exception
  340. """
  341. from airflow.models.taskinstance import _email_alert
  342. _email_alert(task_instance=self, exception=exception, task=task)
  343. def get_email_subject_content(
  344. self, exception: BaseException, task: BaseOperator | None = None
  345. ) -> tuple[str, str, str]:
  346. """
  347. Get the email subject content for exceptions.
  348. :param exception: the exception sent in the email
  349. :param task:
  350. """
  351. from airflow.models.taskinstance import _get_email_subject_content
  352. return _get_email_subject_content(task_instance=self, exception=exception, task=task)
  353. def get_previous_ti(
  354. self,
  355. state: DagRunState | None = None,
  356. session: Session | None = None,
  357. ) -> TaskInstance | TaskInstancePydantic | None:
  358. """
  359. Return the task instance for the task that ran before this task instance.
  360. :param session: SQLAlchemy ORM Session
  361. :param state: If passed, it only take into account instances of a specific state.
  362. """
  363. from airflow.models.taskinstance import _get_previous_ti
  364. return _get_previous_ti(task_instance=self, state=state, session=session)
  365. def check_and_change_state_before_execution(
  366. self,
  367. verbose: bool = True,
  368. ignore_all_deps: bool = False,
  369. ignore_depends_on_past: bool = False,
  370. wait_for_past_depends_before_skipping: bool = False,
  371. ignore_task_deps: bool = False,
  372. ignore_ti_state: bool = False,
  373. mark_success: bool = False,
  374. test_mode: bool = False,
  375. job_id: str | None = None,
  376. pool: str | None = None,
  377. external_executor_id: str | None = None,
  378. session: Session | None = None,
  379. ) -> bool:
  380. return TaskInstance._check_and_change_state_before_execution(
  381. task_instance=self,
  382. verbose=verbose,
  383. ignore_all_deps=ignore_all_deps,
  384. ignore_depends_on_past=ignore_depends_on_past,
  385. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  386. ignore_task_deps=ignore_task_deps,
  387. ignore_ti_state=ignore_ti_state,
  388. mark_success=mark_success,
  389. test_mode=test_mode,
  390. hostname=get_hostname(),
  391. job_id=job_id,
  392. pool=pool,
  393. external_executor_id=external_executor_id,
  394. session=session,
  395. )
  396. def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
  397. """
  398. Schedule downstream tasks of this task instance.
  399. :meta: private
  400. """
  401. # we should not schedule downstream tasks with Pydantic model because it will not be able to
  402. # get the DAG object (we do not serialize it currently).
  403. return
  404. def command_as_list(
  405. self,
  406. mark_success: bool = False,
  407. ignore_all_deps: bool = False,
  408. ignore_task_deps: bool = False,
  409. ignore_depends_on_past: bool = False,
  410. wait_for_past_depends_before_skipping: bool = False,
  411. ignore_ti_state: bool = False,
  412. local: bool = False,
  413. pickle_id: int | None = None,
  414. raw: bool = False,
  415. job_id: str | None = None,
  416. pool: str | None = None,
  417. cfg_path: str | None = None,
  418. ) -> list[str]:
  419. """
  420. Return a command that can be executed anywhere where airflow is installed.
  421. This command is part of the message sent to executors by the orchestrator.
  422. """
  423. return TaskInstance._command_as_list(
  424. ti=self,
  425. mark_success=mark_success,
  426. ignore_all_deps=ignore_all_deps,
  427. ignore_task_deps=ignore_task_deps,
  428. ignore_depends_on_past=ignore_depends_on_past,
  429. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  430. ignore_ti_state=ignore_ti_state,
  431. local=local,
  432. pickle_id=pickle_id,
  433. raw=raw,
  434. job_id=job_id,
  435. pool=pool,
  436. cfg_path=cfg_path,
  437. )
  438. def _register_dataset_changes(self, *, events, session: Session | None = None) -> None:
  439. TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type]
  440. def defer_task(self, exception: TaskDeferred, session: Session | None = None):
  441. """Defer task."""
  442. updated_ti = _defer_task(ti=self, exception=exception, session=session)
  443. _set_ti_attrs(self, updated_ti)
  444. def _handle_reschedule(
  445. self,
  446. actual_start_date: datetime,
  447. reschedule_exception: AirflowRescheduleException,
  448. test_mode: bool = False,
  449. session: Session | None = None,
  450. ):
  451. updated_ti = _handle_reschedule(
  452. ti=self,
  453. actual_start_date=actual_start_date,
  454. reschedule_exception=reschedule_exception,
  455. test_mode=test_mode,
  456. session=session,
  457. )
  458. _set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI
  459. def get_relevant_upstream_map_indexes(
  460. self,
  461. upstream: Operator,
  462. ti_count: int | None,
  463. *,
  464. session: Session | None = None,
  465. ) -> int | range | None:
  466. return TaskInstance.get_relevant_upstream_map_indexes(
  467. self=self, # type: ignore[arg-type]
  468. upstream=upstream,
  469. ti_count=ti_count,
  470. session=session,
  471. )
  472. if is_pydantic_2_installed():
  473. TaskInstancePydantic.model_rebuild()