123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import logging
- from functools import wraps
- from typing import TYPE_CHECKING, Callable, TypeVar, cast
- from airflow.configuration import conf
- from airflow.exceptions import AirflowConfigException
- from airflow.metrics.protocols import Timer
- from airflow.metrics.validators import (
- AllowListValidator,
- BlockListValidator,
- get_validator,
- validate_stat,
- )
- if TYPE_CHECKING:
- from statsd import StatsClient
- from airflow.metrics.protocols import DeltaType, TimerProtocol
- from airflow.metrics.validators import (
- ListValidator,
- )
- T = TypeVar("T", bound=Callable)
- log = logging.getLogger(__name__)
- def prepare_stat_with_tags(fn: T) -> T:
- """Add tags to stat with influxdb standard format if influxdb_tags_enabled is True."""
- @wraps(fn)
- def wrapper(
- self, stat: str | None = None, *args, tags: dict[str, str] | None = None, **kwargs
- ) -> Callable[[str], str]:
- if self.influxdb_tags_enabled:
- if stat is not None and tags is not None:
- for k, v in tags.items():
- if self.metric_tags_validator.test(k):
- if all(c not in [",", "="] for c in f"{v}{k}"):
- stat += f",{k}={v}"
- else:
- log.error("Dropping invalid tag: %s=%s.", k, v)
- return fn(self, stat, *args, tags=tags, **kwargs)
- return cast(T, wrapper)
- class SafeStatsdLogger:
- """StatsD Logger."""
- def __init__(
- self,
- statsd_client: StatsClient,
- metrics_validator: ListValidator = AllowListValidator(),
- influxdb_tags_enabled: bool = False,
- metric_tags_validator: ListValidator = AllowListValidator(),
- ) -> None:
- self.statsd = statsd_client
- self.metrics_validator = metrics_validator
- self.influxdb_tags_enabled = influxdb_tags_enabled
- self.metric_tags_validator = metric_tags_validator
- @prepare_stat_with_tags
- @validate_stat
- def incr(
- self,
- stat: str,
- count: int = 1,
- rate: float = 1,
- *,
- tags: dict[str, str] | None = None,
- ) -> None:
- """Increment stat."""
- if self.metrics_validator.test(stat):
- return self.statsd.incr(stat, count, rate)
- return None
- @prepare_stat_with_tags
- @validate_stat
- def decr(
- self,
- stat: str,
- count: int = 1,
- rate: float = 1,
- *,
- tags: dict[str, str] | None = None,
- ) -> None:
- """Decrement stat."""
- if self.metrics_validator.test(stat):
- return self.statsd.decr(stat, count, rate)
- return None
- @prepare_stat_with_tags
- @validate_stat
- def gauge(
- self,
- stat: str,
- value: int | float,
- rate: float = 1,
- delta: bool = False,
- *,
- tags: dict[str, str] | None = None,
- ) -> None:
- """Gauge stat."""
- if self.metrics_validator.test(stat):
- return self.statsd.gauge(stat, value, rate, delta)
- return None
- @prepare_stat_with_tags
- @validate_stat
- def timing(
- self,
- stat: str,
- dt: DeltaType,
- *,
- tags: dict[str, str] | None = None,
- ) -> None:
- """Stats timing."""
- if self.metrics_validator.test(stat):
- return self.statsd.timing(stat, dt)
- return None
- @prepare_stat_with_tags
- @validate_stat
- def timer(
- self,
- stat: str | None = None,
- *args,
- tags: dict[str, str] | None = None,
- **kwargs,
- ) -> TimerProtocol:
- """Timer metric that can be cancelled."""
- if stat and self.metrics_validator.test(stat):
- return Timer(self.statsd.timer(stat, *args, **kwargs))
- return Timer()
- def get_statsd_logger(cls) -> SafeStatsdLogger:
- """Return logger for StatsD."""
- # no need to check for the scheduler/statsd_on -> this method is only called when it is set
- # and previously it would crash with None is callable if it was called without it.
- from statsd import StatsClient
- stats_class = conf.getimport("metrics", "statsd_custom_client_path", fallback=None)
- if stats_class:
- if not issubclass(stats_class, StatsClient):
- raise AirflowConfigException(
- "Your custom StatsD client must extend the statsd.StatsClient in order to ensure "
- "backwards compatibility."
- )
- else:
- log.info("Successfully loaded custom StatsD client")
- else:
- stats_class = StatsClient
- statsd = stats_class(
- host=conf.get("metrics", "statsd_host"),
- port=conf.getint("metrics", "statsd_port"),
- prefix=conf.get("metrics", "statsd_prefix"),
- )
- influxdb_tags_enabled = conf.getboolean("metrics", "statsd_influxdb_enabled", fallback=False)
- metric_tags_validator = BlockListValidator(conf.get("metrics", "statsd_disabled_tags", fallback=None))
- return SafeStatsdLogger(statsd, get_validator(), influxdb_tags_enabled, metric_tags_validator)
|