1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Hook for additional Package Indexes (Python)."""
- from __future__ import annotations
- import subprocess
- from typing import Any
- from urllib.parse import quote, urlparse
- from airflow.hooks.base import BaseHook
- class PackageIndexHook(BaseHook):
- """Specify package indexes/Python package sources using Airflow connections."""
- conn_name_attr = "pi_conn_id"
- default_conn_name = "package_index_default"
- conn_type = "package_index"
- hook_name = "Package Index (Python)"
- def __init__(self, pi_conn_id: str = default_conn_name, **kwargs) -> None:
- super().__init__(**kwargs)
- self.pi_conn_id = pi_conn_id
- self.conn = None
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
- """Return custom field behaviour."""
- return {
- "hidden_fields": ["schema", "port", "extra"],
- "relabeling": {"host": "Package Index URL"},
- "placeholders": {
- "host": "Example: https://my-package-mirror.net/pypi/repo-name/simple",
- "login": "Username for package index",
- "password": "Password for package index (will be masked)",
- },
- }
- @staticmethod
- def _get_basic_auth_conn_url(index_url: str, user: str | None, password: str | None) -> str:
- """Return a connection URL with basic auth credentials based on connection config."""
- url = urlparse(index_url)
- host = url.netloc.split("@")[-1]
- if user:
- if password:
- host = f"{quote(user)}:{quote(password)}@{host}"
- else:
- host = f"{quote(user)}@{host}"
- return url._replace(netloc=host).geturl()
- def get_conn(self) -> Any:
- """Return connection for the hook."""
- return self.get_connection_url()
- def get_connection_url(self) -> Any:
- """Return a connection URL with embedded credentials."""
- conn = self.get_connection(self.pi_conn_id)
- index_url = conn.host
- if not index_url:
- raise ValueError("Please provide an index URL.")
- return self._get_basic_auth_conn_url(index_url, conn.login, conn.password)
- def test_connection(self) -> tuple[bool, str]:
- """Test connection to package index url."""
- conn_url = self.get_connection_url()
- proc = subprocess.run(
- ["pip", "search", "not-existing-test-package", "--no-input", "--index", conn_url],
- check=False,
- capture_output=True,
- )
- conn = self.get_connection(self.pi_conn_id)
- if proc.returncode not in [
- 0, # executed successfully, found package
- 23, # executed successfully, didn't find any packages
- # (but we do not expect it to find 'not-existing-test-package')
- ]:
- return False, f"Connection test to {conn.host} failed. Error: {str(proc.stderr)}"
- return True, f"Connection to {conn.host} tested successfully!"
|