connection.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import json
  20. import logging
  21. import warnings
  22. from contextlib import suppress
  23. from json import JSONDecodeError
  24. from typing import Any
  25. from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
  26. import re2
  27. from sqlalchemy import Boolean, Column, Integer, String, Text
  28. from sqlalchemy.orm import declared_attr, reconstructor, synonym
  29. from airflow.configuration import ensure_secrets_loaded
  30. from airflow.exceptions import AirflowException, AirflowNotFoundException, RemovedInAirflow3Warning
  31. from airflow.models.base import ID_LEN, Base
  32. from airflow.models.crypto import get_fernet
  33. from airflow.secrets.cache import SecretCache
  34. from airflow.utils.helpers import prune_dict
  35. from airflow.utils.log.logging_mixin import LoggingMixin
  36. from airflow.utils.log.secrets_masker import mask_secret
  37. from airflow.utils.module_loading import import_string
  38. log = logging.getLogger(__name__)
  39. # sanitize the `conn_id` pattern by allowing alphanumeric characters plus
  40. # the symbols #,!,-,_,.,:,\,/ and () requiring at least one match.
  41. #
  42. # You can try the regex here: https://regex101.com/r/69033B/1
  43. RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$")
  44. # the conn ID max len should be 250
  45. CONN_ID_MAX_LEN: int = 250
  46. def parse_netloc_to_hostname(*args, **kwargs):
  47. """Do not use, this method is deprecated."""
  48. warnings.warn("This method is deprecated.", RemovedInAirflow3Warning, stacklevel=2)
  49. return _parse_netloc_to_hostname(*args, **kwargs)
  50. def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | None:
  51. r"""
  52. Sanitizes the connection id and allows only specific characters to be within.
  53. Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ and () from 1 and up to
  54. 250 consecutive matches. If desired, the max length can be adjusted by setting `max_length`.
  55. You can try to play with the regex here: https://regex101.com/r/69033B/1
  56. The character selection is such that it prevents the injection of javascript or
  57. executable bits to avoid any awkward behaviour in the front-end.
  58. :param conn_id: The connection id to sanitize.
  59. :param max_length: The max length of the connection ID, by default it is 250.
  60. :return: the sanitized string, `None` otherwise.
  61. """
  62. # check if `conn_id` or our match group is `None` and the `conn_id` is within the specified length.
  63. if (not isinstance(conn_id, str) or len(conn_id) > max_length) or (
  64. res := re2.match(RE_SANITIZE_CONN_ID, conn_id)
  65. ) is None:
  66. return None
  67. # if we reach here, then we matched something, return the first match
  68. return res.group(0)
  69. def _parse_netloc_to_hostname(uri_parts):
  70. """
  71. Parse a URI string to get the correct Hostname.
  72. ``urlparse(...).hostname`` or ``urlsplit(...).hostname`` returns value into the lowercase in most cases,
  73. there are some exclusion exists for specific cases such as https://bugs.python.org/issue32323
  74. In case if expected to get a path as part of hostname path,
  75. then default behavior ``urlparse``/``urlsplit`` is unexpected.
  76. """
  77. hostname = unquote(uri_parts.hostname or "")
  78. if "/" in hostname:
  79. hostname = uri_parts.netloc
  80. if "@" in hostname:
  81. hostname = hostname.rsplit("@", 1)[1]
  82. if ":" in hostname:
  83. hostname = hostname.split(":", 1)[0]
  84. hostname = unquote(hostname)
  85. return hostname
  86. class Connection(Base, LoggingMixin):
  87. """
  88. Placeholder to store information about different database instances connection information.
  89. The idea here is that scripts use references to database instances (conn_id)
  90. instead of hard coding hostname, logins and passwords when using operators or hooks.
  91. .. seealso::
  92. For more information on how to use this class, see: :doc:`/howto/connection`
  93. :param conn_id: The connection ID.
  94. :param conn_type: The connection type.
  95. :param description: The connection description.
  96. :param host: The host.
  97. :param login: The login.
  98. :param password: The password.
  99. :param schema: The schema.
  100. :param port: The port number.
  101. :param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON
  102. encoded object.
  103. :param uri: URI address describing connection parameters.
  104. """
  105. EXTRA_KEY = "__extra__"
  106. __tablename__ = "connection"
  107. id = Column(Integer(), primary_key=True)
  108. conn_id = Column(String(ID_LEN), unique=True, nullable=False)
  109. conn_type = Column(String(500), nullable=False)
  110. description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite"))
  111. host = Column(String(500))
  112. schema = Column(String(500))
  113. login = Column(Text())
  114. _password = Column("password", Text())
  115. port = Column(Integer())
  116. is_encrypted = Column(Boolean, unique=False, default=False)
  117. is_extra_encrypted = Column(Boolean, unique=False, default=False)
  118. _extra = Column("extra", Text())
  119. def __init__(
  120. self,
  121. conn_id: str | None = None,
  122. conn_type: str | None = None,
  123. description: str | None = None,
  124. host: str | None = None,
  125. login: str | None = None,
  126. password: str | None = None,
  127. schema: str | None = None,
  128. port: int | None = None,
  129. extra: str | dict | None = None,
  130. uri: str | None = None,
  131. ):
  132. super().__init__()
  133. self.conn_id = sanitize_conn_id(conn_id)
  134. self.description = description
  135. if extra and not isinstance(extra, str):
  136. extra = json.dumps(extra)
  137. if uri and (conn_type or host or login or password or schema or port or extra):
  138. raise AirflowException(
  139. "You must create an object using the URI or individual values "
  140. "(conn_type, host, login, password, schema, port or extra)."
  141. "You can't mix these two ways to create this object."
  142. )
  143. if uri:
  144. self._parse_from_uri(uri)
  145. else:
  146. self.conn_type = conn_type
  147. self.host = host
  148. self.login = login
  149. self.password = password
  150. self.schema = schema
  151. self.port = port
  152. self.extra = extra
  153. if self.extra:
  154. self._validate_extra(self.extra, self.conn_id)
  155. if self.password:
  156. mask_secret(self.password)
  157. mask_secret(quote(self.password))
  158. @staticmethod
  159. def _validate_extra(extra, conn_id) -> None:
  160. """
  161. Verify that ``extra`` is a JSON-encoded Python dict.
  162. From Airflow 3.0, we should no longer suppress these errors but raise instead.
  163. """
  164. if extra is None:
  165. return None
  166. try:
  167. extra_parsed = json.loads(extra)
  168. if not isinstance(extra_parsed, dict):
  169. warnings.warn(
  170. "Encountered JSON value in `extra` which does not parse as a dictionary in "
  171. f"connection {conn_id!r}. From Airflow 3.0, the `extra` field must contain a JSON "
  172. "representation of a Python dict.",
  173. RemovedInAirflow3Warning,
  174. stacklevel=3,
  175. )
  176. except json.JSONDecodeError:
  177. warnings.warn(
  178. f"Encountered non-JSON in `extra` field for connection {conn_id!r}. Support for "
  179. "non-JSON `extra` will be removed in Airflow 3.0",
  180. RemovedInAirflow3Warning,
  181. stacklevel=2,
  182. )
  183. return None
  184. @reconstructor
  185. def on_db_load(self):
  186. if self.password:
  187. mask_secret(self.password)
  188. mask_secret(quote(self.password))
  189. def parse_from_uri(self, **uri):
  190. """Use uri parameter in constructor, this method is deprecated."""
  191. warnings.warn(
  192. "This method is deprecated. Please use uri parameter in constructor.",
  193. RemovedInAirflow3Warning,
  194. stacklevel=2,
  195. )
  196. self._parse_from_uri(**uri)
  197. @staticmethod
  198. def _normalize_conn_type(conn_type):
  199. if conn_type == "postgresql":
  200. conn_type = "postgres"
  201. elif "-" in conn_type:
  202. conn_type = conn_type.replace("-", "_")
  203. return conn_type
  204. def _parse_from_uri(self, uri: str):
  205. schemes_count_in_uri = uri.count("://")
  206. if schemes_count_in_uri > 2:
  207. raise AirflowException(f"Invalid connection string: {uri}.")
  208. host_with_protocol = schemes_count_in_uri == 2
  209. uri_parts = urlsplit(uri)
  210. conn_type = uri_parts.scheme
  211. self.conn_type = self._normalize_conn_type(conn_type)
  212. rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//"))
  213. if host_with_protocol:
  214. uri_splits = rest_of_the_url.split("://", 1)
  215. if "@" in uri_splits[0] or ":" in uri_splits[0]:
  216. raise AirflowException(f"Invalid connection string: {uri}.")
  217. uri_parts = urlsplit(rest_of_the_url)
  218. protocol = uri_parts.scheme if host_with_protocol else None
  219. host = _parse_netloc_to_hostname(uri_parts)
  220. self.host = self._create_host(protocol, host)
  221. quoted_schema = uri_parts.path[1:]
  222. self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
  223. self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
  224. self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password
  225. self.port = uri_parts.port
  226. if uri_parts.query:
  227. query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
  228. if self.EXTRA_KEY in query:
  229. self.extra = query[self.EXTRA_KEY]
  230. else:
  231. self.extra = json.dumps(query)
  232. @staticmethod
  233. def _create_host(protocol, host) -> str | None:
  234. """Return the connection host with the protocol."""
  235. if not host:
  236. return host
  237. if protocol:
  238. return f"{protocol}://{host}"
  239. return host
  240. def get_uri(self) -> str:
  241. """Return connection in URI format."""
  242. if self.conn_type and "_" in self.conn_type:
  243. self.log.warning(
  244. "Connection schemes (type: %s) shall not contain '_' according to RFC3986.",
  245. self.conn_type,
  246. )
  247. if self.conn_type:
  248. uri = f"{self.conn_type.lower().replace('_', '-')}://"
  249. else:
  250. uri = "//"
  251. if self.host and "://" in self.host:
  252. protocol, host = self.host.split("://", 1)
  253. else:
  254. protocol, host = None, self.host
  255. if protocol:
  256. uri += f"{protocol}://"
  257. authority_block = ""
  258. if self.login is not None:
  259. authority_block += quote(self.login, safe="")
  260. if self.password is not None:
  261. authority_block += ":" + quote(self.password, safe="")
  262. if authority_block > "":
  263. authority_block += "@"
  264. uri += authority_block
  265. host_block = ""
  266. if host:
  267. host_block += quote(host, safe="")
  268. if self.port:
  269. if host_block == "" and authority_block == "":
  270. host_block += f"@:{self.port}"
  271. else:
  272. host_block += f":{self.port}"
  273. if self.schema:
  274. host_block += f"/{quote(self.schema, safe='')}"
  275. uri += host_block
  276. if self.extra:
  277. try:
  278. query: str | None = urlencode(self.extra_dejson)
  279. except TypeError:
  280. query = None
  281. if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)):
  282. uri += ("?" if self.schema else "/?") + query
  283. else:
  284. uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra})
  285. return uri
  286. def get_password(self) -> str | None:
  287. """Return encrypted password."""
  288. if self._password and self.is_encrypted:
  289. fernet = get_fernet()
  290. if not fernet.is_encrypted:
  291. raise AirflowException(
  292. f"Can't decrypt encrypted password for login={self.login} "
  293. f"FERNET_KEY configuration is missing"
  294. )
  295. return fernet.decrypt(bytes(self._password, "utf-8")).decode()
  296. else:
  297. return self._password
  298. def set_password(self, value: str | None):
  299. """Encrypt password and set in object attribute."""
  300. if value:
  301. fernet = get_fernet()
  302. self._password = fernet.encrypt(bytes(value, "utf-8")).decode()
  303. self.is_encrypted = fernet.is_encrypted
  304. @declared_attr
  305. def password(cls):
  306. """Password. The value is decrypted/encrypted when reading/setting the value."""
  307. return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
  308. def get_extra(self) -> str:
  309. """Return encrypted extra-data."""
  310. if self._extra and self.is_extra_encrypted:
  311. fernet = get_fernet()
  312. if not fernet.is_encrypted:
  313. raise AirflowException(
  314. f"Can't decrypt `extra` params for login={self.login}, "
  315. f"FERNET_KEY configuration is missing"
  316. )
  317. extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode()
  318. else:
  319. extra_val = self._extra
  320. if extra_val:
  321. self._validate_extra(extra_val, self.conn_id)
  322. return extra_val
  323. def set_extra(self, value: str):
  324. """Encrypt extra-data and save in object attribute to object."""
  325. if value:
  326. self._validate_extra(value, self.conn_id)
  327. fernet = get_fernet()
  328. self._extra = fernet.encrypt(bytes(value, "utf-8")).decode()
  329. self.is_extra_encrypted = fernet.is_encrypted
  330. else:
  331. self._extra = value
  332. self.is_extra_encrypted = False
  333. @declared_attr
  334. def extra(cls):
  335. """Extra data. The value is decrypted/encrypted when reading/setting the value."""
  336. return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra))
  337. def rotate_fernet_key(self):
  338. """Encrypts data with a new key. See: :ref:`security/fernet`."""
  339. fernet = get_fernet()
  340. if self._password and self.is_encrypted:
  341. self._password = fernet.rotate(self._password.encode("utf-8")).decode()
  342. if self._extra and self.is_extra_encrypted:
  343. self._extra = fernet.rotate(self._extra.encode("utf-8")).decode()
  344. def get_hook(self, *, hook_params=None):
  345. """Return hook based on conn_type."""
  346. from airflow.providers_manager import ProvidersManager
  347. hook = ProvidersManager().hooks.get(self.conn_type, None)
  348. if hook is None:
  349. raise AirflowException(f'Unknown hook type "{self.conn_type}"')
  350. try:
  351. hook_class = import_string(hook.hook_class_name)
  352. except ImportError:
  353. log.error(
  354. "Could not import %s when discovering %s %s",
  355. hook.hook_class_name,
  356. hook.hook_name,
  357. hook.package_name,
  358. )
  359. raise
  360. if hook_params is None:
  361. hook_params = {}
  362. return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)
  363. def __repr__(self):
  364. return self.conn_id or ""
  365. def log_info(self):
  366. """
  367. Read each field individually or use the default representation (`__repr__`).
  368. This method is deprecated.
  369. """
  370. warnings.warn(
  371. "This method is deprecated. You can read each field individually or "
  372. "use the default representation (__repr__).",
  373. RemovedInAirflow3Warning,
  374. stacklevel=2,
  375. )
  376. return (
  377. f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, "
  378. f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, "
  379. f"extra: {'XXXXXXXX' if self.extra_dejson else None}"
  380. )
  381. def debug_info(self):
  382. """
  383. Read each field individually or use the default representation (`__repr__`).
  384. This method is deprecated.
  385. """
  386. warnings.warn(
  387. "This method is deprecated. You can read each field individually or "
  388. "use the default representation (__repr__).",
  389. RemovedInAirflow3Warning,
  390. stacklevel=2,
  391. )
  392. return (
  393. f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, "
  394. f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, "
  395. f"extra: {self.extra_dejson}"
  396. )
  397. def test_connection(self):
  398. """Calls out get_hook method and executes test_connection method on that."""
  399. status, message = False, ""
  400. try:
  401. hook = self.get_hook()
  402. if getattr(hook, "test_connection", False):
  403. status, message = hook.test_connection()
  404. else:
  405. message = (
  406. f"Hook {hook.__class__.__name__} doesn't implement or inherit test_connection method"
  407. )
  408. except Exception as e:
  409. message = str(e)
  410. return status, message
  411. def get_extra_dejson(self, nested: bool = False) -> dict:
  412. """
  413. Deserialize extra property to JSON.
  414. :param nested: Determines whether nested structures are also deserialized into JSON (default False).
  415. """
  416. extra = {}
  417. if self.extra:
  418. try:
  419. if nested:
  420. for key, value in json.loads(self.extra).items():
  421. extra[key] = value
  422. if isinstance(value, str):
  423. with suppress(JSONDecodeError):
  424. extra[key] = json.loads(value)
  425. else:
  426. extra = json.loads(self.extra)
  427. except JSONDecodeError:
  428. self.log.exception("Failed parsing the json for conn_id %s", self.conn_id)
  429. # Mask sensitive keys from this list
  430. mask_secret(extra)
  431. return extra
  432. @property
  433. def extra_dejson(self) -> dict:
  434. """Returns the extra property by deserializing json."""
  435. return self.get_extra_dejson()
  436. @classmethod
  437. def get_connection_from_secrets(cls, conn_id: str) -> Connection:
  438. """
  439. Get connection by conn_id.
  440. :param conn_id: connection id
  441. :return: connection
  442. """
  443. # check cache first
  444. # enabled only if SecretCache.init() has been called first
  445. try:
  446. uri = SecretCache.get_connection_uri(conn_id)
  447. return Connection(conn_id=conn_id, uri=uri)
  448. except SecretCache.NotPresentException:
  449. pass # continue business
  450. # iterate over backends if not in cache (or expired)
  451. for secrets_backend in ensure_secrets_loaded():
  452. try:
  453. conn = secrets_backend.get_connection(conn_id=conn_id)
  454. if conn:
  455. SecretCache.save_connection_uri(conn_id, conn.get_uri())
  456. return conn
  457. except Exception:
  458. log.exception(
  459. "Unable to retrieve connection from secrets backend (%s). "
  460. "Checking subsequent secrets backend.",
  461. type(secrets_backend).__name__,
  462. )
  463. raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
  464. def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]:
  465. """
  466. Convert Connection to json-serializable dictionary.
  467. :param prune_empty: Whether or not remove empty values.
  468. :param validate: Validate dictionary is JSON-serializable
  469. :meta private:
  470. """
  471. conn = {
  472. "conn_id": self.conn_id,
  473. "conn_type": self.conn_type,
  474. "description": self.description,
  475. "host": self.host,
  476. "login": self.login,
  477. "password": self.password,
  478. "schema": self.schema,
  479. "port": self.port,
  480. }
  481. if prune_empty:
  482. conn = prune_dict(val=conn, mode="strict")
  483. if (extra := self.extra_dejson) or not prune_empty:
  484. conn["extra"] = extra
  485. if validate:
  486. json.dumps(conn)
  487. return conn
  488. @classmethod
  489. def from_json(cls, value, conn_id=None) -> Connection:
  490. kwargs = json.loads(value)
  491. extra = kwargs.pop("extra", None)
  492. if extra:
  493. kwargs["extra"] = extra if isinstance(extra, str) else json.dumps(extra)
  494. conn_type = kwargs.pop("conn_type", None)
  495. if conn_type:
  496. kwargs["conn_type"] = cls._normalize_conn_type(conn_type)
  497. port = kwargs.pop("port", None)
  498. if port:
  499. try:
  500. kwargs["port"] = int(port)
  501. except ValueError:
  502. raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
  503. return Connection(conn_id=conn_id, **kwargs)
  504. def as_json(self) -> str:
  505. """Convert Connection to JSON-string object."""
  506. conn_repr = self.to_dict(prune_empty=True, validate=False)
  507. conn_repr.pop("conn_id", None)
  508. return json.dumps(conn_repr)