package_index.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Hook for additional Package Indexes (Python)."""
  19. from __future__ import annotations
  20. import subprocess
  21. from typing import Any
  22. from urllib.parse import quote, urlparse
  23. from airflow.hooks.base import BaseHook
  24. class PackageIndexHook(BaseHook):
  25. """Specify package indexes/Python package sources using Airflow connections."""
  26. conn_name_attr = "pi_conn_id"
  27. default_conn_name = "package_index_default"
  28. conn_type = "package_index"
  29. hook_name = "Package Index (Python)"
  30. def __init__(self, pi_conn_id: str = default_conn_name, **kwargs) -> None:
  31. super().__init__(**kwargs)
  32. self.pi_conn_id = pi_conn_id
  33. self.conn = None
  34. @staticmethod
  35. def get_ui_field_behaviour() -> dict[str, Any]:
  36. """Return custom field behaviour."""
  37. return {
  38. "hidden_fields": ["schema", "port", "extra"],
  39. "relabeling": {"host": "Package Index URL"},
  40. "placeholders": {
  41. "host": "Example: https://my-package-mirror.net/pypi/repo-name/simple",
  42. "login": "Username for package index",
  43. "password": "Password for package index (will be masked)",
  44. },
  45. }
  46. @staticmethod
  47. def _get_basic_auth_conn_url(index_url: str, user: str | None, password: str | None) -> str:
  48. """Return a connection URL with basic auth credentials based on connection config."""
  49. url = urlparse(index_url)
  50. host = url.netloc.split("@")[-1]
  51. if user:
  52. if password:
  53. host = f"{quote(user)}:{quote(password)}@{host}"
  54. else:
  55. host = f"{quote(user)}@{host}"
  56. return url._replace(netloc=host).geturl()
  57. def get_conn(self) -> Any:
  58. """Return connection for the hook."""
  59. return self.get_connection_url()
  60. def get_connection_url(self) -> Any:
  61. """Return a connection URL with embedded credentials."""
  62. conn = self.get_connection(self.pi_conn_id)
  63. index_url = conn.host
  64. if not index_url:
  65. raise ValueError("Please provide an index URL.")
  66. return self._get_basic_auth_conn_url(index_url, conn.login, conn.password)
  67. def test_connection(self) -> tuple[bool, str]:
  68. """Test connection to package index url."""
  69. conn_url = self.get_connection_url()
  70. proc = subprocess.run(
  71. ["pip", "search", "not-existing-test-package", "--no-input", "--index", conn_url],
  72. check=False,
  73. capture_output=True,
  74. )
  75. conn = self.get_connection(self.pi_conn_id)
  76. if proc.returncode not in [
  77. 0, # executed successfully, found package
  78. 23, # executed successfully, didn't find any packages
  79. # (but we do not expect it to find 'not-existing-test-package')
  80. ]:
  81. return False, f"Connection test to {conn.host} failed. Error: {str(proc.stderr)}"
  82. return True, f"Connection to {conn.host} tested successfully!"