123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- from typing import TYPE_CHECKING, Any
- from sqlalchemy import Boolean, Column, Integer, String, Text, func, select
- from airflow.exceptions import AirflowException, PoolNotFound
- from airflow.models.base import Base
- from airflow.ti_deps.dependencies_states import EXECUTION_STATES
- from airflow.typing_compat import TypedDict
- from airflow.utils.db import exists_query
- from airflow.utils.session import NEW_SESSION, provide_session
- from airflow.utils.sqlalchemy import with_row_locks
- from airflow.utils.state import TaskInstanceState
- if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
- class PoolStats(TypedDict):
- """Dictionary containing Pool Stats."""
- total: int
- running: int
- deferred: int
- queued: int
- open: int
- scheduled: int
- class Pool(Base):
- """the class to get Pool info."""
- __tablename__ = "slot_pool"
- id = Column(Integer, primary_key=True)
- pool = Column(String(256), unique=True)
- # -1 for infinite
- slots = Column(Integer, default=0)
- description = Column(Text)
- include_deferred = Column(Boolean, nullable=False)
- DEFAULT_POOL_NAME = "default_pool"
- def __repr__(self):
- return str(self.pool)
- @staticmethod
- @provide_session
- def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
- """Get all pools."""
- return session.scalars(select(Pool)).all()
- @staticmethod
- @provide_session
- def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None:
- """
- Get the Pool with specific pool name from the Pools.
- :param pool_name: The pool name of the Pool to get.
- :param session: SQLAlchemy ORM Session
- :return: the pool object
- """
- return session.scalar(select(Pool).where(Pool.pool == pool_name))
- @staticmethod
- @provide_session
- def get_default_pool(session: Session = NEW_SESSION) -> Pool | None:
- """
- Get the Pool of the default_pool from the Pools.
- :param session: SQLAlchemy ORM Session
- :return: the pool object
- """
- return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)
- @staticmethod
- @provide_session
- def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
- """
- Check id if is the default_pool.
- :param id: pool id
- :param session: SQLAlchemy ORM Session
- :return: True if id is default_pool, otherwise False
- """
- return exists_query(
- Pool.id == id,
- Pool.pool == Pool.DEFAULT_POOL_NAME,
- session=session,
- )
- @staticmethod
- @provide_session
- def create_or_update_pool(
- name: str,
- slots: int,
- description: str,
- include_deferred: bool,
- session: Session = NEW_SESSION,
- ) -> Pool:
- """Create a pool with given parameters or update it if it already exists."""
- if not name:
- raise ValueError("Pool name must not be empty")
- pool = session.scalar(select(Pool).filter_by(pool=name))
- if pool is None:
- pool = Pool(pool=name, slots=slots, description=description, include_deferred=include_deferred)
- session.add(pool)
- else:
- pool.slots = slots
- pool.description = description
- pool.include_deferred = include_deferred
- session.commit()
- return pool
- @staticmethod
- @provide_session
- def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool:
- """Delete pool by a given name."""
- if name == Pool.DEFAULT_POOL_NAME:
- raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
- pool = session.scalar(select(Pool).filter_by(pool=name))
- if pool is None:
- raise PoolNotFound(f"Pool '{name}' doesn't exist")
- session.delete(pool)
- session.commit()
- return pool
- @staticmethod
- @provide_session
- def slots_stats(
- *,
- lock_rows: bool = False,
- session: Session = NEW_SESSION,
- ) -> dict[str, PoolStats]:
- """
- Get Pool stats (Number of Running, Queued, Open & Total tasks).
- If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
- non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
- OperationalError.
- :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
- :param session: SQLAlchemy ORM Session
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- pools: dict[str, PoolStats] = {}
- pool_includes_deferred: dict[str, bool] = {}
- query = select(Pool.pool, Pool.slots, Pool.include_deferred)
- if lock_rows:
- query = with_row_locks(query, session=session, nowait=True)
- pool_rows = session.execute(query)
- for pool_name, total_slots, include_deferred in pool_rows:
- if total_slots == -1:
- total_slots = float("inf") # type: ignore
- pools[pool_name] = PoolStats(
- total=total_slots, running=0, queued=0, open=0, deferred=0, scheduled=0
- )
- pool_includes_deferred[pool_name] = include_deferred
- allowed_execution_states = EXECUTION_STATES | {
- TaskInstanceState.DEFERRED,
- TaskInstanceState.SCHEDULED,
- }
- state_count_by_pool = session.execute(
- select(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.state.in_(allowed_execution_states))
- .group_by(TaskInstance.pool, TaskInstance.state)
- )
- # calculate queued and running metrics
- for pool_name, state, count in state_count_by_pool:
- # Some databases return decimal.Decimal here.
- count = int(count)
- stats_dict: PoolStats | None = pools.get(pool_name)
- if not stats_dict:
- continue
- # TypedDict key must be a string literal, so we use if-statements to set value
- if state == TaskInstanceState.RUNNING:
- stats_dict["running"] = count
- elif state == TaskInstanceState.QUEUED:
- stats_dict["queued"] = count
- elif state == TaskInstanceState.DEFERRED:
- stats_dict["deferred"] = count
- elif state == TaskInstanceState.SCHEDULED:
- stats_dict["scheduled"] = count
- else:
- raise AirflowException(f"Unexpected state. Expected values: {allowed_execution_states}.")
- # calculate open metric
- for pool_name, stats_dict in pools.items():
- stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
- if pool_includes_deferred[pool_name]:
- stats_dict["open"] -= stats_dict["deferred"]
- return pools
- def to_json(self) -> dict[str, Any]:
- """
- Get the Pool in a json structure.
- :return: the pool object in json format
- """
- return {
- "id": self.id,
- "pool": self.pool,
- "slots": self.slots,
- "description": self.description,
- "include_deferred": self.include_deferred,
- }
- @provide_session
- def occupied_slots(self, session: Session = NEW_SESSION) -> int:
- """
- Get the number of slots used by running/queued tasks at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the used number of slots
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- occupied_states = self.get_occupied_states()
- return int(
- session.scalar(
- select(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state.in_(occupied_states))
- )
- or 0
- )
- def get_occupied_states(self):
- if self.include_deferred:
- return EXECUTION_STATES | {
- TaskInstanceState.DEFERRED,
- }
- return EXECUTION_STATES
- @provide_session
- def running_slots(self, session: Session = NEW_SESSION) -> int:
- """
- Get the number of slots used by running tasks at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the used number of slots
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- return int(
- session.scalar(
- select(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == TaskInstanceState.RUNNING)
- )
- or 0
- )
- @provide_session
- def queued_slots(self, session: Session = NEW_SESSION) -> int:
- """
- Get the number of slots used by queued tasks at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the used number of slots
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- return int(
- session.scalar(
- select(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == TaskInstanceState.QUEUED)
- )
- or 0
- )
- @provide_session
- def scheduled_slots(self, session: Session = NEW_SESSION) -> int:
- """
- Get the number of slots scheduled at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the number of scheduled slots
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- return int(
- session.scalar(
- select(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
- )
- or 0
- )
- @provide_session
- def deferred_slots(self, session: Session = NEW_SESSION) -> int:
- """
- Get the number of slots deferred at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the number of deferred slots
- """
- from airflow.models.taskinstance import TaskInstance # Avoid circular import
- return int(
- session.scalar(
- select(func.sum(TaskInstance.pool_slots)).where(
- TaskInstance.pool == self.pool, TaskInstance.state == TaskInstanceState.DEFERRED
- )
- )
- or 0
- )
- @provide_session
- def open_slots(self, session: Session = NEW_SESSION) -> float:
- """
- Get the number of slots open at the moment.
- :param session: SQLAlchemy ORM Session
- :return: the number of slots
- """
- if self.slots == -1:
- return float("inf")
- return self.slots - self.occupied_slots(session)
|