_plugin_wrapping.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright 2015 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import logging
  16. import threading
  17. from typing import Callable, Optional, Type
  18. import grpc
  19. from grpc import _common
  20. from grpc._cython import cygrpc
  21. from grpc._typing import MetadataType
  22. _LOGGER = logging.getLogger(__name__)
  23. class _AuthMetadataContext(
  24. collections.namedtuple(
  25. "AuthMetadataContext",
  26. (
  27. "service_url",
  28. "method_name",
  29. ),
  30. ),
  31. grpc.AuthMetadataContext,
  32. ):
  33. pass
  34. class _CallbackState(object):
  35. def __init__(self):
  36. self.lock = threading.Lock()
  37. self.called = False
  38. self.exception = None
  39. class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
  40. _state: _CallbackState
  41. _callback: Callable
  42. def __init__(self, state: _CallbackState, callback: Callable):
  43. self._state = state
  44. self._callback = callback
  45. def __call__(
  46. self, metadata: MetadataType, error: Optional[Type[BaseException]]
  47. ):
  48. with self._state.lock:
  49. if self._state.exception is None:
  50. if self._state.called:
  51. raise RuntimeError(
  52. "AuthMetadataPluginCallback invoked more than once!"
  53. )
  54. else:
  55. self._state.called = True
  56. else:
  57. raise RuntimeError(
  58. 'AuthMetadataPluginCallback raised exception "{}"!'.format(
  59. self._state.exception
  60. )
  61. )
  62. if error is None:
  63. self._callback(metadata, cygrpc.StatusCode.ok, None)
  64. else:
  65. self._callback(
  66. None, cygrpc.StatusCode.internal, _common.encode(str(error))
  67. )
  68. class _Plugin(object):
  69. _metadata_plugin: grpc.AuthMetadataPlugin
  70. def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
  71. self._metadata_plugin = metadata_plugin
  72. self._stored_ctx = None
  73. try:
  74. import contextvars # pylint: disable=wrong-import-position
  75. # The plugin may be invoked on a thread created by Core, which will not
  76. # have the context propagated. This context is stored and installed in
  77. # the thread invoking the plugin.
  78. self._stored_ctx = contextvars.copy_context()
  79. except ImportError:
  80. # Support versions predating contextvars.
  81. pass
  82. def __call__(self, service_url: str, method_name: str, callback: Callable):
  83. context = _AuthMetadataContext(
  84. _common.decode(service_url), _common.decode(method_name)
  85. )
  86. callback_state = _CallbackState()
  87. try:
  88. self._metadata_plugin(
  89. context, _AuthMetadataPluginCallback(callback_state, callback)
  90. )
  91. except Exception as exception: # pylint: disable=broad-except
  92. _LOGGER.exception(
  93. 'AuthMetadataPluginCallback "%s" raised exception!',
  94. self._metadata_plugin,
  95. )
  96. with callback_state.lock:
  97. callback_state.exception = exception
  98. if callback_state.called:
  99. return
  100. callback(
  101. None, cygrpc.StatusCode.internal, _common.encode(str(exception))
  102. )
  103. def metadata_plugin_call_credentials(
  104. metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str]
  105. ) -> grpc.CallCredentials:
  106. if name is None:
  107. try:
  108. effective_name = metadata_plugin.__name__
  109. except AttributeError:
  110. effective_name = metadata_plugin.__class__.__name__
  111. else:
  112. effective_name = name
  113. return grpc.CallCredentials(
  114. cygrpc.MetadataPluginCallCredentials(
  115. _Plugin(metadata_plugin), _common.encode(effective_name)
  116. )
  117. )