http.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import asyncio
  19. import base64
  20. import pickle
  21. from collections.abc import AsyncIterator
  22. from typing import TYPE_CHECKING, Any
  23. import aiohttp
  24. import requests
  25. from requests.cookies import RequestsCookieJar
  26. from requests.structures import CaseInsensitiveDict
  27. from airflow.exceptions import AirflowException
  28. from airflow.providers.http.hooks.http import HttpAsyncHook
  29. from airflow.triggers.base import BaseTrigger, TriggerEvent
  30. if TYPE_CHECKING:
  31. from aiohttp.client_reqrep import ClientResponse
  32. class HttpTrigger(BaseTrigger):
  33. """
  34. HttpTrigger run on the trigger worker.
  35. :param http_conn_id: http connection id that has the base
  36. API url i.e https://www.google.com/ and optional authentication credentials. Default
  37. headers can also be specified in the Extra field in json format.
  38. :param auth_type: The auth type for the service
  39. :param method: the API method to be called
  40. :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``.
  41. :param headers: Additional headers to be passed through as a dict.
  42. :param data: Payload to be uploaded or request parameters.
  43. :param extra_options: Additional kwargs to pass when creating a request.
  44. For example, ``run(json=obj)`` is passed as
  45. ``aiohttp.ClientSession().get(json=obj)``.
  46. 2XX or 3XX status codes
  47. """
  48. def __init__(
  49. self,
  50. http_conn_id: str = "http_default",
  51. auth_type: Any = None,
  52. method: str = "POST",
  53. endpoint: str | None = None,
  54. headers: dict[str, str] | None = None,
  55. data: dict[str, Any] | str | None = None,
  56. extra_options: dict[str, Any] | None = None,
  57. ):
  58. super().__init__()
  59. self.http_conn_id = http_conn_id
  60. self.method = method
  61. self.auth_type = auth_type
  62. self.endpoint = endpoint
  63. self.headers = headers
  64. self.data = data
  65. self.extra_options = extra_options
  66. def serialize(self) -> tuple[str, dict[str, Any]]:
  67. """Serialize HttpTrigger arguments and classpath."""
  68. return (
  69. "airflow.providers.http.triggers.http.HttpTrigger",
  70. {
  71. "http_conn_id": self.http_conn_id,
  72. "method": self.method,
  73. "auth_type": self.auth_type,
  74. "endpoint": self.endpoint,
  75. "headers": self.headers,
  76. "data": self.data,
  77. "extra_options": self.extra_options,
  78. },
  79. )
  80. async def run(self) -> AsyncIterator[TriggerEvent]:
  81. """Make a series of asynchronous http calls via a http hook."""
  82. hook = HttpAsyncHook(
  83. method=self.method,
  84. http_conn_id=self.http_conn_id,
  85. auth_type=self.auth_type,
  86. )
  87. try:
  88. async with aiohttp.ClientSession() as session:
  89. client_response = await hook.run(
  90. session=session,
  91. endpoint=self.endpoint,
  92. data=self.data,
  93. headers=self.headers,
  94. extra_options=self.extra_options,
  95. )
  96. response = await self._convert_response(client_response)
  97. yield TriggerEvent(
  98. {
  99. "status": "success",
  100. "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"),
  101. }
  102. )
  103. except Exception as e:
  104. yield TriggerEvent({"status": "error", "message": str(e)})
  105. @staticmethod
  106. async def _convert_response(client_response: ClientResponse) -> requests.Response:
  107. """Convert aiohttp.client_reqrep.ClientResponse to requests.Response."""
  108. response = requests.Response()
  109. response._content = await client_response.read()
  110. response.status_code = client_response.status
  111. response.headers = CaseInsensitiveDict(client_response.headers)
  112. response.url = str(client_response.url)
  113. response.history = [await HttpTrigger._convert_response(h) for h in client_response.history]
  114. response.encoding = client_response.get_encoding()
  115. response.reason = str(client_response.reason)
  116. cookies = RequestsCookieJar()
  117. for k, v in client_response.cookies.items():
  118. cookies.set(k, v)
  119. response.cookies = cookies
  120. return response
  121. class HttpSensorTrigger(BaseTrigger):
  122. """
  123. A trigger that fires when the request to a URL returns a non-404 status code.
  124. :param endpoint: The relative part of the full url
  125. :param http_conn_id: The HTTP Connection ID to run the sensor against
  126. :param method: The HTTP request method to use
  127. :param data: payload to be uploaded or aiohttp parameters
  128. :param headers: The HTTP headers to be added to the GET request
  129. :param extra_options: Additional kwargs to pass when creating a request.
  130. For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
  131. :param poke_interval: Time to sleep using asyncio
  132. """
  133. def __init__(
  134. self,
  135. endpoint: str | None = None,
  136. http_conn_id: str = "http_default",
  137. method: str = "GET",
  138. data: dict[str, Any] | str | None = None,
  139. headers: dict[str, str] | None = None,
  140. extra_options: dict[str, Any] | None = None,
  141. poke_interval: float = 5.0,
  142. ):
  143. super().__init__()
  144. self.endpoint = endpoint
  145. self.method = method
  146. self.data = data
  147. self.headers = headers
  148. self.extra_options = extra_options or {}
  149. self.http_conn_id = http_conn_id
  150. self.poke_interval = poke_interval
  151. def serialize(self) -> tuple[str, dict[str, Any]]:
  152. """Serialize HttpTrigger arguments and classpath."""
  153. return (
  154. "airflow.providers.http.triggers.http.HttpSensorTrigger",
  155. {
  156. "endpoint": self.endpoint,
  157. "data": self.data,
  158. "method": self.method,
  159. "headers": self.headers,
  160. "extra_options": self.extra_options,
  161. "http_conn_id": self.http_conn_id,
  162. "poke_interval": self.poke_interval,
  163. },
  164. )
  165. async def run(self) -> AsyncIterator[TriggerEvent]:
  166. """Make a series of asynchronous http calls via an http hook."""
  167. hook = self._get_async_hook()
  168. while True:
  169. try:
  170. async with aiohttp.ClientSession() as session:
  171. await hook.run(
  172. session=session,
  173. endpoint=self.endpoint,
  174. data=self.data,
  175. headers=self.headers,
  176. extra_options=self.extra_options,
  177. )
  178. yield TriggerEvent(True)
  179. return
  180. except AirflowException as exc:
  181. if str(exc).startswith("404"):
  182. await asyncio.sleep(self.poke_interval)
  183. def _get_async_hook(self) -> HttpAsyncHook:
  184. return HttpAsyncHook(
  185. method=self.method,
  186. http_conn_id=self.http_conn_id,
  187. )