123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Table to store information about mapped task instances (AIP-42)."""
- from __future__ import annotations
- import collections.abc
- import enum
- from typing import TYPE_CHECKING, Any, Collection
- from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String
- from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
- from airflow.utils.sqlalchemy import ExtendedJSON
- if TYPE_CHECKING:
- from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- class TaskMapVariant(enum.Enum):
- """
- Task map variant.
- Possible values are **dict** (for a key-value mapping) and **list** (for an
- ordered value sequence).
- """
- DICT = "dict"
- LIST = "list"
- class TaskMap(TaskInstanceDependencies):
- """
- Model to track dynamic task-mapping information.
- This is currently only populated by an upstream TaskInstance pushing an
- XCom that's pulled by a downstream for mapping purposes.
- """
- __tablename__ = "task_map"
- # Link to upstream TaskInstance creating this dynamic mapping information.
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
- task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
- map_index = Column(Integer, primary_key=True)
- length = Column(Integer, nullable=False)
- keys = Column(ExtendedJSON, nullable=True)
- __table_args__ = (
- CheckConstraint(length >= 0, name="task_map_length_not_negative"),
- ForeignKeyConstraint(
- [dag_id, task_id, run_id, map_index],
- [
- "task_instance.dag_id",
- "task_instance.task_id",
- "task_instance.run_id",
- "task_instance.map_index",
- ],
- name="task_map_task_instance_fkey",
- ondelete="CASCADE",
- onupdate="CASCADE",
- ),
- )
- def __init__(
- self,
- dag_id: str,
- task_id: str,
- run_id: str,
- map_index: int,
- length: int,
- keys: list[Any] | None,
- ) -> None:
- self.dag_id = dag_id
- self.task_id = task_id
- self.run_id = run_id
- self.map_index = map_index
- self.length = length
- self.keys = keys
- @classmethod
- def from_task_instance_xcom(cls, ti: TaskInstance | TaskInstancePydantic, value: Collection) -> TaskMap:
- if ti.run_id is None:
- raise ValueError("cannot record task map for unrun task instance")
- return cls(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- run_id=ti.run_id,
- map_index=ti.map_index,
- length=len(value),
- keys=(list(value) if isinstance(value, collections.abc.Mapping) else None),
- )
- @property
- def variant(self) -> TaskMapVariant:
- if self.keys is None:
- return TaskMapVariant.LIST
- return TaskMapVariant.DICT
|