Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mTLS support to generator #359

Merged
merged 11 commits into from
Apr 8, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

{% block content %}
from collections import OrderedDict
import re
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources

Expand All @@ -23,37 +24,6 @@ from .transports.base import {{ service.name }}Transport
from .transports.grpc import {{ service.name }}GrpcTransport


def _get_default_mtls_endpoint(api_endpoint):
"""Convert api endpoint to mTLS endpoint.
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
api_endpoint (Optional[str]): the api endpoint to convert.
Returns:
str: converted mTLS api endpoint.
"""
if (
api_endpoint is None
or api_endpoint.find("mtls.sandbox.googleapis.com") != -1
or api_endpoint.find("mtls.googleapis.com") != -1
or api_endpoint.find(".googleapis.com") == -1
):
# If the endpoint is already mTLS or the endpoint is not a googleapi,
# there is no need to convert.
return api_endpoint

if api_endpoint.find(".sandbox.googleapis.com") != -1:
return api_endpoint.replace(
".sandbox.googleapis.com", ".mtls.sandbox.googleapis.com"
)

return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")


_DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
_DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint(_DEFAULT_ENDPOINT)


class {{ service.client_name }}Meta(type):
"""Metaclass for the {{ service.name }} client.

Expand Down Expand Up @@ -88,7 +58,40 @@ class {{ service.client_name }}Meta(type):
class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""

DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=_DEFAULT_ENDPOINT)
@staticmethod
def _get_default_mtls_endpoint(api_endpoint):
"""Convert api endpoint to mTLS endpoint.
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
api_endpoint (Optional[str]): the api endpoint to convert.
Returns:
str: converted mTLS api endpoint.
"""
if not api_endpoint:
return api_endpoint

mtls_endpoint_re = re.compile(
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
)

m = mtls_endpoint_re.match(api_endpoint)
name, mtls, sandbox, googledomain = m.groups()
if mtls or not googledomain:
return api_endpoint

if sandbox:
return api_endpoint.replace(
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
)

return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")

DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
Expand Down Expand Up @@ -156,7 +159,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

# Set default api endpoint if not set.
if client_options.api_endpoint is None:
client_options.api_endpoint = _DEFAULT_ENDPOINT
client_options.api_endpoint = {{ service.client_name }}.DEFAULT_ENDPOINT
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
Expand All @@ -168,7 +171,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
'provide its credentials directly.')
self._transport = transport
elif transport is not None or (
client_options.api_endpoint == _DEFAULT_ENDPOINT
client_options.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
and client_options.client_cert_source is None
):
# Don't trigger mTLS.
Expand All @@ -179,11 +182,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
else:
# Trigger mTLS. If the user overrides endpoint, use it as the mTLS
# endpoint, otherwise use the default mTLS endpoint.
api_mtls_endpoint = (
(client_options.api_endpoint != _DEFAULT_ENDPOINT)
and client_options.api_endpoint
or _DEFAULT_MTLS_ENDPOINT
)
if client_options.api_endpoint != {{ service.client_name }}.DEFAULT_ENDPOINT:
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
api_mtls_endpoint = client_options.api_endpoint
else:
api_mtls_endpoint = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

self._transport = {{ service.name }}GrpcTransport(
credentials=credentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
elif api_mtls_endpoint:
host = (
(":" in api_mtls_endpoint)
and api_mtls_endpoint
or (api_mtls_endpoint + ":443")
)
if ":" in api_mtls_endpoint:
host = api_mtls_endpoint
else:
host = api_mtls_endpoint + ":443"
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
Expand Down
32 changes: 17 additions & 15 deletions gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ from google.auth import credentials
from google.oauth2 import service_account
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _get_default_mtls_endpoint
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_ENDPOINT
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_MTLS_ENDPOINT
from google.api_core import client_options
from google.api_core import grpc_helpers
{% if service.has_lro -%}
Expand Down Expand Up @@ -45,12 +42,12 @@ def test__get_default_mtls_endpoint():
sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com"
non_googleapi = "api.example.com"

assert _get_default_mtls_endpoint(None) == None
assert _get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint
assert _get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint
assert _get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint
assert _get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint
assert _get_default_mtls_endpoint(non_googleapi) == non_googleapi
assert {{ service.client_name }}._get_default_mtls_endpoint(None) == None
assert {{ service.client_name }}._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint
assert {{ service.client_name }}._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint
assert {{ service.client_name }}._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint
assert {{ service.client_name }}._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi


def test_{{ service.client_name|snake_case }}_from_service_account_file():
Expand All @@ -69,7 +66,7 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
def test_{{ service.client_name|snake_case }}_client_options():
# Check the default options have their expected values.
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == _DEFAULT_ENDPOINT
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT

# Check that if channel is provided we won't create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
Expand All @@ -86,7 +83,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client = {{ service.client_name }}(client_options=options)
transport.assert_called_once_with(
credentials=None,
host=_DEFAULT_ENDPOINT,
host={{ service.client_name }}.DEFAULT_ENDPOINT,
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
)

# Check mTLS is triggered with api endpoint override.
Expand All @@ -109,10 +106,10 @@ def test_{{ service.client_name|snake_case }}_client_options():
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint=_DEFAULT_MTLS_ENDPOINT,
api_mtls_endpoint={{ service.client_name }}.DEFAULT_MTLS_ENDPOINT,
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
client_cert_source=client_cert_source_callback,
credentials=None,
host=_DEFAULT_ENDPOINT,
host={{ service.client_name }}.DEFAULT_ENDPOINT,
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
)

def test_{{ service.client_name|snake_case }}_client_options_from_dict():
Expand Down Expand Up @@ -624,8 +621,13 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_c
assert transport.grpc_channel == mock_grpc_channel


@pytest.mark.parametrize(
"api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"]
)
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True)
def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(grpc_create_channel):
def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(
grpc_create_channel, api_mtls_endpoint
):
# Check that if channel and client_cert_source are None, but api_mtls_endpoint
# is provided, then a mTLS channel will be created with SSL ADC.
mock_grpc_channel = mock.Mock()
Expand All @@ -642,7 +644,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(grpc
transport = transports.{{ service.name }}GrpcTransport(
host="squid.clam.whelk",
credentials=mock_cred,
api_mtls_endpoint="mtls.squid.clam.whelk",
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=None,
)
grpc_create_channel.assert_called_once_with(
Expand Down