taskmap.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Table to store information about mapped task instances (AIP-42)."""
  19. from __future__ import annotations
  20. import collections.abc
  21. import enum
  22. from typing import TYPE_CHECKING, Any, Collection
  23. from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String
  24. from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
  25. from airflow.utils.sqlalchemy import ExtendedJSON
  26. if TYPE_CHECKING:
  27. from airflow.models.taskinstance import TaskInstance
  28. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  29. class TaskMapVariant(enum.Enum):
  30. """
  31. Task map variant.
  32. Possible values are **dict** (for a key-value mapping) and **list** (for an
  33. ordered value sequence).
  34. """
  35. DICT = "dict"
  36. LIST = "list"
  37. class TaskMap(TaskInstanceDependencies):
  38. """
  39. Model to track dynamic task-mapping information.
  40. This is currently only populated by an upstream TaskInstance pushing an
  41. XCom that's pulled by a downstream for mapping purposes.
  42. """
  43. __tablename__ = "task_map"
  44. # Link to upstream TaskInstance creating this dynamic mapping information.
  45. dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
  46. task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
  47. run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
  48. map_index = Column(Integer, primary_key=True)
  49. length = Column(Integer, nullable=False)
  50. keys = Column(ExtendedJSON, nullable=True)
  51. __table_args__ = (
  52. CheckConstraint(length >= 0, name="task_map_length_not_negative"),
  53. ForeignKeyConstraint(
  54. [dag_id, task_id, run_id, map_index],
  55. [
  56. "task_instance.dag_id",
  57. "task_instance.task_id",
  58. "task_instance.run_id",
  59. "task_instance.map_index",
  60. ],
  61. name="task_map_task_instance_fkey",
  62. ondelete="CASCADE",
  63. onupdate="CASCADE",
  64. ),
  65. )
  66. def __init__(
  67. self,
  68. dag_id: str,
  69. task_id: str,
  70. run_id: str,
  71. map_index: int,
  72. length: int,
  73. keys: list[Any] | None,
  74. ) -> None:
  75. self.dag_id = dag_id
  76. self.task_id = task_id
  77. self.run_id = run_id
  78. self.map_index = map_index
  79. self.length = length
  80. self.keys = keys
  81. @classmethod
  82. def from_task_instance_xcom(cls, ti: TaskInstance | TaskInstancePydantic, value: Collection) -> TaskMap:
  83. if ti.run_id is None:
  84. raise ValueError("cannot record task map for unrun task instance")
  85. return cls(
  86. dag_id=ti.dag_id,
  87. task_id=ti.task_id,
  88. run_id=ti.run_id,
  89. map_index=ti.map_index,
  90. length=len(value),
  91. keys=(list(value) if isinstance(value, collections.abc.Mapping) else None),
  92. )
  93. @property
  94. def variant(self) -> TaskMapVariant:
  95. if self.keys is None:
  96. return TaskMapVariant.LIST
  97. return TaskMapVariant.DICT