123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- # Copyright 2023 The gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import annotations
- import abc
- import contextlib
- import logging
- import threading
- from typing import Any, Generator, Generic, List, Optional, TypeVar
- from grpc._cython import cygrpc as _cygrpc
- from grpc._typing import ChannelArgumentType
- _LOGGER = logging.getLogger(__name__)
- _channel = Any # _channel.py imports this module.
- ClientCallTracerCapsule = TypeVar("ClientCallTracerCapsule")
- ServerCallTracerFactoryCapsule = TypeVar("ServerCallTracerFactoryCapsule")
- _plugin_lock: threading.RLock = threading.RLock()
- _OBSERVABILITY_PLUGIN: Optional["ObservabilityPlugin"] = None
- _SERVICES_TO_EXCLUDE: List[bytes] = [
- b"google.monitoring.v3.MetricService",
- b"google.devtools.cloudtrace.v2.TraceService",
- ]
- class ServerCallTracerFactory:
- """An encapsulation of a ServerCallTracerFactory.
- Instances of this class can be passed to a Channel as values for the
- grpc.experimental.server_call_tracer_factory option
- """
- def __init__(self, address):
- self._address = address
- def __int__(self):
- return self._address
- class ObservabilityPlugin(
- Generic[ClientCallTracerCapsule, ServerCallTracerFactoryCapsule],
- metaclass=abc.ABCMeta,
- ):
- """Abstract base class for observability plugin.
- *This is a semi-private class that was intended for the exclusive use of
- the gRPC team.*
- The ClientCallTracerCapsule and ClientCallTracerCapsule created by this
- plugin should be injected to gRPC core using observability_init at the
- start of a program, before any channels/servers are built.
- Any future methods added to this interface cannot have the
- @abc.abstractmethod annotation.
- Attributes:
- _stats_enabled: A bool indicates whether tracing is enabled.
- _tracing_enabled: A bool indicates whether stats(metrics) is enabled.
- _registered_methods: A set which stores the registered method names in
- bytes.
- """
- _tracing_enabled: bool = False
- _stats_enabled: bool = False
- @abc.abstractmethod
- def create_client_call_tracer(
- self, method_name: bytes, target: bytes
- ) -> ClientCallTracerCapsule:
- """Creates a ClientCallTracerCapsule.
- After register the plugin, if tracing or stats is enabled, this method
- will be called after a call was created, the ClientCallTracer created
- by this method will be saved to call context.
- The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer`
- interface and wrapped in a PyCapsule using `client_call_tracer` as name.
- Args:
- method_name: The method name of the call in byte format.
- target: The channel target of the call in byte format.
- registered_method: Whether this method is pre-registered.
- Returns:
- A PyCapsule which stores a ClientCallTracer object.
- """
- raise NotImplementedError()
- @abc.abstractmethod
- def save_trace_context(
- self, trace_id: str, span_id: str, is_sampled: bool
- ) -> None:
- """Saves the trace_id and span_id related to the current span.
- After register the plugin, if tracing is enabled, this method will be
- called after the server finished sending response.
- This method can be used to propagate census context.
- Args:
- trace_id: The identifier for the trace associated with the span as a
- 32-character hexadecimal encoded string,
- e.g. 26ed0036f2eff2b7317bccce3e28d01f
- span_id: The identifier for the span as a 16-character hexadecimal encoded
- string. e.g. 113ec879e62583bc
- is_sampled: A bool indicates whether the span is sampled.
- """
- raise NotImplementedError()
- @abc.abstractmethod
- def create_server_call_tracer_factory(
- self,
- *,
- xds: bool = False,
- ) -> Optional[ServerCallTracerFactoryCapsule]:
- """Creates a ServerCallTracerFactoryCapsule.
- This method will be called at server initialization time to create a
- ServerCallTracerFactory, which will be registered to gRPC core.
- The ServerCallTracerFactory is an object which implements
- `grpc_core::ServerCallTracerFactory` interface and wrapped in a PyCapsule
- using `server_call_tracer_factory` as name.
- Args:
- xds: Whether the server is xds server.
- Returns:
- A PyCapsule which stores a ServerCallTracerFactory object. Or None if
- plugin decides not to create ServerCallTracerFactory.
- """
- raise NotImplementedError()
- @abc.abstractmethod
- def record_rpc_latency(
- self, method: str, target: str, rpc_latency: float, status_code: Any
- ) -> None:
- """Record the latency of the RPC.
- After register the plugin, if stats is enabled, this method will be
- called at the end of each RPC.
- Args:
- method: The fully-qualified name of the RPC method being invoked.
- target: The target name of the RPC method being invoked.
- rpc_latency: The latency for the RPC in seconds, equals to the time between
- when the client invokes the RPC and when the client receives the status.
- status_code: An element of grpc.StatusCode in string format representing the
- final status for the RPC.
- """
- raise NotImplementedError()
- def set_tracing(self, enable: bool) -> None:
- """Enable or disable tracing.
- Args:
- enable: A bool indicates whether tracing should be enabled.
- """
- self._tracing_enabled = enable
- def set_stats(self, enable: bool) -> None:
- """Enable or disable stats(metrics).
- Args:
- enable: A bool indicates whether stats should be enabled.
- """
- self._stats_enabled = enable
- def save_registered_method(self, method_name: bytes) -> None:
- """Saves the method name to registered_method list.
- When exporting metrics, method name for unregistered methods will be replaced
- with 'other' by default.
- Args:
- method_name: The method name in bytes.
- """
- raise NotImplementedError()
- @property
- def tracing_enabled(self) -> bool:
- return self._tracing_enabled
- @property
- def stats_enabled(self) -> bool:
- return self._stats_enabled
- @property
- def observability_enabled(self) -> bool:
- return self.tracing_enabled or self.stats_enabled
- @contextlib.contextmanager
- def get_plugin() -> Generator[Optional[ObservabilityPlugin], None, None]:
- """Get the ObservabilityPlugin in _observability module.
- Returns:
- The ObservabilityPlugin currently registered with the _observability
- module. Or None if no plugin exists at the time of calling this method.
- """
- with _plugin_lock:
- yield _OBSERVABILITY_PLUGIN
- def set_plugin(observability_plugin: Optional[ObservabilityPlugin]) -> None:
- """Save ObservabilityPlugin to _observability module.
- Args:
- observability_plugin: The ObservabilityPlugin to save.
- Raises:
- ValueError: If an ObservabilityPlugin was already registered at the
- time of calling this method.
- """
- global _OBSERVABILITY_PLUGIN # pylint: disable=global-statement
- with _plugin_lock:
- if observability_plugin and _OBSERVABILITY_PLUGIN:
- raise ValueError("observability_plugin was already set!")
- _OBSERVABILITY_PLUGIN = observability_plugin
- def observability_init(observability_plugin: ObservabilityPlugin) -> None:
- """Initialize observability with provided ObservabilityPlugin.
- This method have to be called at the start of a program, before any
- channels/servers are built.
- Args:
- observability_plugin: The ObservabilityPlugin to use.
- Raises:
- ValueError: If an ObservabilityPlugin was already registered at the
- time of calling this method.
- """
- set_plugin(observability_plugin)
- def observability_deinit() -> None:
- """Clear the observability context, including ObservabilityPlugin and
- ServerCallTracerFactory
- This method have to be called after exit observability context so that
- it's possible to re-initialize again.
- """
- set_plugin(None)
- _cygrpc.clear_server_call_tracer_factory()
- def maybe_record_rpc_latency(state: "_channel._RPCState") -> None:
- """Record the latency of the RPC, if the plugin is registered and stats is enabled.
- This method will be called at the end of each RPC.
- Args:
- state: a grpc._channel._RPCState object which contains the stats related to the
- RPC.
- """
- # TODO(xuanwn): use channel args to exclude those metrics.
- for exclude_prefix in _SERVICES_TO_EXCLUDE:
- if exclude_prefix in state.method.encode("utf8"):
- return
- with get_plugin() as plugin:
- if plugin and plugin.stats_enabled:
- rpc_latency_s = state.rpc_end_time - state.rpc_start_time
- rpc_latency_ms = rpc_latency_s * 1000
- plugin.record_rpc_latency(
- state.method, state.target, rpc_latency_ms, state.code
- )
- def create_server_call_tracer_factory_option(xds: bool) -> ChannelArgumentType:
- with get_plugin() as plugin:
- if plugin and plugin.stats_enabled:
- server_call_tracer_factory_address = (
- _cygrpc.get_server_call_tracer_factory_address(plugin, xds)
- )
- if server_call_tracer_factory_address:
- return (
- (
- "grpc.experimental.server_call_tracer_factory",
- ServerCallTracerFactory(
- server_call_tracer_factory_address
- ),
- ),
- )
- return ()
|