Skip to content

Commit

Permalink
feat: enable self signed jwt for grpc (#920)
Browse files Browse the repository at this point in the history
* feat: enable self signe jwt for grpc

* update test

* update golden files
  • Loading branch information
arithmetic1728 authored Jun 16, 2021
1 parent 1e39f41 commit da119c7
Show file tree
Hide file tree
Showing 28 changed files with 149 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
{% endif %}
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

{% filter sort_lines %}
{% for method in service.methods.values() %}
Expand Down Expand Up @@ -75,6 +76,7 @@ class {{ service.name }}Transport(abc.ABC):
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -98,6 +100,8 @@ class {{ service.name }}Transport(abc.ABC):
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -124,6 +128,10 @@ class {{ service.name }}Transport(abc.ABC):
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def test_{{ service.client_name|snake_case }}_from_service_account_info(client_c
{% endif %}


@pytest.mark.parametrize("client_class", [
{{ service.client_name }},
{% if 'grpc' in opts.transport %}
{{ service.async_client_name }},
{% endif %}
])
def test_{{ service.client_name|snake_case }}_service_account_always_use_jwt(client_class):
with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_called_with(True)


@pytest.mark.parametrize("client_class", [
{{ service.client_name }},
{% if 'grpc' in opts.transport %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.asset_v1.types import asset_service
from google.longrunning import operations_pb2 # type: ignore
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -88,6 +90,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -114,6 +118,10 @@ def __init__(
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def test_asset_service_client_from_service_account_info(client_class):
assert client.transport._host == 'cloudasset.googleapis.com:443'


@pytest.mark.parametrize("client_class", [
AssetServiceClient,
AssetServiceAsyncClient,
])
def test_asset_service_client_service_account_always_use_jwt(client_class):
with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_called_with(True)


@pytest.mark.parametrize("client_class", [
AssetServiceClient,
AssetServiceAsyncClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.iam.credentials_v1.types import common

Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -85,6 +87,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -111,6 +115,10 @@ def __init__(
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def test_iam_credentials_client_from_service_account_info(client_class):
assert client.transport._host == 'iamcredentials.googleapis.com:443'


@pytest.mark.parametrize("client_class", [
IAMCredentialsClient,
IAMCredentialsAsyncClient,
])
def test_iam_credentials_client_service_account_always_use_jwt(client_class):
with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_called_with(True)


@pytest.mark.parametrize("client_class", [
IAMCredentialsClient,
IAMCredentialsAsyncClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.logging_v2.types import logging_config
from google.protobuf import empty_pb2 # type: ignore
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -89,6 +91,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -115,6 +119,10 @@ def __init__(
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.logging_v2.types import logging
from google.protobuf import empty_pb2 # type: ignore
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -90,6 +92,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -116,6 +120,10 @@ def __init__(
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.logging_v2.types import logging_metrics
from google.protobuf import empty_pb2 # type: ignore
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -90,6 +92,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -116,6 +120,10 @@ def __init__(
elif credentials is None:
credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# If the credentials is service account credentials, then always try to use self signed JWT.
if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, *,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def test_config_service_v2_client_from_service_account_info(client_class):
assert client.transport._host == 'logging.googleapis.com:443'


@pytest.mark.parametrize("client_class", [
ConfigServiceV2Client,
ConfigServiceV2AsyncClient,
])
def test_config_service_v2_client_service_account_always_use_jwt(client_class):
with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_called_with(True)


@pytest.mark.parametrize("client_class", [
ConfigServiceV2Client,
ConfigServiceV2AsyncClient,
Expand Down
Loading

0 comments on commit da119c7

Please sign in to comment.