Skip to content

Commit

Permalink
fix: fix mTLS logic (#374)
Browse files Browse the repository at this point in the history
Previous PR triggers mTLS if client_options.api_endpoint is different 
from the default one, in this PR, we change the logic, mTLS is triggered 
if client_options.api_endpoint is provided
  • Loading branch information
arithmetic1728 authored Apr 10, 2020
1 parent a354629 commit e3c079b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
Expand Down Expand Up @@ -126,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = DEFAULT_OPTIONS,
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -143,12 +142,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``api_endpoint`` is
provided and different from the default endpoint, or the
``client_cert_source`` property is provided, mutual TLS
transport will be created if client SSL credentials are found.
Client SSL credentials are obtained from ``client_cert_source``
or application default SSL credentials.
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand All @@ -157,10 +154,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)

# Set default api endpoint if not set.
if client_options.api_endpoint is None:
client_options.api_endpoint = self.DEFAULT_ENDPOINT

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
Expand All @@ -170,24 +163,37 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
elif transport is not None or (
client_options.api_endpoint == self.DEFAULT_ENDPOINT
elif client_options is None or (
client_options.api_endpoint == None
and client_options.client_cert_source is None
):
# Don't trigger mTLS.
# Don't trigger mTLS if we get an empty ClientOptions.
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials, host=client_options.api_endpoint
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# Trigger mTLS. If the user overrides endpoint, use it as the mTLS
# endpoint, otherwise use the default mTLS endpoint.
option_endpoint = client_options.api_endpoint
api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT if option_endpoint == self.DEFAULT_ENDPOINT else option_endpoint
# We have a non-empty ClientOptions. If client_cert_source is
# provided, trigger mTLS with user provided endpoint or the default
# mTLS endpoint.
if client_options.client_cert_source:
api_mtls_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_MTLS_ENDPOINT
)
else:
api_mtls_endpoint = None

api_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_ENDPOINT
)

self._transport = {{ service.name }}GrpcTransport(
credentials=credentials,
host=client_options.api_endpoint,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=client_options.client_cert_source,
)
Expand Down
26 changes: 19 additions & 7 deletions gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():


def test_{{ service.client_name|snake_case }}_client_options():
# Check the default options have their expected values.
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT

# Check that if channel is provided we won't create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
transport = transports.{{ service.name }}GrpcTransport(
Expand All @@ -86,13 +82,14 @@ def test_{{ service.client_name|snake_case }}_client_options():
host=client.DEFAULT_ENDPOINT,
)

# Check mTLS is triggered with api endpoint override.
# Check mTLS is not triggered if api_endpoint is provided but
# client_cert_source is None.
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
api_mtls_endpoint=None,
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
Expand All @@ -112,14 +109,29 @@ def test_{{ service.client_name|snake_case }}_client_options():
host=client.DEFAULT_ENDPOINT,
)

# Check mTLS is triggered if api_endpoint and client_cert_source are provided.
options = client_options.ClientOptions(
api_endpoint="squid.clam.whelk",
client_cert_source=client_cert_source_callback
)
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=client_cert_source_callback,
credentials=None,
host="squid.clam.whelk",
)

def test_{{ service.client_name|snake_case }}_client_options_from_dict():
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(
client_options={'api_endpoint': 'squid.clam.whelk'}
)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
api_mtls_endpoint=None,
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
Expand Down

0 comments on commit e3c079b

Please sign in to comment.