variable_command.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Variable subcommands."""
  19. from __future__ import annotations
  20. import json
  21. import os
  22. from json import JSONDecodeError
  23. from sqlalchemy import select
  24. from airflow.cli.simple_table import AirflowConsole
  25. from airflow.cli.utils import print_export_output
  26. from airflow.models import Variable
  27. from airflow.utils import cli as cli_utils
  28. from airflow.utils.cli import suppress_logs_and_warning
  29. from airflow.utils.providers_configuration_loader import providers_configuration_loaded
  30. from airflow.utils.session import create_session, provide_session
  31. @suppress_logs_and_warning
  32. @providers_configuration_loaded
  33. def variables_list(args):
  34. """Display all the variables."""
  35. with create_session() as session:
  36. variables = session.scalars(select(Variable)).all()
  37. AirflowConsole().print_as(data=variables, output=args.output, mapper=lambda x: {"key": x.key})
  38. @suppress_logs_and_warning
  39. @providers_configuration_loaded
  40. def variables_get(args):
  41. """Display variable by a given name."""
  42. try:
  43. if args.default is None:
  44. var = Variable.get(args.key, deserialize_json=args.json)
  45. print(var)
  46. else:
  47. var = Variable.get(args.key, deserialize_json=args.json, default_var=args.default)
  48. print(var)
  49. except (ValueError, KeyError) as e:
  50. raise SystemExit(str(e).strip("'\""))
  51. @cli_utils.action_cli
  52. @providers_configuration_loaded
  53. def variables_set(args):
  54. """Create new variable with a given name, value and description."""
  55. Variable.set(args.key, args.value, args.description, serialize_json=args.json)
  56. print(f"Variable {args.key} created")
  57. @cli_utils.action_cli
  58. @providers_configuration_loaded
  59. def variables_delete(args):
  60. """Delete variable by a given name."""
  61. Variable.delete(args.key)
  62. print(f"Variable {args.key} deleted")
  63. @cli_utils.action_cli
  64. @providers_configuration_loaded
  65. @provide_session
  66. def variables_import(args, session):
  67. """Import variables from a given file."""
  68. if not os.path.exists(args.file):
  69. raise SystemExit("Missing variables file.")
  70. with open(args.file) as varfile:
  71. try:
  72. var_json = json.load(varfile)
  73. except JSONDecodeError:
  74. raise SystemExit("Invalid variables file.")
  75. suc_count = fail_count = 0
  76. skipped = set()
  77. action_on_existing = args.action_on_existing_key
  78. existing_keys = set()
  79. if action_on_existing != "overwrite":
  80. existing_keys = set(session.scalars(select(Variable.key).where(Variable.key.in_(var_json))))
  81. if action_on_existing == "fail" and existing_keys:
  82. raise SystemExit(f"Failed. These keys: {sorted(existing_keys)} already exists.")
  83. for k, v in var_json.items():
  84. if action_on_existing == "skip" and k in existing_keys:
  85. skipped.add(k)
  86. continue
  87. try:
  88. Variable.set(k, v, serialize_json=not isinstance(v, str))
  89. except Exception as e:
  90. print(f"Variable import failed: {e!r}")
  91. fail_count += 1
  92. else:
  93. suc_count += 1
  94. print(f"{suc_count} of {len(var_json)} variables successfully updated.")
  95. if fail_count:
  96. print(f"{fail_count} variable(s) failed to be updated.")
  97. if skipped:
  98. print(
  99. f"The variables with these keys: {list(sorted(skipped))} "
  100. f"were skipped because they already exists"
  101. )
  102. @providers_configuration_loaded
  103. def variables_export(args):
  104. """Export all the variables to the file."""
  105. var_dict = {}
  106. with create_session() as session:
  107. qry = session.scalars(select(Variable))
  108. data = json.JSONDecoder()
  109. for var in qry:
  110. try:
  111. val = data.decode(var.val)
  112. except Exception:
  113. val = var.val
  114. var_dict[var.key] = val
  115. with args.file as varfile:
  116. json.dump(var_dict, varfile, sort_keys=True, indent=4)
  117. print_export_output("Variables", var_dict, varfile)