123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627 |
- # Copyright 2019 gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Invocation-side implementation of gRPC Asyncio Python."""
- import asyncio
- import sys
- from typing import Any, Iterable, List, Optional, Sequence
- import grpc
- from grpc import _common
- from grpc import _compression
- from grpc import _grpcio_metadata
- from grpc._cython import cygrpc
- from . import _base_call
- from . import _base_channel
- from ._call import StreamStreamCall
- from ._call import StreamUnaryCall
- from ._call import UnaryStreamCall
- from ._call import UnaryUnaryCall
- from ._interceptor import ClientInterceptor
- from ._interceptor import InterceptedStreamStreamCall
- from ._interceptor import InterceptedStreamUnaryCall
- from ._interceptor import InterceptedUnaryStreamCall
- from ._interceptor import InterceptedUnaryUnaryCall
- from ._interceptor import StreamStreamClientInterceptor
- from ._interceptor import StreamUnaryClientInterceptor
- from ._interceptor import UnaryStreamClientInterceptor
- from ._interceptor import UnaryUnaryClientInterceptor
- from ._metadata import Metadata
- from ._typing import ChannelArgumentType
- from ._typing import DeserializingFunction
- from ._typing import MetadataType
- from ._typing import RequestIterableType
- from ._typing import RequestType
- from ._typing import ResponseType
- from ._typing import SerializingFunction
- from ._utils import _timeout_to_deadline
- _USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
- if sys.version_info[1] < 7:
- def _all_tasks() -> Iterable[asyncio.Task]:
- return asyncio.Task.all_tasks() # pylint: disable=no-member
- else:
- def _all_tasks() -> Iterable[asyncio.Task]:
- return asyncio.all_tasks()
- def _augment_channel_arguments(
- base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
- ):
- compression_channel_argument = _compression.create_channel_option(
- compression
- )
- user_agent_channel_argument = (
- (
- cygrpc.ChannelArgKey.primary_user_agent_string,
- _USER_AGENT,
- ),
- )
- return (
- tuple(base_options)
- + compression_channel_argument
- + user_agent_channel_argument
- )
- class _BaseMultiCallable:
- """Base class of all multi callable objects.
- Handles the initialization logic and stores common attributes.
- """
- _loop: asyncio.AbstractEventLoop
- _channel: cygrpc.AioChannel
- _method: bytes
- _request_serializer: SerializingFunction
- _response_deserializer: DeserializingFunction
- _interceptors: Optional[Sequence[ClientInterceptor]]
- _references: List[Any]
- _loop: asyncio.AbstractEventLoop
- # pylint: disable=too-many-arguments
- def __init__(
- self,
- channel: cygrpc.AioChannel,
- method: bytes,
- request_serializer: SerializingFunction,
- response_deserializer: DeserializingFunction,
- interceptors: Optional[Sequence[ClientInterceptor]],
- references: List[Any],
- loop: asyncio.AbstractEventLoop,
- ) -> None:
- self._loop = loop
- self._channel = channel
- self._method = method
- self._request_serializer = request_serializer
- self._response_deserializer = response_deserializer
- self._interceptors = interceptors
- self._references = references
- @staticmethod
- def _init_metadata(
- metadata: Optional[MetadataType] = None,
- compression: Optional[grpc.Compression] = None,
- ) -> Metadata:
- """Based on the provided values for <metadata> or <compression> initialise the final
- metadata, as it should be used for the current call.
- """
- metadata = metadata or Metadata()
- if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
- metadata = Metadata.from_tuple(metadata)
- if compression:
- metadata = Metadata(
- *_compression.augment_metadata(metadata, compression)
- )
- return metadata
- class UnaryUnaryMultiCallable(
- _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
- ):
- def __call__(
- self,
- request: RequestType,
- *,
- timeout: Optional[float] = None,
- metadata: Optional[MetadataType] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None,
- ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
- metadata = self._init_metadata(metadata, compression)
- if not self._interceptors:
- call = UnaryUnaryCall(
- request,
- _timeout_to_deadline(timeout),
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- else:
- call = InterceptedUnaryUnaryCall(
- self._interceptors,
- request,
- timeout,
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- return call
- class UnaryStreamMultiCallable(
- _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
- ):
- def __call__(
- self,
- request: RequestType,
- *,
- timeout: Optional[float] = None,
- metadata: Optional[MetadataType] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None,
- ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
- metadata = self._init_metadata(metadata, compression)
- if not self._interceptors:
- call = UnaryStreamCall(
- request,
- _timeout_to_deadline(timeout),
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- else:
- call = InterceptedUnaryStreamCall(
- self._interceptors,
- request,
- timeout,
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- return call
- class StreamUnaryMultiCallable(
- _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
- ):
- def __call__(
- self,
- request_iterator: Optional[RequestIterableType] = None,
- timeout: Optional[float] = None,
- metadata: Optional[MetadataType] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None,
- ) -> _base_call.StreamUnaryCall:
- metadata = self._init_metadata(metadata, compression)
- if not self._interceptors:
- call = StreamUnaryCall(
- request_iterator,
- _timeout_to_deadline(timeout),
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- else:
- call = InterceptedStreamUnaryCall(
- self._interceptors,
- request_iterator,
- timeout,
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- return call
- class StreamStreamMultiCallable(
- _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
- ):
- def __call__(
- self,
- request_iterator: Optional[RequestIterableType] = None,
- timeout: Optional[float] = None,
- metadata: Optional[MetadataType] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None,
- ) -> _base_call.StreamStreamCall:
- metadata = self._init_metadata(metadata, compression)
- if not self._interceptors:
- call = StreamStreamCall(
- request_iterator,
- _timeout_to_deadline(timeout),
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- else:
- call = InterceptedStreamStreamCall(
- self._interceptors,
- request_iterator,
- timeout,
- metadata,
- credentials,
- wait_for_ready,
- self._channel,
- self._method,
- self._request_serializer,
- self._response_deserializer,
- self._loop,
- )
- return call
- class Channel(_base_channel.Channel):
- _loop: asyncio.AbstractEventLoop
- _channel: cygrpc.AioChannel
- _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
- _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
- _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
- _stream_stream_interceptors: List[StreamStreamClientInterceptor]
- def __init__(
- self,
- target: str,
- options: ChannelArgumentType,
- credentials: Optional[grpc.ChannelCredentials],
- compression: Optional[grpc.Compression],
- interceptors: Optional[Sequence[ClientInterceptor]],
- ):
- """Constructor.
- Args:
- target: The target to which to connect.
- options: Configuration options for the channel.
- credentials: A cygrpc.ChannelCredentials or None.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional list of interceptors that would be used for
- intercepting any RPC executed with that channel.
- """
- self._unary_unary_interceptors = []
- self._unary_stream_interceptors = []
- self._stream_unary_interceptors = []
- self._stream_stream_interceptors = []
- if interceptors is not None:
- for interceptor in interceptors:
- if isinstance(interceptor, UnaryUnaryClientInterceptor):
- self._unary_unary_interceptors.append(interceptor)
- elif isinstance(interceptor, UnaryStreamClientInterceptor):
- self._unary_stream_interceptors.append(interceptor)
- elif isinstance(interceptor, StreamUnaryClientInterceptor):
- self._stream_unary_interceptors.append(interceptor)
- elif isinstance(interceptor, StreamStreamClientInterceptor):
- self._stream_stream_interceptors.append(interceptor)
- else:
- raise ValueError(
- "Interceptor {} must be ".format(interceptor)
- + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
- + "{} or ".format(UnaryStreamClientInterceptor.__name__)
- + "{} or ".format(StreamUnaryClientInterceptor.__name__)
- + "{}. ".format(StreamStreamClientInterceptor.__name__)
- )
- self._loop = cygrpc.get_working_loop()
- self._channel = cygrpc.AioChannel(
- _common.encode(target),
- _augment_channel_arguments(options, compression),
- credentials,
- self._loop,
- )
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._close(None)
- async def _close(self, grace): # pylint: disable=too-many-branches
- if self._channel.closed():
- return
- # No new calls will be accepted by the Cython channel.
- self._channel.closing()
- # Iterate through running tasks
- tasks = _all_tasks()
- calls = []
- call_tasks = []
- for task in tasks:
- try:
- stack = task.get_stack(limit=1)
- except AttributeError as attribute_error:
- # NOTE(lidiz) tl;dr: If the Task is created with a CPython
- # object, it will trigger AttributeError.
- #
- # In the global finalizer, the event loop schedules
- # a CPython PyAsyncGenAThrow object.
- # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
- #
- # However, the PyAsyncGenAThrow object is written in C and
- # failed to include the normal Python frame objects. Hence,
- # this exception is a false negative, and it is safe to ignore
- # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
- # but not available until 3.9 or 3.8.3. So, we have to keep it
- # for a while.
- # TODO(lidiz) drop this hack after 3.8 deprecation
- if "frame" in str(attribute_error):
- continue
- else:
- raise
- # If the Task is created by a C-extension, the stack will be empty.
- if not stack:
- continue
- # Locate ones created by `aio.Call`.
- frame = stack[0]
- candidate = frame.f_locals.get("self")
- # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:'
- # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which
- # might not always be the case.
- if candidate is not None:
- if isinstance(candidate, _base_call.Call):
- if hasattr(candidate, "_channel"):
- # For intercepted Call object
- if candidate._channel is not self._channel:
- continue
- elif hasattr(candidate, "_cython_call"):
- # For normal Call object
- if candidate._cython_call._channel is not self._channel:
- continue
- else:
- # Unidentified Call object
- raise cygrpc.InternalError(
- f"Unrecognized call object: {candidate}"
- )
- calls.append(candidate)
- call_tasks.append(task)
- # If needed, try to wait for them to finish.
- # Call objects are not always awaitables.
- if grace and call_tasks:
- await asyncio.wait(call_tasks, timeout=grace)
- # Time to cancel existing calls.
- for call in calls:
- call.cancel()
- # Destroy the channel
- self._channel.close()
- async def close(self, grace: Optional[float] = None):
- await self._close(grace)
- def __del__(self):
- if hasattr(self, "_channel"):
- if not self._channel.closed():
- self._channel.close()
- def get_state(
- self, try_to_connect: bool = False
- ) -> grpc.ChannelConnectivity:
- result = self._channel.check_connectivity_state(try_to_connect)
- return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
- async def wait_for_state_change(
- self,
- last_observed_state: grpc.ChannelConnectivity,
- ) -> None:
- assert await self._channel.watch_connectivity_state(
- last_observed_state.value[0], None
- )
- async def channel_ready(self) -> None:
- state = self.get_state(try_to_connect=True)
- while state != grpc.ChannelConnectivity.READY:
- await self.wait_for_state_change(state)
- state = self.get_state(try_to_connect=True)
- # TODO(xuanwn): Implement this method after we have
- # observability for Asyncio.
- def _get_registered_call_handle(self, method: str) -> int:
- pass
- # TODO(xuanwn): Implement _registered_method after we have
- # observability for Asyncio.
- # pylint: disable=arguments-differ,unused-argument
- def unary_unary(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None,
- _registered_method: Optional[bool] = False,
- ) -> UnaryUnaryMultiCallable:
- return UnaryUnaryMultiCallable(
- self._channel,
- _common.encode(method),
- request_serializer,
- response_deserializer,
- self._unary_unary_interceptors,
- [self],
- self._loop,
- )
- # TODO(xuanwn): Implement _registered_method after we have
- # observability for Asyncio.
- # pylint: disable=arguments-differ,unused-argument
- def unary_stream(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None,
- _registered_method: Optional[bool] = False,
- ) -> UnaryStreamMultiCallable:
- return UnaryStreamMultiCallable(
- self._channel,
- _common.encode(method),
- request_serializer,
- response_deserializer,
- self._unary_stream_interceptors,
- [self],
- self._loop,
- )
- # TODO(xuanwn): Implement _registered_method after we have
- # observability for Asyncio.
- # pylint: disable=arguments-differ,unused-argument
- def stream_unary(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None,
- _registered_method: Optional[bool] = False,
- ) -> StreamUnaryMultiCallable:
- return StreamUnaryMultiCallable(
- self._channel,
- _common.encode(method),
- request_serializer,
- response_deserializer,
- self._stream_unary_interceptors,
- [self],
- self._loop,
- )
- # TODO(xuanwn): Implement _registered_method after we have
- # observability for Asyncio.
- # pylint: disable=arguments-differ,unused-argument
- def stream_stream(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None,
- _registered_method: Optional[bool] = False,
- ) -> StreamStreamMultiCallable:
- return StreamStreamMultiCallable(
- self._channel,
- _common.encode(method),
- request_serializer,
- response_deserializer,
- self._stream_stream_interceptors,
- [self],
- self._loop,
- )
- def insecure_channel(
- target: str,
- options: Optional[ChannelArgumentType] = None,
- compression: Optional[grpc.Compression] = None,
- interceptors: Optional[Sequence[ClientInterceptor]] = None,
- ):
- """Creates an insecure asynchronous Channel to a server.
- Args:
- target: The server address
- options: An optional list of key-value pairs (:term:`channel_arguments`
- in gRPC Core runtime) to configure the channel.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional sequence of interceptors that will be executed for
- any call executed with this channel.
- Returns:
- A Channel.
- """
- return Channel(
- target,
- () if options is None else options,
- None,
- compression,
- interceptors,
- )
- def secure_channel(
- target: str,
- credentials: grpc.ChannelCredentials,
- options: Optional[ChannelArgumentType] = None,
- compression: Optional[grpc.Compression] = None,
- interceptors: Optional[Sequence[ClientInterceptor]] = None,
- ):
- """Creates a secure asynchronous Channel to a server.
- Args:
- target: The server address.
- credentials: A ChannelCredentials instance.
- options: An optional list of key-value pairs (:term:`channel_arguments`
- in gRPC Core runtime) to configure the channel.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional sequence of interceptors that will be executed for
- any call executed with this channel.
- Returns:
- An aio.Channel.
- """
- return Channel(
- target,
- () if options is None else options,
- credentials._credentials,
- compression,
- interceptors,
- )
|