pool.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from typing import TYPE_CHECKING, Any
  20. from sqlalchemy import Boolean, Column, Integer, String, Text, func, select
  21. from airflow.exceptions import AirflowException, PoolNotFound
  22. from airflow.models.base import Base
  23. from airflow.ti_deps.dependencies_states import EXECUTION_STATES
  24. from airflow.typing_compat import TypedDict
  25. from airflow.utils.db import exists_query
  26. from airflow.utils.session import NEW_SESSION, provide_session
  27. from airflow.utils.sqlalchemy import with_row_locks
  28. from airflow.utils.state import TaskInstanceState
  29. if TYPE_CHECKING:
  30. from sqlalchemy.orm.session import Session
  31. class PoolStats(TypedDict):
  32. """Dictionary containing Pool Stats."""
  33. total: int
  34. running: int
  35. deferred: int
  36. queued: int
  37. open: int
  38. scheduled: int
  39. class Pool(Base):
  40. """the class to get Pool info."""
  41. __tablename__ = "slot_pool"
  42. id = Column(Integer, primary_key=True)
  43. pool = Column(String(256), unique=True)
  44. # -1 for infinite
  45. slots = Column(Integer, default=0)
  46. description = Column(Text)
  47. include_deferred = Column(Boolean, nullable=False)
  48. DEFAULT_POOL_NAME = "default_pool"
  49. def __repr__(self):
  50. return str(self.pool)
  51. @staticmethod
  52. @provide_session
  53. def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
  54. """Get all pools."""
  55. return session.scalars(select(Pool)).all()
  56. @staticmethod
  57. @provide_session
  58. def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None:
  59. """
  60. Get the Pool with specific pool name from the Pools.
  61. :param pool_name: The pool name of the Pool to get.
  62. :param session: SQLAlchemy ORM Session
  63. :return: the pool object
  64. """
  65. return session.scalar(select(Pool).where(Pool.pool == pool_name))
  66. @staticmethod
  67. @provide_session
  68. def get_default_pool(session: Session = NEW_SESSION) -> Pool | None:
  69. """
  70. Get the Pool of the default_pool from the Pools.
  71. :param session: SQLAlchemy ORM Session
  72. :return: the pool object
  73. """
  74. return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)
  75. @staticmethod
  76. @provide_session
  77. def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
  78. """
  79. Check id if is the default_pool.
  80. :param id: pool id
  81. :param session: SQLAlchemy ORM Session
  82. :return: True if id is default_pool, otherwise False
  83. """
  84. return exists_query(
  85. Pool.id == id,
  86. Pool.pool == Pool.DEFAULT_POOL_NAME,
  87. session=session,
  88. )
  89. @staticmethod
  90. @provide_session
  91. def create_or_update_pool(
  92. name: str,
  93. slots: int,
  94. description: str,
  95. include_deferred: bool,
  96. session: Session = NEW_SESSION,
  97. ) -> Pool:
  98. """Create a pool with given parameters or update it if it already exists."""
  99. if not name:
  100. raise ValueError("Pool name must not be empty")
  101. pool = session.scalar(select(Pool).filter_by(pool=name))
  102. if pool is None:
  103. pool = Pool(pool=name, slots=slots, description=description, include_deferred=include_deferred)
  104. session.add(pool)
  105. else:
  106. pool.slots = slots
  107. pool.description = description
  108. pool.include_deferred = include_deferred
  109. session.commit()
  110. return pool
  111. @staticmethod
  112. @provide_session
  113. def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool:
  114. """Delete pool by a given name."""
  115. if name == Pool.DEFAULT_POOL_NAME:
  116. raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
  117. pool = session.scalar(select(Pool).filter_by(pool=name))
  118. if pool is None:
  119. raise PoolNotFound(f"Pool '{name}' doesn't exist")
  120. session.delete(pool)
  121. session.commit()
  122. return pool
  123. @staticmethod
  124. @provide_session
  125. def slots_stats(
  126. *,
  127. lock_rows: bool = False,
  128. session: Session = NEW_SESSION,
  129. ) -> dict[str, PoolStats]:
  130. """
  131. Get Pool stats (Number of Running, Queued, Open & Total tasks).
  132. If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
  133. non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
  134. OperationalError.
  135. :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
  136. :param session: SQLAlchemy ORM Session
  137. """
  138. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  139. pools: dict[str, PoolStats] = {}
  140. pool_includes_deferred: dict[str, bool] = {}
  141. query = select(Pool.pool, Pool.slots, Pool.include_deferred)
  142. if lock_rows:
  143. query = with_row_locks(query, session=session, nowait=True)
  144. pool_rows = session.execute(query)
  145. for pool_name, total_slots, include_deferred in pool_rows:
  146. if total_slots == -1:
  147. total_slots = float("inf") # type: ignore
  148. pools[pool_name] = PoolStats(
  149. total=total_slots, running=0, queued=0, open=0, deferred=0, scheduled=0
  150. )
  151. pool_includes_deferred[pool_name] = include_deferred
  152. allowed_execution_states = EXECUTION_STATES | {
  153. TaskInstanceState.DEFERRED,
  154. TaskInstanceState.SCHEDULED,
  155. }
  156. state_count_by_pool = session.execute(
  157. select(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
  158. .filter(TaskInstance.state.in_(allowed_execution_states))
  159. .group_by(TaskInstance.pool, TaskInstance.state)
  160. )
  161. # calculate queued and running metrics
  162. for pool_name, state, count in state_count_by_pool:
  163. # Some databases return decimal.Decimal here.
  164. count = int(count)
  165. stats_dict: PoolStats | None = pools.get(pool_name)
  166. if not stats_dict:
  167. continue
  168. # TypedDict key must be a string literal, so we use if-statements to set value
  169. if state == TaskInstanceState.RUNNING:
  170. stats_dict["running"] = count
  171. elif state == TaskInstanceState.QUEUED:
  172. stats_dict["queued"] = count
  173. elif state == TaskInstanceState.DEFERRED:
  174. stats_dict["deferred"] = count
  175. elif state == TaskInstanceState.SCHEDULED:
  176. stats_dict["scheduled"] = count
  177. else:
  178. raise AirflowException(f"Unexpected state. Expected values: {allowed_execution_states}.")
  179. # calculate open metric
  180. for pool_name, stats_dict in pools.items():
  181. stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
  182. if pool_includes_deferred[pool_name]:
  183. stats_dict["open"] -= stats_dict["deferred"]
  184. return pools
  185. def to_json(self) -> dict[str, Any]:
  186. """
  187. Get the Pool in a json structure.
  188. :return: the pool object in json format
  189. """
  190. return {
  191. "id": self.id,
  192. "pool": self.pool,
  193. "slots": self.slots,
  194. "description": self.description,
  195. "include_deferred": self.include_deferred,
  196. }
  197. @provide_session
  198. def occupied_slots(self, session: Session = NEW_SESSION) -> int:
  199. """
  200. Get the number of slots used by running/queued tasks at the moment.
  201. :param session: SQLAlchemy ORM Session
  202. :return: the used number of slots
  203. """
  204. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  205. occupied_states = self.get_occupied_states()
  206. return int(
  207. session.scalar(
  208. select(func.sum(TaskInstance.pool_slots))
  209. .filter(TaskInstance.pool == self.pool)
  210. .filter(TaskInstance.state.in_(occupied_states))
  211. )
  212. or 0
  213. )
  214. def get_occupied_states(self):
  215. if self.include_deferred:
  216. return EXECUTION_STATES | {
  217. TaskInstanceState.DEFERRED,
  218. }
  219. return EXECUTION_STATES
  220. @provide_session
  221. def running_slots(self, session: Session = NEW_SESSION) -> int:
  222. """
  223. Get the number of slots used by running tasks at the moment.
  224. :param session: SQLAlchemy ORM Session
  225. :return: the used number of slots
  226. """
  227. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  228. return int(
  229. session.scalar(
  230. select(func.sum(TaskInstance.pool_slots))
  231. .filter(TaskInstance.pool == self.pool)
  232. .filter(TaskInstance.state == TaskInstanceState.RUNNING)
  233. )
  234. or 0
  235. )
  236. @provide_session
  237. def queued_slots(self, session: Session = NEW_SESSION) -> int:
  238. """
  239. Get the number of slots used by queued tasks at the moment.
  240. :param session: SQLAlchemy ORM Session
  241. :return: the used number of slots
  242. """
  243. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  244. return int(
  245. session.scalar(
  246. select(func.sum(TaskInstance.pool_slots))
  247. .filter(TaskInstance.pool == self.pool)
  248. .filter(TaskInstance.state == TaskInstanceState.QUEUED)
  249. )
  250. or 0
  251. )
  252. @provide_session
  253. def scheduled_slots(self, session: Session = NEW_SESSION) -> int:
  254. """
  255. Get the number of slots scheduled at the moment.
  256. :param session: SQLAlchemy ORM Session
  257. :return: the number of scheduled slots
  258. """
  259. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  260. return int(
  261. session.scalar(
  262. select(func.sum(TaskInstance.pool_slots))
  263. .filter(TaskInstance.pool == self.pool)
  264. .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
  265. )
  266. or 0
  267. )
  268. @provide_session
  269. def deferred_slots(self, session: Session = NEW_SESSION) -> int:
  270. """
  271. Get the number of slots deferred at the moment.
  272. :param session: SQLAlchemy ORM Session
  273. :return: the number of deferred slots
  274. """
  275. from airflow.models.taskinstance import TaskInstance # Avoid circular import
  276. return int(
  277. session.scalar(
  278. select(func.sum(TaskInstance.pool_slots)).where(
  279. TaskInstance.pool == self.pool, TaskInstance.state == TaskInstanceState.DEFERRED
  280. )
  281. )
  282. or 0
  283. )
  284. @provide_session
  285. def open_slots(self, session: Session = NEW_SESSION) -> float:
  286. """
  287. Get the number of slots open at the moment.
  288. :param session: SQLAlchemy ORM Session
  289. :return: the number of slots
  290. """
  291. if self.slots == -1:
  292. return float("inf")
  293. return self.slots - self.occupied_slots(session)