basenotifier.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from abc import abstractmethod
  19. from typing import TYPE_CHECKING, Sequence
  20. from airflow.template.templater import Templater
  21. from airflow.utils.context import context_merge
  22. if TYPE_CHECKING:
  23. import jinja2
  24. from airflow import DAG
  25. from airflow.utils.context import Context
  26. class BaseNotifier(Templater):
  27. """BaseNotifier class for sending notifications."""
  28. template_fields: Sequence[str] = ()
  29. template_ext: Sequence[str] = ()
  30. def __init__(self):
  31. super().__init__()
  32. self.resolve_template_files()
  33. def _update_context(self, context: Context) -> Context:
  34. """
  35. Add additional context to the context.
  36. :param context: The airflow context
  37. """
  38. additional_context = ((f, getattr(self, f)) for f in self.template_fields)
  39. context_merge(context, additional_context)
  40. return context
  41. def _render(self, template, context, dag: DAG | None = None):
  42. dag = dag or context["dag"]
  43. return super()._render(template, context, dag)
  44. def render_template_fields(
  45. self,
  46. context: Context,
  47. jinja_env: jinja2.Environment | None = None,
  48. ) -> None:
  49. """
  50. Template all attributes listed in *self.template_fields*.
  51. This mutates the attributes in-place and is irreversible.
  52. :param context: Context dict with values to apply on content.
  53. :param jinja_env: Jinja environment to use for rendering.
  54. """
  55. dag = context["dag"]
  56. if not jinja_env:
  57. jinja_env = self.get_template_env(dag=dag)
  58. self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
  59. @abstractmethod
  60. def notify(self, context: Context) -> None:
  61. """
  62. Send a notification.
  63. :param context: The airflow context
  64. """
  65. ...
  66. def __call__(self, *args) -> None:
  67. """
  68. Send a notification.
  69. :param context: The airflow context
  70. """
  71. # Currently, there are two ways a callback is invoked
  72. # 1. callback(context) - for on_*_callbacks
  73. # 2. callback(dag, task_list, blocking_task_list, slas, blocking_tis) - for sla_miss_callback
  74. # we have to distinguish between the two calls so that we can prepare the correct context,
  75. if len(args) == 1:
  76. context = args[0]
  77. else:
  78. context = {
  79. "dag": args[0],
  80. "task_list": args[1],
  81. "blocking_task_list": args[2],
  82. "slas": args[3],
  83. "blocking_tis": args[4],
  84. }
  85. self._update_context(context)
  86. self.render_template_fields(context)
  87. try:
  88. self.notify(context)
  89. except Exception as e:
  90. self.log.exception("Failed to send notification: %s", e)