connection_command.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. """Connection sub-commands."""
  18. from __future__ import annotations
  19. import json
  20. import os
  21. import warnings
  22. from pathlib import Path
  23. from typing import Any
  24. from urllib.parse import urlsplit, urlunsplit
  25. from sqlalchemy import select
  26. from sqlalchemy.orm import exc
  27. from airflow.cli.simple_table import AirflowConsole
  28. from airflow.cli.utils import is_stdout, print_export_output
  29. from airflow.compat.functools import cache
  30. from airflow.configuration import conf
  31. from airflow.exceptions import AirflowNotFoundException
  32. from airflow.hooks.base import BaseHook
  33. from airflow.models import Connection
  34. from airflow.providers_manager import ProvidersManager
  35. from airflow.secrets.local_filesystem import load_connections_dict
  36. from airflow.utils import cli as cli_utils, helpers, yaml
  37. from airflow.utils.cli import suppress_logs_and_warning
  38. from airflow.utils.db import create_default_connections as db_create_default_connections
  39. from airflow.utils.providers_configuration_loader import providers_configuration_loaded
  40. from airflow.utils.session import create_session
  41. def _connection_mapper(conn: Connection) -> dict[str, Any]:
  42. return {
  43. "id": conn.id,
  44. "conn_id": conn.conn_id,
  45. "conn_type": conn.conn_type,
  46. "description": conn.description,
  47. "host": conn.host,
  48. "schema": conn.schema,
  49. "login": conn.login,
  50. "password": conn.password,
  51. "port": conn.port,
  52. "is_encrypted": conn.is_encrypted,
  53. "is_extra_encrypted": conn.is_encrypted,
  54. "extra_dejson": conn.extra_dejson,
  55. "get_uri": conn.get_uri(),
  56. }
  57. @suppress_logs_and_warning
  58. @providers_configuration_loaded
  59. def connections_get(args):
  60. """Get a connection."""
  61. try:
  62. conn = BaseHook.get_connection(args.conn_id)
  63. except AirflowNotFoundException:
  64. raise SystemExit("Connection not found.")
  65. AirflowConsole().print_as(
  66. data=[conn],
  67. output=args.output,
  68. mapper=_connection_mapper,
  69. )
  70. @suppress_logs_and_warning
  71. @providers_configuration_loaded
  72. def connections_list(args):
  73. """List all connections at the command line."""
  74. with create_session() as session:
  75. query = select(Connection)
  76. if args.conn_id:
  77. query = query.where(Connection.conn_id == args.conn_id)
  78. query = session.scalars(query)
  79. conns = query.all()
  80. AirflowConsole().print_as(
  81. data=conns,
  82. output=args.output,
  83. mapper=_connection_mapper,
  84. )
  85. def _connection_to_dict(conn: Connection) -> dict:
  86. return {
  87. "conn_type": conn.conn_type,
  88. "description": conn.description,
  89. "login": conn.login,
  90. "password": conn.password,
  91. "host": conn.host,
  92. "port": conn.port,
  93. "schema": conn.schema,
  94. "extra": conn.extra,
  95. }
  96. def create_default_connections(args):
  97. db_create_default_connections()
  98. def _format_connections(conns: list[Connection], file_format: str, serialization_format: str) -> str:
  99. if serialization_format == "json":
  100. def serializer_func(x):
  101. return json.dumps(_connection_to_dict(x))
  102. elif serialization_format == "uri":
  103. serializer_func = Connection.get_uri
  104. else:
  105. raise SystemExit(f"Received unexpected value for `--serialization-format`: {serialization_format!r}")
  106. if file_format == ".env":
  107. connections_env = ""
  108. for conn in conns:
  109. connections_env += f"{conn.conn_id}={serializer_func(conn)}\n"
  110. return connections_env
  111. connections_dict = {}
  112. for conn in conns:
  113. connections_dict[conn.conn_id] = _connection_to_dict(conn)
  114. if file_format == ".yaml":
  115. return yaml.dump(connections_dict)
  116. if file_format == ".json":
  117. return json.dumps(connections_dict, indent=2)
  118. return json.dumps(connections_dict)
  119. def _valid_uri(uri: str) -> bool:
  120. """Check if a URI is valid, by checking if scheme (conn_type) provided."""
  121. return urlsplit(uri).scheme != ""
  122. @cache
  123. def _get_connection_types() -> list[str]:
  124. """Return connection types available."""
  125. _connection_types = []
  126. providers_manager = ProvidersManager()
  127. for connection_type, provider_info in providers_manager.hooks.items():
  128. if provider_info:
  129. _connection_types.append(connection_type)
  130. return _connection_types
  131. @providers_configuration_loaded
  132. def connections_export(args):
  133. """Export all connections to a file."""
  134. file_formats = [".yaml", ".json", ".env"]
  135. if args.format:
  136. warnings.warn(
  137. "Option `--format` is deprecated. Use `--file-format` instead.", DeprecationWarning, stacklevel=3
  138. )
  139. if args.format and args.file_format:
  140. raise SystemExit("Option `--format` is deprecated. Use `--file-format` instead.")
  141. default_format = ".json"
  142. provided_file_format = None
  143. if args.format or args.file_format:
  144. provided_file_format = f".{(args.format or args.file_format).lower()}"
  145. with args.file as f:
  146. if is_stdout(f):
  147. filetype = provided_file_format or default_format
  148. elif provided_file_format:
  149. filetype = provided_file_format
  150. else:
  151. filetype = Path(args.file.name).suffix.lower()
  152. if filetype not in file_formats:
  153. raise SystemExit(
  154. f"Unsupported file format. The file must have the extension {', '.join(file_formats)}."
  155. )
  156. if args.serialization_format and filetype != ".env":
  157. raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")
  158. with create_session() as session:
  159. connections = session.scalars(select(Connection).order_by(Connection.conn_id)).all()
  160. msg = _format_connections(
  161. conns=connections,
  162. file_format=filetype,
  163. serialization_format=args.serialization_format or "uri",
  164. )
  165. f.write(msg)
  166. print_export_output("Connections", connections, f)
  167. alternative_conn_specs = ["conn_type", "conn_host", "conn_login", "conn_password", "conn_schema", "conn_port"]
  168. @cli_utils.action_cli
  169. @providers_configuration_loaded
  170. def connections_add(args):
  171. """Add new connection."""
  172. has_uri = bool(args.conn_uri)
  173. has_json = bool(args.conn_json)
  174. has_type = bool(args.conn_type)
  175. # Validate connection-id
  176. try:
  177. helpers.validate_key(args.conn_id, max_length=200)
  178. except Exception as e:
  179. raise SystemExit(f"Could not create connection. {e}")
  180. if not has_type and not (has_json or has_uri):
  181. raise SystemExit("Must supply either conn-uri or conn-json if not supplying conn-type")
  182. if has_json and has_uri:
  183. raise SystemExit("Cannot supply both conn-uri and conn-json")
  184. if has_type and args.conn_type not in _get_connection_types():
  185. warnings.warn(
  186. f"The type provided to --conn-type is invalid: {args.conn_type}", UserWarning, stacklevel=4
  187. )
  188. warnings.warn(
  189. f"Supported --conn-types are:{_get_connection_types()}."
  190. "Hence overriding the conn-type with generic",
  191. UserWarning,
  192. stacklevel=4,
  193. )
  194. args.conn_type = "generic"
  195. if has_uri or has_json:
  196. invalid_args = []
  197. if has_uri and not _valid_uri(args.conn_uri):
  198. raise SystemExit(f"The URI provided to --conn-uri is invalid: {args.conn_uri}")
  199. for arg in alternative_conn_specs:
  200. if getattr(args, arg) is not None:
  201. invalid_args.append(arg)
  202. if has_json and args.conn_extra:
  203. invalid_args.append("--conn-extra")
  204. if invalid_args:
  205. raise SystemExit(
  206. "The following args are not compatible with "
  207. f"the --conn-{'uri' if has_uri else 'json'} flag: {invalid_args!r}"
  208. )
  209. if args.conn_uri:
  210. new_conn = Connection(conn_id=args.conn_id, description=args.conn_description, uri=args.conn_uri)
  211. if args.conn_extra is not None:
  212. new_conn.set_extra(args.conn_extra)
  213. elif args.conn_json:
  214. new_conn = Connection.from_json(conn_id=args.conn_id, value=args.conn_json)
  215. if not new_conn.conn_type:
  216. raise SystemExit("conn-json is invalid; must supply conn-type")
  217. else:
  218. new_conn = Connection(
  219. conn_id=args.conn_id,
  220. conn_type=args.conn_type,
  221. description=args.conn_description,
  222. host=args.conn_host,
  223. login=args.conn_login,
  224. password=args.conn_password,
  225. schema=args.conn_schema,
  226. port=args.conn_port,
  227. )
  228. if args.conn_extra is not None:
  229. new_conn.set_extra(args.conn_extra)
  230. with create_session() as session:
  231. if not session.scalar(select(Connection).where(Connection.conn_id == new_conn.conn_id).limit(1)):
  232. session.add(new_conn)
  233. msg = "Successfully added `conn_id`={conn_id} : {uri}"
  234. msg = msg.format(
  235. conn_id=new_conn.conn_id,
  236. uri=args.conn_uri
  237. or urlunsplit(
  238. (
  239. new_conn.conn_type,
  240. f"{new_conn.login or ''}:{'******' if new_conn.password else ''}"
  241. f"@{new_conn.host or ''}:{new_conn.port or ''}",
  242. new_conn.schema or "",
  243. "",
  244. "",
  245. )
  246. ),
  247. )
  248. print(msg)
  249. else:
  250. msg = f"A connection with `conn_id`={new_conn.conn_id} already exists."
  251. raise SystemExit(msg)
  252. @cli_utils.action_cli
  253. @providers_configuration_loaded
  254. def connections_delete(args):
  255. """Delete connection from DB."""
  256. with create_session() as session:
  257. try:
  258. to_delete = session.scalars(select(Connection).where(Connection.conn_id == args.conn_id)).one()
  259. except exc.NoResultFound:
  260. raise SystemExit(f"Did not find a connection with `conn_id`={args.conn_id}")
  261. except exc.MultipleResultsFound:
  262. raise SystemExit(f"Found more than one connection with `conn_id`={args.conn_id}")
  263. else:
  264. session.delete(to_delete)
  265. print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")
  266. @cli_utils.action_cli(check_db=False)
  267. @providers_configuration_loaded
  268. def connections_import(args):
  269. """Import connections from a file."""
  270. if os.path.exists(args.file):
  271. _import_helper(args.file, args.overwrite)
  272. else:
  273. raise SystemExit("Missing connections file.")
  274. def _import_helper(file_path: str, overwrite: bool) -> None:
  275. """
  276. Load connections from a file and save them to the DB.
  277. :param overwrite: Whether to skip or overwrite on collision.
  278. """
  279. connections_dict = load_connections_dict(file_path)
  280. with create_session() as session:
  281. for conn_id, conn in connections_dict.items():
  282. try:
  283. helpers.validate_key(conn_id, max_length=200)
  284. except Exception as e:
  285. print(f"Could not import connection. {e}")
  286. continue
  287. existing_conn_id = session.scalar(select(Connection.id).where(Connection.conn_id == conn_id))
  288. if existing_conn_id is not None:
  289. if not overwrite:
  290. print(f"Could not import connection {conn_id}: connection already exists.")
  291. continue
  292. # The conn_ids match, but the PK of the new entry must also be the same as the old
  293. conn.id = existing_conn_id
  294. session.merge(conn)
  295. session.commit()
  296. print(f"Imported connection {conn_id}")
  297. @suppress_logs_and_warning
  298. @providers_configuration_loaded
  299. def connections_test(args) -> None:
  300. """Test an Airflow connection."""
  301. console = AirflowConsole()
  302. if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled":
  303. console.print(
  304. "[bold yellow]\nTesting connections is disabled in Airflow configuration. "
  305. "Contact your deployment admin to enable it.\n"
  306. )
  307. raise SystemExit(1)
  308. print(f"Retrieving connection: {args.conn_id!r}")
  309. try:
  310. conn = BaseHook.get_connection(args.conn_id)
  311. except AirflowNotFoundException:
  312. console.print("[bold yellow]\nConnection not found.\n")
  313. raise SystemExit(1)
  314. print("\nTesting...")
  315. status, message = conn.test_connection()
  316. if status is True:
  317. console.print("[bold green]\nConnection success!\n")
  318. else:
  319. console.print(f"[bold][red]\nConnection failed![/bold]\n{message}\n")