Skip to content
This repository has been archived by the owner on Feb 23, 2024. It is now read-only.

feat: support self-signed jwt #107

Closed
wants to merge 12 commits into from
Prev Previous commit
Next Next commit
fix: move api_core, auth version detection into base transport
  • Loading branch information
busunkim96 committed Feb 11, 2021
commit 2f16d21ddc6366caccb13016f5d17785d90dbccf
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import abc
import typing
import packaging.version
import pkg_resources

from google import auth # type: ignore
import google.api_core
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
Expand All @@ -37,6 +39,17 @@
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class TranslationServiceTransport(abc.ABC):
"""Abstract transport class for TranslationService."""
Expand All @@ -53,7 +66,7 @@ def __init__(
host: str = DEFAULT_HOST,
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
scopes: typing.Optional[typing.Sequence[str]] = None,
quota_project_id: typing.Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
**kwargs,
Expand All @@ -70,7 +83,7 @@ def __init__(
credentials_file (Optional[str]): A file with credentials that can
be loaded with :func:`google.auth.load_credentials_from_file`.
This argument is mutually exclusive with credentials.
scope (Optional[Sequence[str]]): A list of scopes.
scopes (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -84,6 +97,21 @@ def __init__(
host += ":443"
self._host = host

# If a custom API endpoint is set, set scopes to ensure the auth
# library does not used the self-signed JWT flow for service
# accounts
if host.split(":")[0] != self.DEFAULT_HOST and not scopes:
scopes = self.AUTH_SCOPES

# TODO: Remove this if/else once google-auth >= 1.25.0 is required
if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or self.AUTH_SCOPES}

# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
Expand All @@ -93,12 +121,12 @@ def __init__(

if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
credentials_file, scopes=scopes, quota_project_id=quota_project_id
credentials_file, **scopes_kwargs, quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(
scopes=scopes, quota_project_id=quota_project_id
**scopes_kwargs, quota_project_id=quota_project_id
)

# Save the credentials.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple

import google.api_core
from google.api_core import grpc_helpers # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version
import packaging
import pkg_resources

import grpc # type: ignore
Expand All @@ -34,19 +33,9 @@
from google.longrunning import operations_pb2 as operations # type: ignore

from .base import TranslationServiceTransport, DEFAULT_CLIENT_INFO
from .base import _GOOGLE_AUTH_VERSION, _API_CORE_VERSION


try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class TranslationServiceGrpcTransport(TranslationServiceTransport):
"""gRPC backend transport for TranslationService.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
#
Expand Down Expand Up @@ -33,10 +33,8 @@
from google.longrunning import operations_pb2 as operations # type: ignore

from .base import TranslationServiceTransport, DEFAULT_CLIENT_INFO
from .base import _GOOGLE_AUTH_VERSION, _API_CORE_VERSION
from .grpc import TranslationServiceGrpcTransport
from .grpc import _API_CORE_VERSION
from .grpc import _GOOGLE_AUTH_VERSION


class TranslationServiceGrpcAsyncIOTransport(TranslationServiceTransport):
"""gRPC AsyncIO backend transport for TranslationService.
Expand Down
41 changes: 36 additions & 5 deletions tests/unit/gapic/translate_v3/test_translation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
)
from google.cloud.translate_v3.services.translation_service import pagers
from google.cloud.translate_v3.services.translation_service import transports
from google.cloud.translate_v3.services.translation_service.transports.grpc import (
from google.cloud.translate_v3.services.translation_service.transports.base import (
_GOOGLE_AUTH_VERSION,
)
from google.cloud.translate_v3.services.translation_service.transports.grpc import (
from google.cloud.translate_v3.services.translation_service.transports.base import (
_API_CORE_VERSION,
)
from google.cloud.translate_v3.types import translation_service
Expand Down Expand Up @@ -2396,8 +2396,32 @@ def test_translation_service_base_transport():
with pytest.raises(NotImplementedError):
transport.operations_client


@requires_google_auth_gte_1_25_0
def test_translation_service_base_transport_with_credentials_file():
# Instantiate the base transport with a credentials file
with mock.patch.object(
auth, "load_credentials_from_file"
) as load_creds, mock.patch(
"google.cloud.translate_v3.services.translation_service.transports.TranslationServiceTransport._prep_wrapped_messages"
) as Transport:
Transport.return_value = None
load_creds.return_value = (credentials.AnonymousCredentials(), None)
transport = transports.TranslationServiceTransport(
credentials_file="credentials.json", quota_project_id="octopus",
)
load_creds.assert_called_once_with(
"credentials.json",
scopes=None,
default_scopes=(
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/cloud-translation",
),
quota_project_id="octopus",
)


@requires_google_auth_lt_1_25_0
def test_translation_service_base_transport_with_credentials_file_old_google_auth():
# Instantiate the base transport with a credentials file
with mock.patch.object(
auth, "load_credentials_from_file"
Expand Down Expand Up @@ -2461,13 +2485,20 @@ def test_translation_service_auth_adc_old_google_auth():
)


@pytest.mark.parametrize(
"transport_class",
[
transports.TranslationServiceGrpcTransport,
transports.TranslationServiceGrpcAsyncIOTransport,
],
)
@requires_google_auth_gte_1_25_0
def test_translation_service_transport_auth_adc():
def test_translation_service_transport_auth_adc(transport_class):
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(auth, "default", autospec=True) as adc:
adc.return_value = (credentials.AnonymousCredentials(), None)
transports.TranslationServiceGrpcTransport(
transport_class(
quota_project_id="octopus", scopes=["1", "2"]
)
adc.assert_called_once_with(
Expand Down