_channel.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright 2019 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Invocation-side implementation of gRPC Asyncio Python."""
  15. import asyncio
  16. import sys
  17. from typing import Any, Iterable, List, Optional, Sequence
  18. import grpc
  19. from grpc import _common
  20. from grpc import _compression
  21. from grpc import _grpcio_metadata
  22. from grpc._cython import cygrpc
  23. from . import _base_call
  24. from . import _base_channel
  25. from ._call import StreamStreamCall
  26. from ._call import StreamUnaryCall
  27. from ._call import UnaryStreamCall
  28. from ._call import UnaryUnaryCall
  29. from ._interceptor import ClientInterceptor
  30. from ._interceptor import InterceptedStreamStreamCall
  31. from ._interceptor import InterceptedStreamUnaryCall
  32. from ._interceptor import InterceptedUnaryStreamCall
  33. from ._interceptor import InterceptedUnaryUnaryCall
  34. from ._interceptor import StreamStreamClientInterceptor
  35. from ._interceptor import StreamUnaryClientInterceptor
  36. from ._interceptor import UnaryStreamClientInterceptor
  37. from ._interceptor import UnaryUnaryClientInterceptor
  38. from ._metadata import Metadata
  39. from ._typing import ChannelArgumentType
  40. from ._typing import DeserializingFunction
  41. from ._typing import MetadataType
  42. from ._typing import RequestIterableType
  43. from ._typing import RequestType
  44. from ._typing import ResponseType
  45. from ._typing import SerializingFunction
  46. from ._utils import _timeout_to_deadline
  47. _USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
  48. if sys.version_info[1] < 7:
  49. def _all_tasks() -> Iterable[asyncio.Task]:
  50. return asyncio.Task.all_tasks() # pylint: disable=no-member
  51. else:
  52. def _all_tasks() -> Iterable[asyncio.Task]:
  53. return asyncio.all_tasks()
  54. def _augment_channel_arguments(
  55. base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
  56. ):
  57. compression_channel_argument = _compression.create_channel_option(
  58. compression
  59. )
  60. user_agent_channel_argument = (
  61. (
  62. cygrpc.ChannelArgKey.primary_user_agent_string,
  63. _USER_AGENT,
  64. ),
  65. )
  66. return (
  67. tuple(base_options)
  68. + compression_channel_argument
  69. + user_agent_channel_argument
  70. )
  71. class _BaseMultiCallable:
  72. """Base class of all multi callable objects.
  73. Handles the initialization logic and stores common attributes.
  74. """
  75. _loop: asyncio.AbstractEventLoop
  76. _channel: cygrpc.AioChannel
  77. _method: bytes
  78. _request_serializer: SerializingFunction
  79. _response_deserializer: DeserializingFunction
  80. _interceptors: Optional[Sequence[ClientInterceptor]]
  81. _references: List[Any]
  82. _loop: asyncio.AbstractEventLoop
  83. # pylint: disable=too-many-arguments
  84. def __init__(
  85. self,
  86. channel: cygrpc.AioChannel,
  87. method: bytes,
  88. request_serializer: SerializingFunction,
  89. response_deserializer: DeserializingFunction,
  90. interceptors: Optional[Sequence[ClientInterceptor]],
  91. references: List[Any],
  92. loop: asyncio.AbstractEventLoop,
  93. ) -> None:
  94. self._loop = loop
  95. self._channel = channel
  96. self._method = method
  97. self._request_serializer = request_serializer
  98. self._response_deserializer = response_deserializer
  99. self._interceptors = interceptors
  100. self._references = references
  101. @staticmethod
  102. def _init_metadata(
  103. metadata: Optional[MetadataType] = None,
  104. compression: Optional[grpc.Compression] = None,
  105. ) -> Metadata:
  106. """Based on the provided values for <metadata> or <compression> initialise the final
  107. metadata, as it should be used for the current call.
  108. """
  109. metadata = metadata or Metadata()
  110. if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
  111. metadata = Metadata.from_tuple(metadata)
  112. if compression:
  113. metadata = Metadata(
  114. *_compression.augment_metadata(metadata, compression)
  115. )
  116. return metadata
  117. class UnaryUnaryMultiCallable(
  118. _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
  119. ):
  120. def __call__(
  121. self,
  122. request: RequestType,
  123. *,
  124. timeout: Optional[float] = None,
  125. metadata: Optional[MetadataType] = None,
  126. credentials: Optional[grpc.CallCredentials] = None,
  127. wait_for_ready: Optional[bool] = None,
  128. compression: Optional[grpc.Compression] = None,
  129. ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
  130. metadata = self._init_metadata(metadata, compression)
  131. if not self._interceptors:
  132. call = UnaryUnaryCall(
  133. request,
  134. _timeout_to_deadline(timeout),
  135. metadata,
  136. credentials,
  137. wait_for_ready,
  138. self._channel,
  139. self._method,
  140. self._request_serializer,
  141. self._response_deserializer,
  142. self._loop,
  143. )
  144. else:
  145. call = InterceptedUnaryUnaryCall(
  146. self._interceptors,
  147. request,
  148. timeout,
  149. metadata,
  150. credentials,
  151. wait_for_ready,
  152. self._channel,
  153. self._method,
  154. self._request_serializer,
  155. self._response_deserializer,
  156. self._loop,
  157. )
  158. return call
  159. class UnaryStreamMultiCallable(
  160. _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
  161. ):
  162. def __call__(
  163. self,
  164. request: RequestType,
  165. *,
  166. timeout: Optional[float] = None,
  167. metadata: Optional[MetadataType] = None,
  168. credentials: Optional[grpc.CallCredentials] = None,
  169. wait_for_ready: Optional[bool] = None,
  170. compression: Optional[grpc.Compression] = None,
  171. ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
  172. metadata = self._init_metadata(metadata, compression)
  173. if not self._interceptors:
  174. call = UnaryStreamCall(
  175. request,
  176. _timeout_to_deadline(timeout),
  177. metadata,
  178. credentials,
  179. wait_for_ready,
  180. self._channel,
  181. self._method,
  182. self._request_serializer,
  183. self._response_deserializer,
  184. self._loop,
  185. )
  186. else:
  187. call = InterceptedUnaryStreamCall(
  188. self._interceptors,
  189. request,
  190. timeout,
  191. metadata,
  192. credentials,
  193. wait_for_ready,
  194. self._channel,
  195. self._method,
  196. self._request_serializer,
  197. self._response_deserializer,
  198. self._loop,
  199. )
  200. return call
  201. class StreamUnaryMultiCallable(
  202. _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
  203. ):
  204. def __call__(
  205. self,
  206. request_iterator: Optional[RequestIterableType] = None,
  207. timeout: Optional[float] = None,
  208. metadata: Optional[MetadataType] = None,
  209. credentials: Optional[grpc.CallCredentials] = None,
  210. wait_for_ready: Optional[bool] = None,
  211. compression: Optional[grpc.Compression] = None,
  212. ) -> _base_call.StreamUnaryCall:
  213. metadata = self._init_metadata(metadata, compression)
  214. if not self._interceptors:
  215. call = StreamUnaryCall(
  216. request_iterator,
  217. _timeout_to_deadline(timeout),
  218. metadata,
  219. credentials,
  220. wait_for_ready,
  221. self._channel,
  222. self._method,
  223. self._request_serializer,
  224. self._response_deserializer,
  225. self._loop,
  226. )
  227. else:
  228. call = InterceptedStreamUnaryCall(
  229. self._interceptors,
  230. request_iterator,
  231. timeout,
  232. metadata,
  233. credentials,
  234. wait_for_ready,
  235. self._channel,
  236. self._method,
  237. self._request_serializer,
  238. self._response_deserializer,
  239. self._loop,
  240. )
  241. return call
  242. class StreamStreamMultiCallable(
  243. _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
  244. ):
  245. def __call__(
  246. self,
  247. request_iterator: Optional[RequestIterableType] = None,
  248. timeout: Optional[float] = None,
  249. metadata: Optional[MetadataType] = None,
  250. credentials: Optional[grpc.CallCredentials] = None,
  251. wait_for_ready: Optional[bool] = None,
  252. compression: Optional[grpc.Compression] = None,
  253. ) -> _base_call.StreamStreamCall:
  254. metadata = self._init_metadata(metadata, compression)
  255. if not self._interceptors:
  256. call = StreamStreamCall(
  257. request_iterator,
  258. _timeout_to_deadline(timeout),
  259. metadata,
  260. credentials,
  261. wait_for_ready,
  262. self._channel,
  263. self._method,
  264. self._request_serializer,
  265. self._response_deserializer,
  266. self._loop,
  267. )
  268. else:
  269. call = InterceptedStreamStreamCall(
  270. self._interceptors,
  271. request_iterator,
  272. timeout,
  273. metadata,
  274. credentials,
  275. wait_for_ready,
  276. self._channel,
  277. self._method,
  278. self._request_serializer,
  279. self._response_deserializer,
  280. self._loop,
  281. )
  282. return call
  283. class Channel(_base_channel.Channel):
  284. _loop: asyncio.AbstractEventLoop
  285. _channel: cygrpc.AioChannel
  286. _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
  287. _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
  288. _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
  289. _stream_stream_interceptors: List[StreamStreamClientInterceptor]
  290. def __init__(
  291. self,
  292. target: str,
  293. options: ChannelArgumentType,
  294. credentials: Optional[grpc.ChannelCredentials],
  295. compression: Optional[grpc.Compression],
  296. interceptors: Optional[Sequence[ClientInterceptor]],
  297. ):
  298. """Constructor.
  299. Args:
  300. target: The target to which to connect.
  301. options: Configuration options for the channel.
  302. credentials: A cygrpc.ChannelCredentials or None.
  303. compression: An optional value indicating the compression method to be
  304. used over the lifetime of the channel.
  305. interceptors: An optional list of interceptors that would be used for
  306. intercepting any RPC executed with that channel.
  307. """
  308. self._unary_unary_interceptors = []
  309. self._unary_stream_interceptors = []
  310. self._stream_unary_interceptors = []
  311. self._stream_stream_interceptors = []
  312. if interceptors is not None:
  313. for interceptor in interceptors:
  314. if isinstance(interceptor, UnaryUnaryClientInterceptor):
  315. self._unary_unary_interceptors.append(interceptor)
  316. elif isinstance(interceptor, UnaryStreamClientInterceptor):
  317. self._unary_stream_interceptors.append(interceptor)
  318. elif isinstance(interceptor, StreamUnaryClientInterceptor):
  319. self._stream_unary_interceptors.append(interceptor)
  320. elif isinstance(interceptor, StreamStreamClientInterceptor):
  321. self._stream_stream_interceptors.append(interceptor)
  322. else:
  323. raise ValueError(
  324. "Interceptor {} must be ".format(interceptor)
  325. + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
  326. + "{} or ".format(UnaryStreamClientInterceptor.__name__)
  327. + "{} or ".format(StreamUnaryClientInterceptor.__name__)
  328. + "{}. ".format(StreamStreamClientInterceptor.__name__)
  329. )
  330. self._loop = cygrpc.get_working_loop()
  331. self._channel = cygrpc.AioChannel(
  332. _common.encode(target),
  333. _augment_channel_arguments(options, compression),
  334. credentials,
  335. self._loop,
  336. )
  337. async def __aenter__(self):
  338. return self
  339. async def __aexit__(self, exc_type, exc_val, exc_tb):
  340. await self._close(None)
  341. async def _close(self, grace): # pylint: disable=too-many-branches
  342. if self._channel.closed():
  343. return
  344. # No new calls will be accepted by the Cython channel.
  345. self._channel.closing()
  346. # Iterate through running tasks
  347. tasks = _all_tasks()
  348. calls = []
  349. call_tasks = []
  350. for task in tasks:
  351. try:
  352. stack = task.get_stack(limit=1)
  353. except AttributeError as attribute_error:
  354. # NOTE(lidiz) tl;dr: If the Task is created with a CPython
  355. # object, it will trigger AttributeError.
  356. #
  357. # In the global finalizer, the event loop schedules
  358. # a CPython PyAsyncGenAThrow object.
  359. # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
  360. #
  361. # However, the PyAsyncGenAThrow object is written in C and
  362. # failed to include the normal Python frame objects. Hence,
  363. # this exception is a false negative, and it is safe to ignore
  364. # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
  365. # but not available until 3.9 or 3.8.3. So, we have to keep it
  366. # for a while.
  367. # TODO(lidiz) drop this hack after 3.8 deprecation
  368. if "frame" in str(attribute_error):
  369. continue
  370. else:
  371. raise
  372. # If the Task is created by a C-extension, the stack will be empty.
  373. if not stack:
  374. continue
  375. # Locate ones created by `aio.Call`.
  376. frame = stack[0]
  377. candidate = frame.f_locals.get("self")
  378. # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:'
  379. # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which
  380. # might not always be the case.
  381. if candidate is not None:
  382. if isinstance(candidate, _base_call.Call):
  383. if hasattr(candidate, "_channel"):
  384. # For intercepted Call object
  385. if candidate._channel is not self._channel:
  386. continue
  387. elif hasattr(candidate, "_cython_call"):
  388. # For normal Call object
  389. if candidate._cython_call._channel is not self._channel:
  390. continue
  391. else:
  392. # Unidentified Call object
  393. raise cygrpc.InternalError(
  394. f"Unrecognized call object: {candidate}"
  395. )
  396. calls.append(candidate)
  397. call_tasks.append(task)
  398. # If needed, try to wait for them to finish.
  399. # Call objects are not always awaitables.
  400. if grace and call_tasks:
  401. await asyncio.wait(call_tasks, timeout=grace)
  402. # Time to cancel existing calls.
  403. for call in calls:
  404. call.cancel()
  405. # Destroy the channel
  406. self._channel.close()
  407. async def close(self, grace: Optional[float] = None):
  408. await self._close(grace)
  409. def __del__(self):
  410. if hasattr(self, "_channel"):
  411. if not self._channel.closed():
  412. self._channel.close()
  413. def get_state(
  414. self, try_to_connect: bool = False
  415. ) -> grpc.ChannelConnectivity:
  416. result = self._channel.check_connectivity_state(try_to_connect)
  417. return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
  418. async def wait_for_state_change(
  419. self,
  420. last_observed_state: grpc.ChannelConnectivity,
  421. ) -> None:
  422. assert await self._channel.watch_connectivity_state(
  423. last_observed_state.value[0], None
  424. )
  425. async def channel_ready(self) -> None:
  426. state = self.get_state(try_to_connect=True)
  427. while state != grpc.ChannelConnectivity.READY:
  428. await self.wait_for_state_change(state)
  429. state = self.get_state(try_to_connect=True)
  430. # TODO(xuanwn): Implement this method after we have
  431. # observability for Asyncio.
  432. def _get_registered_call_handle(self, method: str) -> int:
  433. pass
  434. # TODO(xuanwn): Implement _registered_method after we have
  435. # observability for Asyncio.
  436. # pylint: disable=arguments-differ,unused-argument
  437. def unary_unary(
  438. self,
  439. method: str,
  440. request_serializer: Optional[SerializingFunction] = None,
  441. response_deserializer: Optional[DeserializingFunction] = None,
  442. _registered_method: Optional[bool] = False,
  443. ) -> UnaryUnaryMultiCallable:
  444. return UnaryUnaryMultiCallable(
  445. self._channel,
  446. _common.encode(method),
  447. request_serializer,
  448. response_deserializer,
  449. self._unary_unary_interceptors,
  450. [self],
  451. self._loop,
  452. )
  453. # TODO(xuanwn): Implement _registered_method after we have
  454. # observability for Asyncio.
  455. # pylint: disable=arguments-differ,unused-argument
  456. def unary_stream(
  457. self,
  458. method: str,
  459. request_serializer: Optional[SerializingFunction] = None,
  460. response_deserializer: Optional[DeserializingFunction] = None,
  461. _registered_method: Optional[bool] = False,
  462. ) -> UnaryStreamMultiCallable:
  463. return UnaryStreamMultiCallable(
  464. self._channel,
  465. _common.encode(method),
  466. request_serializer,
  467. response_deserializer,
  468. self._unary_stream_interceptors,
  469. [self],
  470. self._loop,
  471. )
  472. # TODO(xuanwn): Implement _registered_method after we have
  473. # observability for Asyncio.
  474. # pylint: disable=arguments-differ,unused-argument
  475. def stream_unary(
  476. self,
  477. method: str,
  478. request_serializer: Optional[SerializingFunction] = None,
  479. response_deserializer: Optional[DeserializingFunction] = None,
  480. _registered_method: Optional[bool] = False,
  481. ) -> StreamUnaryMultiCallable:
  482. return StreamUnaryMultiCallable(
  483. self._channel,
  484. _common.encode(method),
  485. request_serializer,
  486. response_deserializer,
  487. self._stream_unary_interceptors,
  488. [self],
  489. self._loop,
  490. )
  491. # TODO(xuanwn): Implement _registered_method after we have
  492. # observability for Asyncio.
  493. # pylint: disable=arguments-differ,unused-argument
  494. def stream_stream(
  495. self,
  496. method: str,
  497. request_serializer: Optional[SerializingFunction] = None,
  498. response_deserializer: Optional[DeserializingFunction] = None,
  499. _registered_method: Optional[bool] = False,
  500. ) -> StreamStreamMultiCallable:
  501. return StreamStreamMultiCallable(
  502. self._channel,
  503. _common.encode(method),
  504. request_serializer,
  505. response_deserializer,
  506. self._stream_stream_interceptors,
  507. [self],
  508. self._loop,
  509. )
  510. def insecure_channel(
  511. target: str,
  512. options: Optional[ChannelArgumentType] = None,
  513. compression: Optional[grpc.Compression] = None,
  514. interceptors: Optional[Sequence[ClientInterceptor]] = None,
  515. ):
  516. """Creates an insecure asynchronous Channel to a server.
  517. Args:
  518. target: The server address
  519. options: An optional list of key-value pairs (:term:`channel_arguments`
  520. in gRPC Core runtime) to configure the channel.
  521. compression: An optional value indicating the compression method to be
  522. used over the lifetime of the channel.
  523. interceptors: An optional sequence of interceptors that will be executed for
  524. any call executed with this channel.
  525. Returns:
  526. A Channel.
  527. """
  528. return Channel(
  529. target,
  530. () if options is None else options,
  531. None,
  532. compression,
  533. interceptors,
  534. )
  535. def secure_channel(
  536. target: str,
  537. credentials: grpc.ChannelCredentials,
  538. options: Optional[ChannelArgumentType] = None,
  539. compression: Optional[grpc.Compression] = None,
  540. interceptors: Optional[Sequence[ClientInterceptor]] = None,
  541. ):
  542. """Creates a secure asynchronous Channel to a server.
  543. Args:
  544. target: The server address.
  545. credentials: A ChannelCredentials instance.
  546. options: An optional list of key-value pairs (:term:`channel_arguments`
  547. in gRPC Core runtime) to configure the channel.
  548. compression: An optional value indicating the compression method to be
  549. used over the lifetime of the channel.
  550. interceptors: An optional sequence of interceptors that will be executed for
  551. any call executed with this channel.
  552. Returns:
  553. An aio.Channel.
  554. """
  555. return Channel(
  556. target,
  557. () if options is None else options,
  558. credentials._credentials,
  559. compression,
  560. interceptors,
  561. )