pool.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Pool APIs."""
  19. from __future__ import annotations
  20. from typing import TYPE_CHECKING
  21. from deprecated import deprecated
  22. from sqlalchemy import select
  23. from airflow.exceptions import AirflowBadRequest, PoolNotFound, RemovedInAirflow3Warning
  24. from airflow.models import Pool
  25. from airflow.utils.session import NEW_SESSION, provide_session
  26. if TYPE_CHECKING:
  27. from sqlalchemy.orm import Session
  28. @deprecated(reason="Use Pool.get_pool() instead", version="2.2.4", category=RemovedInAirflow3Warning)
  29. @provide_session
  30. def get_pool(name, session: Session = NEW_SESSION):
  31. """Get pool by a given name."""
  32. if not (name and name.strip()):
  33. raise AirflowBadRequest("Pool name shouldn't be empty")
  34. pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
  35. if pool is None:
  36. raise PoolNotFound(f"Pool '{name}' doesn't exist")
  37. return pool
  38. @deprecated(reason="Use Pool.get_pools() instead", version="2.2.4", category=RemovedInAirflow3Warning)
  39. @provide_session
  40. def get_pools(session: Session = NEW_SESSION):
  41. """Get all pools."""
  42. return session.scalars(select(Pool)).all()
  43. @deprecated(reason="Use Pool.create_pool() instead", version="2.2.4", category=RemovedInAirflow3Warning)
  44. @provide_session
  45. def create_pool(name, slots, description, session: Session = NEW_SESSION):
  46. """Create a pool with given parameters."""
  47. if not (name and name.strip()):
  48. raise AirflowBadRequest("Pool name shouldn't be empty")
  49. try:
  50. slots = int(slots)
  51. except ValueError:
  52. raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
  53. # Get the length of the pool column
  54. pool_name_length = Pool.pool.property.columns[0].type.length
  55. if len(name) > pool_name_length:
  56. raise AirflowBadRequest(f"Pool name can't be more than {pool_name_length} characters")
  57. session.expire_on_commit = False
  58. pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
  59. if pool is None:
  60. pool = Pool(pool=name, slots=slots, description=description, include_deferred=False)
  61. session.add(pool)
  62. else:
  63. pool.slots = slots
  64. pool.description = description
  65. session.commit()
  66. return pool
  67. @deprecated(reason="Use Pool.delete_pool() instead", version="2.2.4", category=RemovedInAirflow3Warning)
  68. @provide_session
  69. def delete_pool(name, session: Session = NEW_SESSION):
  70. """Delete pool by a given name."""
  71. if not (name and name.strip()):
  72. raise AirflowBadRequest("Pool name shouldn't be empty")
  73. if name == Pool.DEFAULT_POOL_NAME:
  74. raise AirflowBadRequest(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
  75. pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
  76. if pool is None:
  77. raise PoolNotFound(f"Pool '{name}' doesn't exist")
  78. session.delete(pool)
  79. session.commit()
  80. return pool