xcom.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import inspect
  20. import json
  21. import logging
  22. import pickle
  23. import warnings
  24. from functools import wraps
  25. from typing import TYPE_CHECKING, Any, Iterable, cast, overload
  26. from sqlalchemy import (
  27. Column,
  28. ForeignKeyConstraint,
  29. Index,
  30. Integer,
  31. LargeBinary,
  32. PrimaryKeyConstraint,
  33. String,
  34. delete,
  35. select,
  36. text,
  37. )
  38. from sqlalchemy.dialects.mysql import LONGBLOB
  39. from sqlalchemy.ext.associationproxy import association_proxy
  40. from sqlalchemy.orm import Query, reconstructor, relationship
  41. from sqlalchemy.orm.exc import NoResultFound
  42. from airflow.api_internal.internal_api_call import internal_api_call
  43. from airflow.configuration import conf
  44. from airflow.exceptions import RemovedInAirflow3Warning
  45. from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
  46. from airflow.utils import timezone
  47. from airflow.utils.db import LazySelectSequence
  48. from airflow.utils.helpers import exactly_one, is_container
  49. from airflow.utils.json import XComDecoder, XComEncoder
  50. from airflow.utils.log.logging_mixin import LoggingMixin
  51. from airflow.utils.session import NEW_SESSION, provide_session
  52. from airflow.utils.sqlalchemy import UtcDateTime
  53. # XCom constants below are needed for providers backward compatibility,
  54. # which should import the constants directly after apache-airflow>=2.6.0
  55. from airflow.utils.xcom import (
  56. MAX_XCOM_SIZE, # noqa: F401
  57. XCOM_RETURN_KEY,
  58. )
  59. log = logging.getLogger(__name__)
  60. if TYPE_CHECKING:
  61. import datetime
  62. import pendulum
  63. from sqlalchemy.engine import Row
  64. from sqlalchemy.orm import Session
  65. from sqlalchemy.sql.expression import Select, TextClause
  66. from airflow.models.taskinstancekey import TaskInstanceKey
  67. class BaseXCom(TaskInstanceDependencies, LoggingMixin):
  68. """Base class for XCom objects."""
  69. __tablename__ = "xcom"
  70. dag_run_id = Column(Integer(), nullable=False, primary_key=True)
  71. task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
  72. map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
  73. key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
  74. # Denormalized for easier lookup.
  75. dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
  76. run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
  77. value = Column(LargeBinary().with_variant(LONGBLOB, "mysql"))
  78. timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  79. __table_args__ = (
  80. # Ideally we should create a unique index over (key, dag_id, task_id, run_id),
  81. # but it goes over MySQL's index length limit. So we instead index 'key'
  82. # separately, and enforce uniqueness with DagRun.id instead.
  83. Index("idx_xcom_key", key),
  84. Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
  85. PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="xcom_pkey"),
  86. ForeignKeyConstraint(
  87. [dag_id, task_id, run_id, map_index],
  88. [
  89. "task_instance.dag_id",
  90. "task_instance.task_id",
  91. "task_instance.run_id",
  92. "task_instance.map_index",
  93. ],
  94. name="xcom_task_instance_fkey",
  95. ondelete="CASCADE",
  96. ),
  97. )
  98. dag_run = relationship(
  99. "DagRun",
  100. primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)",
  101. uselist=False,
  102. lazy="joined",
  103. passive_deletes="all",
  104. )
  105. execution_date = association_proxy("dag_run", "execution_date")
  106. @reconstructor
  107. def init_on_load(self):
  108. """
  109. Execute after the instance has been loaded from the DB or otherwise reconstituted; called by the ORM.
  110. i.e automatically deserialize Xcom value when loading from DB.
  111. """
  112. self.value = self.orm_deserialize_value()
  113. def __repr__(self):
  114. if self.map_index < 0:
  115. return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>'
  116. return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @ {self.run_id})>'
  117. @overload
  118. @classmethod
  119. def set(
  120. cls,
  121. key: str,
  122. value: Any,
  123. *,
  124. dag_id: str,
  125. task_id: str,
  126. run_id: str,
  127. map_index: int = -1,
  128. session: Session = NEW_SESSION,
  129. ) -> None:
  130. """
  131. Store an XCom value.
  132. A deprecated form of this function accepts ``execution_date`` instead of
  133. ``run_id``. The two arguments are mutually exclusive.
  134. :param key: Key to store the XCom.
  135. :param value: XCom value to store.
  136. :param dag_id: DAG ID.
  137. :param task_id: Task ID.
  138. :param run_id: DAG run ID for the task.
  139. :param map_index: Optional map index to assign XCom for a mapped task.
  140. The default is ``-1`` (set for a non-mapped task).
  141. :param session: Database session. If not given, a new session will be
  142. created for this function.
  143. """
  144. @overload
  145. @classmethod
  146. def set(
  147. cls,
  148. key: str,
  149. value: Any,
  150. task_id: str,
  151. dag_id: str,
  152. execution_date: datetime.datetime,
  153. session: Session = NEW_SESSION,
  154. ) -> None:
  155. """
  156. Store an XCom value.
  157. :sphinx-autoapi-skip:
  158. """
  159. @classmethod
  160. @internal_api_call
  161. @provide_session
  162. def set(
  163. cls,
  164. key: str,
  165. value: Any,
  166. task_id: str,
  167. dag_id: str,
  168. execution_date: datetime.datetime | None = None,
  169. session: Session = NEW_SESSION,
  170. *,
  171. run_id: str | None = None,
  172. map_index: int = -1,
  173. ) -> None:
  174. """
  175. Store an XCom value.
  176. :sphinx-autoapi-skip:
  177. """
  178. from airflow.models.dagrun import DagRun
  179. if not exactly_one(execution_date is not None, run_id is not None):
  180. raise ValueError(
  181. f"Exactly one of run_id or execution_date must be passed. "
  182. f"Passed execution_date={execution_date}, run_id={run_id}"
  183. )
  184. if run_id is None:
  185. message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
  186. warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
  187. try:
  188. dag_run_id, run_id = (
  189. session.query(DagRun.id, DagRun.run_id)
  190. .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
  191. .one()
  192. )
  193. except NoResultFound:
  194. raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
  195. else:
  196. dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
  197. if dag_run_id is None:
  198. raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
  199. # Seamlessly resolve LazySelectSequence to a list. This intends to work
  200. # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
  201. # it's pushed into XCom, the user should be aware of the performance
  202. # implications, and this avoids leaking the implementation detail.
  203. if isinstance(value, LazySelectSequence):
  204. warning_message = (
  205. "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
  206. "to list, which may degrade performance. Review resource "
  207. "requirements for this operation, and call list() to suppress "
  208. "this message. See Dynamic Task Mapping documentation for "
  209. "more information about lazy proxy objects."
  210. )
  211. log.warning(
  212. warning_message,
  213. "return value" if key == XCOM_RETURN_KEY else f"value {key}",
  214. task_id,
  215. dag_id,
  216. run_id or execution_date,
  217. )
  218. value = list(value)
  219. value = cls.serialize_value(
  220. value=value,
  221. key=key,
  222. task_id=task_id,
  223. dag_id=dag_id,
  224. run_id=run_id,
  225. map_index=map_index,
  226. )
  227. # Remove duplicate XComs and insert a new one.
  228. session.execute(
  229. delete(cls).where(
  230. cls.key == key,
  231. cls.run_id == run_id,
  232. cls.task_id == task_id,
  233. cls.dag_id == dag_id,
  234. cls.map_index == map_index,
  235. )
  236. )
  237. new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'.
  238. dag_run_id=dag_run_id,
  239. key=key,
  240. value=value,
  241. run_id=run_id,
  242. task_id=task_id,
  243. dag_id=dag_id,
  244. map_index=map_index,
  245. )
  246. session.add(new)
  247. session.flush()
  248. @staticmethod
  249. @provide_session
  250. @internal_api_call
  251. def get_value(
  252. *,
  253. ti_key: TaskInstanceKey,
  254. key: str | None = None,
  255. session: Session = NEW_SESSION,
  256. ) -> Any:
  257. """
  258. Retrieve an XCom value for a task instance.
  259. This method returns "full" XCom values (i.e. uses ``deserialize_value``
  260. from the XCom backend). Use :meth:`get_many` if you want the "shortened"
  261. value via ``orm_deserialize_value``.
  262. If there are no results, *None* is returned. If multiple XCom entries
  263. match the criteria, an arbitrary one is returned.
  264. :param ti_key: The TaskInstanceKey to look up the XCom for.
  265. :param key: A key for the XCom. If provided, only XCom with matching
  266. keys will be returned. Pass *None* (default) to remove the filter.
  267. :param session: Database session. If not given, a new session will be
  268. created for this function.
  269. """
  270. return BaseXCom.get_one(
  271. key=key,
  272. task_id=ti_key.task_id,
  273. dag_id=ti_key.dag_id,
  274. run_id=ti_key.run_id,
  275. map_index=ti_key.map_index,
  276. session=session,
  277. )
  278. @overload
  279. @staticmethod
  280. @internal_api_call
  281. def get_one(
  282. *,
  283. key: str | None = None,
  284. dag_id: str | None = None,
  285. task_id: str | None = None,
  286. run_id: str | None = None,
  287. map_index: int | None = None,
  288. session: Session = NEW_SESSION,
  289. ) -> Any | None:
  290. """
  291. Retrieve an XCom value, optionally meeting certain criteria.
  292. This method returns "full" XCom values (i.e. uses ``deserialize_value``
  293. from the XCom backend). Use :meth:`get_many` if you want the "shortened"
  294. value via ``orm_deserialize_value``.
  295. If there are no results, *None* is returned. If multiple XCom entries
  296. match the criteria, an arbitrary one is returned.
  297. A deprecated form of this function accepts ``execution_date`` instead of
  298. ``run_id``. The two arguments are mutually exclusive.
  299. .. seealso:: ``get_value()`` is a convenience function if you already
  300. have a structured TaskInstance or TaskInstanceKey object available.
  301. :param run_id: DAG run ID for the task.
  302. :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
  303. remove the filter.
  304. :param task_id: Only XCom from task with matching ID will be pulled.
  305. Pass *None* (default) to remove the filter.
  306. :param map_index: Only XCom from task with matching ID will be pulled.
  307. Pass *None* (default) to remove the filter.
  308. :param key: A key for the XCom. If provided, only XCom with matching
  309. keys will be returned. Pass *None* (default) to remove the filter.
  310. :param include_prior_dates: If *False* (default), only XCom from the
  311. specified DAG run is returned. If *True*, the latest matching XCom is
  312. returned regardless of the run it belongs to.
  313. :param session: Database session. If not given, a new session will be
  314. created for this function.
  315. """
  316. @overload
  317. @staticmethod
  318. @internal_api_call
  319. def get_one(
  320. execution_date: datetime.datetime,
  321. key: str | None = None,
  322. task_id: str | None = None,
  323. dag_id: str | None = None,
  324. include_prior_dates: bool = False,
  325. session: Session = NEW_SESSION,
  326. ) -> Any | None:
  327. """
  328. Retrieve an XCom value, optionally meeting certain criteria.
  329. :sphinx-autoapi-skip:
  330. """
  331. @staticmethod
  332. @provide_session
  333. @internal_api_call
  334. def get_one(
  335. execution_date: datetime.datetime | None = None,
  336. key: str | None = None,
  337. task_id: str | None = None,
  338. dag_id: str | None = None,
  339. include_prior_dates: bool = False,
  340. session: Session = NEW_SESSION,
  341. *,
  342. run_id: str | None = None,
  343. map_index: int | None = None,
  344. ) -> Any | None:
  345. """
  346. Retrieve an XCom value, optionally meeting certain criteria.
  347. :sphinx-autoapi-skip:
  348. """
  349. if not exactly_one(execution_date is not None, run_id is not None):
  350. raise ValueError("Exactly one of run_id or execution_date must be passed")
  351. if run_id:
  352. query = BaseXCom.get_many(
  353. run_id=run_id,
  354. key=key,
  355. task_ids=task_id,
  356. dag_ids=dag_id,
  357. map_indexes=map_index,
  358. include_prior_dates=include_prior_dates,
  359. limit=1,
  360. session=session,
  361. )
  362. elif execution_date is not None:
  363. message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
  364. warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
  365. with warnings.catch_warnings():
  366. warnings.simplefilter("ignore", RemovedInAirflow3Warning)
  367. query = BaseXCom.get_many(
  368. execution_date=execution_date,
  369. key=key,
  370. task_ids=task_id,
  371. dag_ids=dag_id,
  372. map_indexes=map_index,
  373. include_prior_dates=include_prior_dates,
  374. limit=1,
  375. session=session,
  376. )
  377. else:
  378. raise RuntimeError("Should not happen?")
  379. result = query.with_entities(BaseXCom.value).first()
  380. if result:
  381. return XCom.deserialize_value(result)
  382. return None
  383. @overload
  384. @staticmethod
  385. def get_many(
  386. *,
  387. run_id: str,
  388. key: str | None = None,
  389. task_ids: str | Iterable[str] | None = None,
  390. dag_ids: str | Iterable[str] | None = None,
  391. map_indexes: int | Iterable[int] | None = None,
  392. include_prior_dates: bool = False,
  393. limit: int | None = None,
  394. session: Session = NEW_SESSION,
  395. ) -> Query:
  396. """
  397. Composes a query to get one or more XCom entries.
  398. This function returns an SQLAlchemy query of full XCom objects. If you
  399. just want one stored value, use :meth:`get_one` instead.
  400. A deprecated form of this function accepts ``execution_date`` instead of
  401. ``run_id``. The two arguments are mutually exclusive.
  402. :param run_id: DAG run ID for the task.
  403. :param key: A key for the XComs. If provided, only XComs with matching
  404. keys will be returned. Pass *None* (default) to remove the filter.
  405. :param task_ids: Only XComs from task with matching IDs will be pulled.
  406. Pass *None* (default) to remove the filter.
  407. :param dag_ids: Only pulls XComs from specified DAGs. Pass *None*
  408. (default) to remove the filter.
  409. :param map_indexes: Only XComs from matching map indexes will be pulled.
  410. Pass *None* (default) to remove the filter.
  411. :param include_prior_dates: If *False* (default), only XComs from the
  412. specified DAG run are returned. If *True*, all matching XComs are
  413. returned regardless of the run it belongs to.
  414. :param session: Database session. If not given, a new session will be
  415. created for this function.
  416. :param limit: Limiting returning XComs
  417. """
  418. @overload
  419. @staticmethod
  420. @internal_api_call
  421. def get_many(
  422. execution_date: datetime.datetime,
  423. key: str | None = None,
  424. task_ids: str | Iterable[str] | None = None,
  425. dag_ids: str | Iterable[str] | None = None,
  426. map_indexes: int | Iterable[int] | None = None,
  427. include_prior_dates: bool = False,
  428. limit: int | None = None,
  429. session: Session = NEW_SESSION,
  430. ) -> Query:
  431. """
  432. Composes a query to get one or more XCom entries.
  433. :sphinx-autoapi-skip:
  434. """
  435. # The 'get_many` is not supported via database isolation mode. Attempting to use it in DB isolation
  436. # mode will result in a crash - Resulting Query object cannot be **really** serialized
  437. # TODO(potiuk) - document it in AIP-44 docs
  438. @staticmethod
  439. @provide_session
  440. def get_many(
  441. execution_date: datetime.datetime | None = None,
  442. key: str | None = None,
  443. task_ids: str | Iterable[str] | None = None,
  444. dag_ids: str | Iterable[str] | None = None,
  445. map_indexes: int | Iterable[int] | None = None,
  446. include_prior_dates: bool = False,
  447. limit: int | None = None,
  448. session: Session = NEW_SESSION,
  449. *,
  450. run_id: str | None = None,
  451. ) -> Query:
  452. """
  453. Composes a query to get one or more XCom entries.
  454. :sphinx-autoapi-skip:
  455. """
  456. from airflow.models.dagrun import DagRun
  457. if not exactly_one(execution_date is not None, run_id is not None):
  458. raise ValueError(
  459. f"Exactly one of run_id or execution_date must be passed. "
  460. f"Passed execution_date={execution_date}, run_id={run_id}"
  461. )
  462. if execution_date is not None:
  463. message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
  464. warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
  465. query = session.query(BaseXCom).join(BaseXCom.dag_run)
  466. if key:
  467. query = query.filter(BaseXCom.key == key)
  468. if is_container(task_ids):
  469. query = query.filter(BaseXCom.task_id.in_(task_ids))
  470. elif task_ids is not None:
  471. query = query.filter(BaseXCom.task_id == task_ids)
  472. if is_container(dag_ids):
  473. query = query.filter(BaseXCom.dag_id.in_(dag_ids))
  474. elif dag_ids is not None:
  475. query = query.filter(BaseXCom.dag_id == dag_ids)
  476. if isinstance(map_indexes, range) and map_indexes.step == 1:
  477. query = query.filter(
  478. BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop
  479. )
  480. elif is_container(map_indexes):
  481. query = query.filter(BaseXCom.map_index.in_(map_indexes))
  482. elif map_indexes is not None:
  483. query = query.filter(BaseXCom.map_index == map_indexes)
  484. if include_prior_dates:
  485. if execution_date is not None:
  486. query = query.filter(DagRun.execution_date <= execution_date)
  487. else:
  488. dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
  489. query = query.filter(BaseXCom.execution_date <= dr.c.execution_date)
  490. elif execution_date is not None:
  491. query = query.filter(DagRun.execution_date == execution_date)
  492. else:
  493. query = query.filter(BaseXCom.run_id == run_id)
  494. query = query.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc())
  495. if limit:
  496. return query.limit(limit)
  497. return query
  498. @classmethod
  499. @provide_session
  500. def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None:
  501. """Delete one or multiple XCom entries."""
  502. if isinstance(xcoms, XCom):
  503. xcoms = [xcoms]
  504. for xcom in xcoms:
  505. if not isinstance(xcom, XCom):
  506. raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}")
  507. XCom.purge(xcom, session)
  508. session.delete(xcom)
  509. session.commit()
  510. @staticmethod
  511. def purge(xcom: XCom, session: Session) -> None:
  512. """Purge an XCom entry from underlying storage implementations."""
  513. pass
  514. @overload
  515. @staticmethod
  516. @internal_api_call
  517. def clear(
  518. *,
  519. dag_id: str,
  520. task_id: str,
  521. run_id: str,
  522. map_index: int | None = None,
  523. session: Session = NEW_SESSION,
  524. ) -> None:
  525. """
  526. Clear all XCom data from the database for the given task instance.
  527. A deprecated form of this function accepts ``execution_date`` instead of
  528. ``run_id``. The two arguments are mutually exclusive.
  529. :param dag_id: ID of DAG to clear the XCom for.
  530. :param task_id: ID of task to clear the XCom for.
  531. :param run_id: ID of DAG run to clear the XCom for.
  532. :param map_index: If given, only clear XCom from this particular mapped
  533. task. The default ``None`` clears *all* XComs from the task.
  534. :param session: Database session. If not given, a new session will be
  535. created for this function.
  536. """
  537. @overload
  538. @staticmethod
  539. @internal_api_call
  540. def clear(
  541. execution_date: pendulum.DateTime,
  542. dag_id: str,
  543. task_id: str,
  544. session: Session = NEW_SESSION,
  545. ) -> None:
  546. """
  547. Clear all XCom data from the database for the given task instance.
  548. :sphinx-autoapi-skip:
  549. """
  550. @staticmethod
  551. @provide_session
  552. @internal_api_call
  553. def clear(
  554. execution_date: pendulum.DateTime | None = None,
  555. dag_id: str | None = None,
  556. task_id: str | None = None,
  557. session: Session = NEW_SESSION,
  558. *,
  559. run_id: str | None = None,
  560. map_index: int | None = None,
  561. ) -> None:
  562. """
  563. Clear all XCom data from the database for the given task instance.
  564. :sphinx-autoapi-skip:
  565. """
  566. from airflow.models import DagRun
  567. # Given the historic order of this function (execution_date was first argument) to add a new optional
  568. # param we need to add default values for everything :(
  569. if dag_id is None:
  570. raise TypeError("clear() missing required argument: dag_id")
  571. if task_id is None:
  572. raise TypeError("clear() missing required argument: task_id")
  573. if not exactly_one(execution_date is not None, run_id is not None):
  574. raise ValueError(
  575. f"Exactly one of run_id or execution_date must be passed. "
  576. f"Passed execution_date={execution_date}, run_id={run_id}"
  577. )
  578. if execution_date is not None:
  579. message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
  580. warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
  581. run_id = (
  582. session.query(DagRun.run_id)
  583. .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
  584. .scalar()
  585. )
  586. query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
  587. if map_index is not None:
  588. query = query.filter_by(map_index=map_index)
  589. for xcom in query:
  590. # print(f"Clearing XCOM {xcom} with value {xcom.value}")
  591. XCom.purge(xcom, session)
  592. session.delete(xcom)
  593. session.commit()
  594. @staticmethod
  595. def serialize_value(
  596. value: Any,
  597. *,
  598. key: str | None = None,
  599. task_id: str | None = None,
  600. dag_id: str | None = None,
  601. run_id: str | None = None,
  602. map_index: int | None = None,
  603. ) -> Any:
  604. """Serialize XCom value to str or pickled object."""
  605. if conf.getboolean("core", "enable_xcom_pickling"):
  606. return pickle.dumps(value)
  607. try:
  608. return json.dumps(value, cls=XComEncoder).encode("UTF-8")
  609. except (ValueError, TypeError) as ex:
  610. log.error(
  611. "%s."
  612. " If you are using pickle instead of JSON for XCom,"
  613. " then you need to enable pickle support for XCom"
  614. " in your airflow config or make sure to decorate your"
  615. " object with attr.",
  616. ex,
  617. )
  618. raise
  619. @staticmethod
  620. def _deserialize_value(result: XCom, orm: bool) -> Any:
  621. object_hook = None
  622. if orm:
  623. object_hook = XComDecoder.orm_object_hook
  624. if result.value is None:
  625. return None
  626. if conf.getboolean("core", "enable_xcom_pickling"):
  627. try:
  628. return pickle.loads(result.value)
  629. except pickle.UnpicklingError:
  630. return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
  631. else:
  632. # Since xcom_pickling is disabled, we should only try to deserialize with JSON
  633. return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
  634. @staticmethod
  635. def deserialize_value(result: XCom) -> Any:
  636. """Deserialize XCom value from str or pickle object."""
  637. return BaseXCom._deserialize_value(result, False)
  638. def orm_deserialize_value(self) -> Any:
  639. """
  640. Deserialize method which is used to reconstruct ORM XCom object.
  641. This method should be overridden in custom XCom backends to avoid
  642. unnecessary request or other resource consuming operations when
  643. creating XCom orm model. This is used when viewing XCom listing
  644. in the webserver, for example.
  645. """
  646. return BaseXCom._deserialize_value(self, True)
  647. class LazyXComSelectSequence(LazySelectSequence[Any]):
  648. """
  649. List-like interface to lazily access XCom values.
  650. :meta private:
  651. """
  652. @staticmethod
  653. def _rebuild_select(stmt: TextClause) -> Select:
  654. return select(XCom.value).from_statement(stmt)
  655. @staticmethod
  656. def _process_row(row: Row) -> Any:
  657. return XCom.deserialize_value(row)
  658. def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
  659. """
  660. Patch a custom ``serialize_value`` to accept the modern signature.
  661. To give custom XCom backends more flexibility with how they store values, we
  662. now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
  663. In order to maintain compatibility with custom XCom backends written with
  664. the old signature, we check the signature and, if necessary, patch with a
  665. method that ignores kwargs the backend does not accept.
  666. """
  667. old_serializer = clazz.serialize_value
  668. @wraps(old_serializer)
  669. def _shim(**kwargs):
  670. kwargs = {k: kwargs.get(k) for k in params}
  671. warnings.warn(
  672. f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and"
  673. f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be "
  674. f"removed in a future release.",
  675. RemovedInAirflow3Warning,
  676. stacklevel=1,
  677. )
  678. return old_serializer(**kwargs)
  679. clazz.serialize_value = _shim # type: ignore[assignment]
  680. def _get_function_params(function) -> list[str]:
  681. """
  682. Return the list of variables names of a function.
  683. :param function: The function to inspect
  684. """
  685. parameters = inspect.signature(function).parameters
  686. bound_arguments = [
  687. name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
  688. ]
  689. return bound_arguments
  690. def resolve_xcom_backend() -> type[BaseXCom]:
  691. """
  692. Resolve custom XCom class.
  693. Confirm that custom XCom class extends the BaseXCom.
  694. Compare the function signature of the custom XCom serialize_value to the base XCom serialize_value.
  695. """
  696. clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
  697. if not clazz:
  698. return BaseXCom
  699. if not issubclass(clazz, BaseXCom):
  700. raise TypeError(
  701. f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
  702. )
  703. base_xcom_params = _get_function_params(BaseXCom.serialize_value)
  704. xcom_params = _get_function_params(clazz.serialize_value)
  705. if set(base_xcom_params) != set(xcom_params):
  706. _patch_outdated_serializer(clazz=clazz, params=xcom_params)
  707. return clazz
  708. if TYPE_CHECKING:
  709. XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type".
  710. else:
  711. XCom = resolve_xcom_backend()