123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- from typing import TYPE_CHECKING, Sequence
- from airflow.hooks.base import BaseHook
- from airflow.models import BaseOperator
- if TYPE_CHECKING:
- from airflow.utils.context import Context
- class GenericTransfer(BaseOperator):
- """
- Moves data from a connection to another.
- Assuming that they both provide the required methods in their respective hooks.
- The source hook needs to expose a `get_records` method, and the destination a
- `insert_rows` method.
- This is meant to be used on small-ish datasets that fit in memory.
- :param sql: SQL query to execute against the source database. (templated)
- :param destination_table: target table. (templated)
- :param source_conn_id: source connection
- :param destination_conn_id: destination connection
- :param preoperator: sql statement or list of statements to be
- executed prior to loading the data. (templated)
- :param insert_args: extra params for `insert_rows` method.
- """
- template_fields: Sequence[str] = ("sql", "destination_table", "preoperator")
- template_ext: Sequence[str] = (
- ".sql",
- ".hql",
- )
- template_fields_renderers = {"preoperator": "sql"}
- ui_color = "#b0f07c"
- def __init__(
- self,
- *,
- sql: str,
- destination_table: str,
- source_conn_id: str,
- destination_conn_id: str,
- preoperator: str | list[str] | None = None,
- insert_args: dict | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.sql = sql
- self.destination_table = destination_table
- self.source_conn_id = source_conn_id
- self.destination_conn_id = destination_conn_id
- self.preoperator = preoperator
- self.insert_args = insert_args or {}
- def execute(self, context: Context):
- source_hook = BaseHook.get_hook(self.source_conn_id)
- destination_hook = BaseHook.get_hook(self.destination_conn_id)
- self.log.info("Extracting data from %s", self.source_conn_id)
- self.log.info("Executing: \n %s", self.sql)
- get_records = getattr(source_hook, "get_records", None)
- if not callable(get_records):
- raise RuntimeError(
- f"Hook for connection {self.source_conn_id!r} "
- f"({type(source_hook).__name__}) has no `get_records` method"
- )
- else:
- results = get_records(self.sql)
- if self.preoperator:
- run = getattr(destination_hook, "run", None)
- if not callable(run):
- raise RuntimeError(
- f"Hook for connection {self.destination_conn_id!r} "
- f"({type(destination_hook).__name__}) has no `run` method"
- )
- self.log.info("Running preoperator")
- self.log.info(self.preoperator)
- run(self.preoperator)
- insert_rows = getattr(destination_hook, "insert_rows", None)
- if not callable(insert_rows):
- raise RuntimeError(
- f"Hook for connection {self.destination_conn_id!r} "
- f"({type(destination_hook).__name__}) has no `insert_rows` method"
- )
- self.log.info("Inserting rows into %s", self.destination_conn_id)
- insert_rows(table=self.destination_table, rows=results, **self.insert_args)
|