_auth.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright 2016 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """GRPCAuthMetadataPlugins for standard authentication."""
  15. import inspect
  16. from typing import Any, Optional
  17. import grpc
  18. def _sign_request(
  19. callback: grpc.AuthMetadataPluginCallback,
  20. token: Optional[str],
  21. error: Optional[Exception],
  22. ):
  23. metadata = (("authorization", "Bearer {}".format(token)),)
  24. callback(metadata, error)
  25. class GoogleCallCredentials(grpc.AuthMetadataPlugin):
  26. """Metadata wrapper for GoogleCredentials from the oauth2client library."""
  27. _is_jwt: bool
  28. _credentials: Any
  29. # TODO(xuanwn): Give credentials an actual type.
  30. def __init__(self, credentials: Any):
  31. self._credentials = credentials
  32. # Hack to determine if these are JWT creds and we need to pass
  33. # additional_claims when getting a token
  34. self._is_jwt = (
  35. "additional_claims"
  36. in inspect.getfullargspec(credentials.get_access_token).args
  37. )
  38. def __call__(
  39. self,
  40. context: grpc.AuthMetadataContext,
  41. callback: grpc.AuthMetadataPluginCallback,
  42. ):
  43. try:
  44. if self._is_jwt:
  45. access_token = self._credentials.get_access_token(
  46. additional_claims={
  47. "aud": context.service_url # pytype: disable=attribute-error
  48. }
  49. ).access_token
  50. else:
  51. access_token = self._credentials.get_access_token().access_token
  52. except Exception as exception: # pylint: disable=broad-except
  53. _sign_request(callback, None, exception)
  54. else:
  55. _sign_request(callback, access_token, None)
  56. class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
  57. """Metadata wrapper for raw access token credentials."""
  58. _access_token: str
  59. def __init__(self, access_token: str):
  60. self._access_token = access_token
  61. def __call__(
  62. self,
  63. context: grpc.AuthMetadataContext,
  64. callback: grpc.AuthMetadataPluginCallback,
  65. ):
  66. _sign_request(callback, self._access_token, None)