pool_command.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Pools sub-commands."""
  19. from __future__ import annotations
  20. import json
  21. import os
  22. from json import JSONDecodeError
  23. from airflow.api.client import get_current_api_client
  24. from airflow.cli.simple_table import AirflowConsole
  25. from airflow.exceptions import PoolNotFound
  26. from airflow.utils import cli as cli_utils
  27. from airflow.utils.cli import suppress_logs_and_warning
  28. from airflow.utils.providers_configuration_loader import providers_configuration_loaded
  29. def _show_pools(pools, output):
  30. AirflowConsole().print_as(
  31. data=pools,
  32. output=output,
  33. mapper=lambda x: {
  34. "pool": x[0],
  35. "slots": x[1],
  36. "description": x[2],
  37. "include_deferred": x[3],
  38. },
  39. )
  40. @suppress_logs_and_warning
  41. @providers_configuration_loaded
  42. def pool_list(args):
  43. """Display info of all the pools."""
  44. api_client = get_current_api_client()
  45. pools = api_client.get_pools()
  46. _show_pools(pools=pools, output=args.output)
  47. @suppress_logs_and_warning
  48. @providers_configuration_loaded
  49. def pool_get(args):
  50. """Display pool info by a given name."""
  51. api_client = get_current_api_client()
  52. try:
  53. pools = [api_client.get_pool(name=args.pool)]
  54. _show_pools(pools=pools, output=args.output)
  55. except PoolNotFound:
  56. raise SystemExit(f"Pool {args.pool} does not exist")
  57. @cli_utils.action_cli
  58. @suppress_logs_and_warning
  59. @providers_configuration_loaded
  60. def pool_set(args):
  61. """Create new pool with a given name and slots."""
  62. api_client = get_current_api_client()
  63. api_client.create_pool(
  64. name=args.pool, slots=args.slots, description=args.description, include_deferred=args.include_deferred
  65. )
  66. print(f"Pool {args.pool} created")
  67. @cli_utils.action_cli
  68. @suppress_logs_and_warning
  69. @providers_configuration_loaded
  70. def pool_delete(args):
  71. """Delete pool by a given name."""
  72. api_client = get_current_api_client()
  73. try:
  74. api_client.delete_pool(name=args.pool)
  75. print(f"Pool {args.pool} deleted")
  76. except PoolNotFound:
  77. raise SystemExit(f"Pool {args.pool} does not exist")
  78. @cli_utils.action_cli
  79. @suppress_logs_and_warning
  80. @providers_configuration_loaded
  81. def pool_import(args):
  82. """Import pools from the file."""
  83. if not os.path.exists(args.file):
  84. raise SystemExit(f"Missing pools file {args.file}")
  85. pools, failed = pool_import_helper(args.file)
  86. if failed:
  87. raise SystemExit(f"Failed to update pool(s): {', '.join(failed)}")
  88. print(f"Uploaded {len(pools)} pool(s)")
  89. @providers_configuration_loaded
  90. def pool_export(args):
  91. """Export all the pools to the file."""
  92. pools = pool_export_helper(args.file)
  93. print(f"Exported {len(pools)} pools to {args.file}")
  94. def pool_import_helper(filepath):
  95. """Help import pools from the json file."""
  96. api_client = get_current_api_client()
  97. with open(filepath) as poolfile:
  98. data = poolfile.read()
  99. try:
  100. pools_json = json.loads(data)
  101. except JSONDecodeError as e:
  102. raise SystemExit(f"Invalid json file: {e}")
  103. pools = []
  104. failed = []
  105. for k, v in pools_json.items():
  106. if isinstance(v, dict) and "slots" in v and "description" in v:
  107. pools.append(
  108. api_client.create_pool(
  109. name=k,
  110. slots=v["slots"],
  111. description=v["description"],
  112. include_deferred=v.get("include_deferred", False),
  113. )
  114. )
  115. else:
  116. failed.append(k)
  117. return pools, failed
  118. def pool_export_helper(filepath):
  119. """Help export all the pools to the json file."""
  120. api_client = get_current_api_client()
  121. pool_dict = {}
  122. pools = api_client.get_pools()
  123. for pool in pools:
  124. pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2], "include_deferred": pool[3]}
  125. with open(filepath, "w") as poolfile:
  126. poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
  127. return pools