generic_transfer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from typing import TYPE_CHECKING, Sequence
  20. from airflow.hooks.base import BaseHook
  21. from airflow.models import BaseOperator
  22. if TYPE_CHECKING:
  23. from airflow.utils.context import Context
  24. class GenericTransfer(BaseOperator):
  25. """
  26. Moves data from a connection to another.
  27. Assuming that they both provide the required methods in their respective hooks.
  28. The source hook needs to expose a `get_records` method, and the destination a
  29. `insert_rows` method.
  30. This is meant to be used on small-ish datasets that fit in memory.
  31. :param sql: SQL query to execute against the source database. (templated)
  32. :param destination_table: target table. (templated)
  33. :param source_conn_id: source connection
  34. :param destination_conn_id: destination connection
  35. :param preoperator: sql statement or list of statements to be
  36. executed prior to loading the data. (templated)
  37. :param insert_args: extra params for `insert_rows` method.
  38. """
  39. template_fields: Sequence[str] = ("sql", "destination_table", "preoperator")
  40. template_ext: Sequence[str] = (
  41. ".sql",
  42. ".hql",
  43. )
  44. template_fields_renderers = {"preoperator": "sql"}
  45. ui_color = "#b0f07c"
  46. def __init__(
  47. self,
  48. *,
  49. sql: str,
  50. destination_table: str,
  51. source_conn_id: str,
  52. destination_conn_id: str,
  53. preoperator: str | list[str] | None = None,
  54. insert_args: dict | None = None,
  55. **kwargs,
  56. ) -> None:
  57. super().__init__(**kwargs)
  58. self.sql = sql
  59. self.destination_table = destination_table
  60. self.source_conn_id = source_conn_id
  61. self.destination_conn_id = destination_conn_id
  62. self.preoperator = preoperator
  63. self.insert_args = insert_args or {}
  64. def execute(self, context: Context):
  65. source_hook = BaseHook.get_hook(self.source_conn_id)
  66. destination_hook = BaseHook.get_hook(self.destination_conn_id)
  67. self.log.info("Extracting data from %s", self.source_conn_id)
  68. self.log.info("Executing: \n %s", self.sql)
  69. get_records = getattr(source_hook, "get_records", None)
  70. if not callable(get_records):
  71. raise RuntimeError(
  72. f"Hook for connection {self.source_conn_id!r} "
  73. f"({type(source_hook).__name__}) has no `get_records` method"
  74. )
  75. else:
  76. results = get_records(self.sql)
  77. if self.preoperator:
  78. run = getattr(destination_hook, "run", None)
  79. if not callable(run):
  80. raise RuntimeError(
  81. f"Hook for connection {self.destination_conn_id!r} "
  82. f"({type(destination_hook).__name__}) has no `run` method"
  83. )
  84. self.log.info("Running preoperator")
  85. self.log.info(self.preoperator)
  86. run(self.preoperator)
  87. insert_rows = getattr(destination_hook, "insert_rows", None)
  88. if not callable(insert_rows):
  89. raise RuntimeError(
  90. f"Hook for connection {self.destination_conn_id!r} "
  91. f"({type(destination_hook).__name__}) has no `insert_rows` method"
  92. )
  93. self.log.info("Inserting rows into %s", self.destination_conn_id)
  94. insert_rows(table=self.destination_table, rows=results, **self.insert_args)