statsd_logger.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import logging
  19. from functools import wraps
  20. from typing import TYPE_CHECKING, Callable, TypeVar, cast
  21. from airflow.configuration import conf
  22. from airflow.exceptions import AirflowConfigException
  23. from airflow.metrics.protocols import Timer
  24. from airflow.metrics.validators import (
  25. AllowListValidator,
  26. BlockListValidator,
  27. get_validator,
  28. validate_stat,
  29. )
  30. if TYPE_CHECKING:
  31. from statsd import StatsClient
  32. from airflow.metrics.protocols import DeltaType, TimerProtocol
  33. from airflow.metrics.validators import (
  34. ListValidator,
  35. )
  36. T = TypeVar("T", bound=Callable)
  37. log = logging.getLogger(__name__)
  38. def prepare_stat_with_tags(fn: T) -> T:
  39. """Add tags to stat with influxdb standard format if influxdb_tags_enabled is True."""
  40. @wraps(fn)
  41. def wrapper(
  42. self, stat: str | None = None, *args, tags: dict[str, str] | None = None, **kwargs
  43. ) -> Callable[[str], str]:
  44. if self.influxdb_tags_enabled:
  45. if stat is not None and tags is not None:
  46. for k, v in tags.items():
  47. if self.metric_tags_validator.test(k):
  48. if all(c not in [",", "="] for c in f"{v}{k}"):
  49. stat += f",{k}={v}"
  50. else:
  51. log.error("Dropping invalid tag: %s=%s.", k, v)
  52. return fn(self, stat, *args, tags=tags, **kwargs)
  53. return cast(T, wrapper)
  54. class SafeStatsdLogger:
  55. """StatsD Logger."""
  56. def __init__(
  57. self,
  58. statsd_client: StatsClient,
  59. metrics_validator: ListValidator = AllowListValidator(),
  60. influxdb_tags_enabled: bool = False,
  61. metric_tags_validator: ListValidator = AllowListValidator(),
  62. ) -> None:
  63. self.statsd = statsd_client
  64. self.metrics_validator = metrics_validator
  65. self.influxdb_tags_enabled = influxdb_tags_enabled
  66. self.metric_tags_validator = metric_tags_validator
  67. @prepare_stat_with_tags
  68. @validate_stat
  69. def incr(
  70. self,
  71. stat: str,
  72. count: int = 1,
  73. rate: float = 1,
  74. *,
  75. tags: dict[str, str] | None = None,
  76. ) -> None:
  77. """Increment stat."""
  78. if self.metrics_validator.test(stat):
  79. return self.statsd.incr(stat, count, rate)
  80. return None
  81. @prepare_stat_with_tags
  82. @validate_stat
  83. def decr(
  84. self,
  85. stat: str,
  86. count: int = 1,
  87. rate: float = 1,
  88. *,
  89. tags: dict[str, str] | None = None,
  90. ) -> None:
  91. """Decrement stat."""
  92. if self.metrics_validator.test(stat):
  93. return self.statsd.decr(stat, count, rate)
  94. return None
  95. @prepare_stat_with_tags
  96. @validate_stat
  97. def gauge(
  98. self,
  99. stat: str,
  100. value: int | float,
  101. rate: float = 1,
  102. delta: bool = False,
  103. *,
  104. tags: dict[str, str] | None = None,
  105. ) -> None:
  106. """Gauge stat."""
  107. if self.metrics_validator.test(stat):
  108. return self.statsd.gauge(stat, value, rate, delta)
  109. return None
  110. @prepare_stat_with_tags
  111. @validate_stat
  112. def timing(
  113. self,
  114. stat: str,
  115. dt: DeltaType,
  116. *,
  117. tags: dict[str, str] | None = None,
  118. ) -> None:
  119. """Stats timing."""
  120. if self.metrics_validator.test(stat):
  121. return self.statsd.timing(stat, dt)
  122. return None
  123. @prepare_stat_with_tags
  124. @validate_stat
  125. def timer(
  126. self,
  127. stat: str | None = None,
  128. *args,
  129. tags: dict[str, str] | None = None,
  130. **kwargs,
  131. ) -> TimerProtocol:
  132. """Timer metric that can be cancelled."""
  133. if stat and self.metrics_validator.test(stat):
  134. return Timer(self.statsd.timer(stat, *args, **kwargs))
  135. return Timer()
  136. def get_statsd_logger(cls) -> SafeStatsdLogger:
  137. """Return logger for StatsD."""
  138. # no need to check for the scheduler/statsd_on -> this method is only called when it is set
  139. # and previously it would crash with None is callable if it was called without it.
  140. from statsd import StatsClient
  141. stats_class = conf.getimport("metrics", "statsd_custom_client_path", fallback=None)
  142. if stats_class:
  143. if not issubclass(stats_class, StatsClient):
  144. raise AirflowConfigException(
  145. "Your custom StatsD client must extend the statsd.StatsClient in order to ensure "
  146. "backwards compatibility."
  147. )
  148. else:
  149. log.info("Successfully loaded custom StatsD client")
  150. else:
  151. stats_class = StatsClient
  152. statsd = stats_class(
  153. host=conf.get("metrics", "statsd_host"),
  154. port=conf.getint("metrics", "statsd_port"),
  155. prefix=conf.get("metrics", "statsd_prefix"),
  156. )
  157. influxdb_tags_enabled = conf.getboolean("metrics", "statsd_influxdb_enabled", fallback=False)
  158. metric_tags_validator = BlockListValidator(conf.get("metrics", "statsd_disabled_tags", fallback=None))
  159. return SafeStatsdLogger(statsd, get_validator(), influxdb_tags_enabled, metric_tags_validator)