Skip to content

Commit 41fa725

Browse files
feat: add GOOGLE_API_USE_MTLS support (#420)
Co-authored-by: Dov Shlachter <dovs@google.com>
1 parent 4957090 commit 41fa725

File tree

6 files changed

+272
-125
lines changed

6 files changed

+272
-125
lines changed

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
{% block content %}
44
from collections import OrderedDict
5+
import os
56
import re
67
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
78
import pkg_resources
@@ -11,6 +12,8 @@ from google.api_core import exceptions # type: ignore
1112
from google.api_core import gapic_v1 # type: ignore
1213
from google.api_core import retry as retries # type: ignore
1314
from google.auth import credentials # type: ignore
15+
from google.auth.transport import mtls # type: ignore
16+
from google.auth.exceptions import MutualTLSChannelError # type: ignore
1417
from google.oauth2 import service_account # type: ignore
1518

1619
{% filter sort_lines -%}
@@ -144,21 +147,47 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
144147
transport (Union[str, ~.{{ service.name }}Transport]): The
145148
transport to use. If set to None, a transport is chosen
146149
automatically.
147-
client_options (ClientOptions): Custom options for the client.
150+
client_options (ClientOptions): Custom options for the client. It
151+
won't take effect unless ``transport`` is None.
148152
(1) The ``api_endpoint`` property can be used to override the
149-
default endpoint provided by the client.
150-
(2) If ``transport`` argument is None, ``client_options`` can be
151-
used to create a mutual TLS transport. If ``client_cert_source``
152-
is provided, mutual TLS transport will be created with the given
153-
``api_endpoint`` or the default mTLS endpoint, and the client
154-
SSL credentials obtained from ``client_cert_source``.
153+
default endpoint provided by the client. GOOGLE_API_USE_MTLS
154+
environment variable can also be used to override the endpoint:
155+
"Always" (always use the default mTLS endpoint), "Never" (always
156+
use the default regular endpoint, this is the default value for
157+
the environment variable) and "Auto" (auto switch to the default
158+
mTLS endpoint if client SSL credentials is present). However,
159+
the ``api_endpoint`` property takes precedence if provided.
160+
(2) The ``client_cert_source`` property is used to provide client
161+
SSL credentials for mutual TLS transport. If not provided, the
162+
default SSL credentials will be used if present.
155163

156164
Raises:
157-
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
165+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
158166
creation failed for any reason.
159167
"""
160168
if isinstance(client_options, dict):
161169
client_options = ClientOptions.from_dict(client_options)
170+
if client_options is None:
171+
client_options = ClientOptions.ClientOptions()
172+
173+
if transport is None and client_options.api_endpoint is None:
174+
use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "Never")
175+
if use_mtls_env == "Never":
176+
client_options.api_endpoint = self.DEFAULT_ENDPOINT
177+
elif use_mtls_env == "Always":
178+
client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT
179+
elif use_mtls_env == "Auto":
180+
has_client_cert_source = (
181+
client_options.client_cert_source is not None
182+
or mtls.has_default_client_cert_source()
183+
)
184+
client_options.api_endpoint = (
185+
self.DEFAULT_MTLS_ENDPOINT if has_client_cert_source else self.DEFAULT_ENDPOINT
186+
)
187+
else:
188+
raise MutualTLSChannelError(
189+
"Unsupported GOOGLE_API_USE_MTLS value. Accepted values: Never, Auto, Always"
190+
)
162191

163192
# Save or instantiate the transport.
164193
# Ordinarily, we provide the transport, but allowing a custom transport
@@ -169,38 +198,16 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
169198
raise ValueError('When providing a transport instance, '
170199
'provide its credentials directly.')
171200
self._transport = transport
172-
elif client_options is None or (
173-
client_options.api_endpoint is None
174-
and client_options.client_cert_source is None
175-
):
176-
# Don't trigger mTLS if we get an empty ClientOptions.
201+
elif isinstance(transport, str):
177202
Transport = type(self).get_transport_class(transport)
178203
self._transport = Transport(
179204
credentials=credentials, host=self.DEFAULT_ENDPOINT
180205
)
181206
else:
182-
# We have a non-empty ClientOptions. If client_cert_source is
183-
# provided, trigger mTLS with user provided endpoint or the default
184-
# mTLS endpoint.
185-
if client_options.client_cert_source:
186-
api_mtls_endpoint = (
187-
client_options.api_endpoint
188-
if client_options.api_endpoint
189-
else self.DEFAULT_MTLS_ENDPOINT
190-
)
191-
else:
192-
api_mtls_endpoint = None
193-
194-
api_endpoint = (
195-
client_options.api_endpoint
196-
if client_options.api_endpoint
197-
else self.DEFAULT_ENDPOINT
198-
)
199-
200207
self._transport = {{ service.name }}GrpcTransport(
201208
credentials=credentials,
202-
host=api_endpoint,
203-
api_mtls_endpoint=api_mtls_endpoint,
209+
host=client_options.api_endpoint,
210+
api_mtls_endpoint=client_options.api_endpoint,
204211
client_cert_source=client_options.client_cert_source,
205212
)
206213

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore
77
{%- if service.has_lro %}
88
from google.api_core import operations_v1 # type: ignore
99
{%- endif %}
10+
from google import auth # type: ignore
1011
from google.auth import credentials # type: ignore
1112
from google.auth.transport.grpc import SslCredentials # type: ignore
1213

@@ -63,7 +64,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
6364
is None.
6465

6566
Raises:
66-
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
67+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
6768
creation failed for any reason.
6869
"""
6970
if channel:
@@ -76,6 +77,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
7677
elif api_mtls_endpoint:
7778
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
7879

80+
if credentials is None:
81+
credentials, _ = auth.default(scopes=self.AUTH_SCOPES)
82+
7983
# Create SSL credentials with client_cert_source or application
8084
# default SSL credentials.
8185
if client_cert_source:
@@ -96,7 +100,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
96100

97101
# Run the base constructor.
98102
super().__init__(host=host, credentials=credentials)
99-
self._stubs = {} # type: Dict[str, Callable]
103+
self._stubs = {} # type: Dict[str, Callable]
100104

101105

102106
@classmethod

gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{% extends "_base.py.j2" %}
22

33
{% block content %}
4+
import os
45
from unittest import mock
56

67
import grpc
@@ -11,6 +12,7 @@ import pytest
1112
{% filter sort_lines -%}
1213
from google import auth
1314
from google.auth import credentials
15+
from google.auth.exceptions import MutualTLSChannelError
1416
from google.oauth2 import service_account
1517
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
1618
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
@@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
6365
{% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}
6466

6567

68+
def test_{{ service.client_name|snake_case }}_get_transport_class():
69+
transport = {{ service.client_name }}.get_transport_class()
70+
assert transport == transports.{{ service.name }}GrpcTransport
71+
72+
transport = {{ service.client_name }}.get_transport_class("grpc")
73+
assert transport == transports.{{ service.name }}GrpcTransport
74+
75+
6676
def test_{{ service.client_name|snake_case }}_client_options():
6777
# Check that if channel is provided we won't create a new one.
6878
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
@@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options():
7282
client = {{ service.client_name }}(transport=transport)
7383
gtc.assert_not_called()
7484

75-
# Check mTLS is not triggered with empty client options.
76-
options = client_options.ClientOptions()
85+
# Check that if channel is provided via str we will create a new one.
7786
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
78-
transport = gtc.return_value = mock.MagicMock()
79-
client = {{ service.client_name }}(client_options=options)
80-
transport.assert_called_once_with(
81-
credentials=None,
82-
host=client.DEFAULT_ENDPOINT,
83-
)
87+
client = {{ service.client_name }}(transport="grpc")
88+
gtc.assert_called()
8489

85-
# Check mTLS is not triggered if api_endpoint is provided but
86-
# client_cert_source is None.
90+
# Check the case api_endpoint is provided.
8791
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
8892
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
8993
grpc_transport.return_value = None
9094
client = {{ service.client_name }}(client_options=options)
9195
grpc_transport.assert_called_once_with(
92-
api_mtls_endpoint=None,
96+
api_mtls_endpoint="squid.clam.whelk",
9397
client_cert_source=None,
9498
credentials=None,
9599
host="squid.clam.whelk",
96100
)
97101

98-
# Check mTLS is triggered if client_cert_source is provided.
99-
options = client_options.ClientOptions(
100-
client_cert_source=client_cert_source_callback
101-
)
102+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
103+
# "Never".
104+
os.environ["GOOGLE_API_USE_MTLS"] = "Never"
102105
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
103106
grpc_transport.return_value = None
104-
client = {{ service.client_name }}(client_options=options)
107+
client = {{ service.client_name }}()
105108
grpc_transport.assert_called_once_with(
106-
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
107-
client_cert_source=client_cert_source_callback,
109+
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
110+
client_cert_source=None,
108111
credentials=None,
109112
host=client.DEFAULT_ENDPOINT,
110113
)
111114

112-
# Check mTLS is triggered if api_endpoint and client_cert_source are provided.
113-
options = client_options.ClientOptions(
114-
api_endpoint="squid.clam.whelk",
115-
client_cert_source=client_cert_source_callback
116-
)
115+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
116+
# "Always".
117+
os.environ["GOOGLE_API_USE_MTLS"] = "Always"
118+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
119+
grpc_transport.return_value = None
120+
client = {{ service.client_name }}()
121+
grpc_transport.assert_called_once_with(
122+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
123+
client_cert_source=None,
124+
credentials=None,
125+
host=client.DEFAULT_MTLS_ENDPOINT,
126+
)
127+
128+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
129+
# "Auto", and client_cert_source is provided.
130+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
131+
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
117132
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
118133
grpc_transport.return_value = None
119134
client = {{ service.client_name }}(client_options=options)
120135
grpc_transport.assert_called_once_with(
121-
api_mtls_endpoint="squid.clam.whelk",
136+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
122137
client_cert_source=client_cert_source_callback,
123138
credentials=None,
124-
host="squid.clam.whelk",
139+
host=client.DEFAULT_MTLS_ENDPOINT,
125140
)
126141

142+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
143+
# "Auto", and default_client_cert_source is provided.
144+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
145+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
146+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
147+
grpc_transport.return_value = None
148+
client = {{ service.client_name }}()
149+
grpc_transport.assert_called_once_with(
150+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
151+
client_cert_source=None,
152+
credentials=None,
153+
host=client.DEFAULT_MTLS_ENDPOINT,
154+
)
155+
156+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
157+
# "Auto", but client_cert_source and default_client_cert_source are None.
158+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
159+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
160+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
161+
grpc_transport.return_value = None
162+
client = {{ service.client_name }}()
163+
grpc_transport.assert_called_once_with(
164+
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
165+
client_cert_source=None,
166+
credentials=None,
167+
host=client.DEFAULT_ENDPOINT,
168+
)
169+
170+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
171+
# unsupported value.
172+
os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
173+
with pytest.raises(MutualTLSChannelError):
174+
client = {{ service.client_name }}()
175+
176+
del os.environ["GOOGLE_API_USE_MTLS"]
177+
127178

128179
def test_{{ service.client_name|snake_case }}_client_options_from_dict():
129180
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
@@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
132183
client_options={'api_endpoint': 'squid.clam.whelk'}
133184
)
134185
grpc_transport.assert_called_once_with(
135-
api_mtls_endpoint=None,
186+
api_mtls_endpoint="squid.clam.whelk",
136187
client_cert_source=None,
137188
credentials=None,
138189
host="squid.clam.whelk",
@@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc():
490541
))
491542

492543

544+
def test_{{ service.name|snake_case }}_transport_auth_adc():
545+
# If credentials and host are not provided, the transport class should use
546+
# ADC credentials.
547+
with mock.patch.object(auth, 'default') as adc:
548+
adc.return_value = (credentials.AnonymousCredentials(), None)
549+
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
550+
adc.assert_called_once_with(scopes=(
551+
{%- for scope in service.oauth_scopes %}
552+
'{{ scope }}',
553+
{%- endfor %}
554+
))
555+
556+
493557
def test_{{ service.name|snake_case }}_host_no_port():
494558
{% with host = (service.host|default('localhost', true)).split(':')[0] -%}
495559
client = {{ service.client_name }}(
496560
credentials=credentials.AnonymousCredentials(),
497561
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
498-
transport='grpc',
499562
)
500563
assert client._transport._host == '{{ host }}:443'
501564
{% endwith %}
@@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port():
506569
client = {{ service.client_name }}(
507570
credentials=credentials.AnonymousCredentials(),
508571
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
509-
transport='grpc',
510572
)
511573
assert client._transport._host == '{{ host }}:8000'
512574
{% endwith %}

0 commit comments

Comments
 (0)