1
1
{% extends "_base.py.j2" %}
2
2
3
3
{% block content %}
4
+ import os
4
5
from unittest import mock
5
6
6
7
import grpc
@@ -11,6 +12,7 @@ import pytest
11
12
{% filter sort_lines -%}
12
13
from google import auth
13
14
from google.auth import credentials
15
+ from google.auth.exceptions import MutualTLSChannelError
14
16
from google.oauth2 import service_account
15
17
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
16
18
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
@@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
63
65
{% if service .host %} assert client._transport._host == '{{ service.host }}{% if ":" not in service .host %} :443{% endif %} '{% endif %}
64
66
65
67
68
+ def test_{{ service.client_name|snake_case }}_get_transport_class():
69
+ transport = {{ service.client_name }}.get_transport_class()
70
+ assert transport == transports.{{ service.name }}GrpcTransport
71
+
72
+ transport = {{ service.client_name }}.get_transport_class("grpc")
73
+ assert transport == transports.{{ service.name }}GrpcTransport
74
+
75
+
66
76
def test_{{ service.client_name|snake_case }}_client_options():
67
77
# Check that if channel is provided we won't create a new one.
68
78
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
@@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options():
72
82
client = {{ service.client_name }}(transport=transport)
73
83
gtc.assert_not_called()
74
84
75
- # Check mTLS is not triggered with empty client options.
76
- options = client_options.ClientOptions()
85
+ # Check that if channel is provided via str we will create a new one.
77
86
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
78
- transport = gtc.return_value = mock.MagicMock()
79
- client = {{ service.client_name }}(client_options=options)
80
- transport.assert_called_once_with(
81
- credentials=None,
82
- host=client.DEFAULT_ENDPOINT,
83
- )
87
+ client = {{ service.client_name }}(transport="grpc")
88
+ gtc.assert_called()
84
89
85
- # Check mTLS is not triggered if api_endpoint is provided but
86
- # client_cert_source is None.
90
+ # Check the case api_endpoint is provided.
87
91
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
88
92
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
89
93
grpc_transport.return_value = None
90
94
client = {{ service.client_name }}(client_options=options)
91
95
grpc_transport.assert_called_once_with(
92
- api_mtls_endpoint=None ,
96
+ api_mtls_endpoint="squid.clam.whelk" ,
93
97
client_cert_source=None,
94
98
credentials=None,
95
99
host="squid.clam.whelk",
96
100
)
97
101
98
- # Check mTLS is triggered if client_cert_source is provided.
99
- options = client_options.ClientOptions(
100
- client_cert_source=client_cert_source_callback
101
- )
102
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
103
+ # "Never".
104
+ os.environ["GOOGLE_API_USE_MTLS"] = "Never"
102
105
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
103
106
grpc_transport.return_value = None
104
- client = {{ service.client_name }}(client_options=options )
107
+ client = {{ service.client_name }}()
105
108
grpc_transport.assert_called_once_with(
106
- api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT ,
107
- client_cert_source=client_cert_source_callback ,
109
+ api_mtls_endpoint=client.DEFAULT_ENDPOINT ,
110
+ client_cert_source=None ,
108
111
credentials=None,
109
112
host=client.DEFAULT_ENDPOINT,
110
113
)
111
114
112
- # Check mTLS is triggered if api_endpoint and client_cert_source are provided.
113
- options = client_options.ClientOptions(
114
- api_endpoint="squid.clam.whelk",
115
- client_cert_source=client_cert_source_callback
116
- )
115
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
116
+ # "Always".
117
+ os.environ["GOOGLE_API_USE_MTLS"] = "Always"
118
+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
119
+ grpc_transport.return_value = None
120
+ client = {{ service.client_name }}()
121
+ grpc_transport.assert_called_once_with(
122
+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
123
+ client_cert_source=None,
124
+ credentials=None,
125
+ host=client.DEFAULT_MTLS_ENDPOINT,
126
+ )
127
+
128
+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
129
+ # "Auto", and client_cert_source is provided.
130
+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
131
+ options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
117
132
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
118
133
grpc_transport.return_value = None
119
134
client = {{ service.client_name }}(client_options=options)
120
135
grpc_transport.assert_called_once_with(
121
- api_mtls_endpoint="squid.clam.whelk" ,
136
+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT ,
122
137
client_cert_source=client_cert_source_callback,
123
138
credentials=None,
124
- host="squid.clam.whelk" ,
139
+ host=client.DEFAULT_MTLS_ENDPOINT ,
125
140
)
126
141
142
+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
143
+ # "Auto", and default_client_cert_source is provided.
144
+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
145
+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
146
+ with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
147
+ grpc_transport.return_value = None
148
+ client = {{ service.client_name }}()
149
+ grpc_transport.assert_called_once_with(
150
+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
151
+ client_cert_source=None,
152
+ credentials=None,
153
+ host=client.DEFAULT_MTLS_ENDPOINT,
154
+ )
155
+
156
+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
157
+ # "Auto", but client_cert_source and default_client_cert_source are None.
158
+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
159
+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
160
+ with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
161
+ grpc_transport.return_value = None
162
+ client = {{ service.client_name }}()
163
+ grpc_transport.assert_called_once_with(
164
+ api_mtls_endpoint=client.DEFAULT_ENDPOINT,
165
+ client_cert_source=None,
166
+ credentials=None,
167
+ host=client.DEFAULT_ENDPOINT,
168
+ )
169
+
170
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
171
+ # unsupported value.
172
+ os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
173
+ with pytest.raises(MutualTLSChannelError):
174
+ client = {{ service.client_name }}()
175
+
176
+ del os.environ["GOOGLE_API_USE_MTLS"]
177
+
127
178
128
179
def test_{{ service.client_name|snake_case }}_client_options_from_dict():
129
180
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
@@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
132
183
client_options={'api_endpoint': 'squid.clam.whelk'}
133
184
)
134
185
grpc_transport.assert_called_once_with(
135
- api_mtls_endpoint=None ,
186
+ api_mtls_endpoint="squid.clam.whelk" ,
136
187
client_cert_source=None,
137
188
credentials=None,
138
189
host="squid.clam.whelk",
@@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc():
490
541
))
491
542
492
543
544
+ def test_{{ service.name|snake_case }}_transport_auth_adc():
545
+ # If credentials and host are not provided, the transport class should use
546
+ # ADC credentials.
547
+ with mock.patch.object(auth, 'default') as adc:
548
+ adc.return_value = (credentials.AnonymousCredentials(), None)
549
+ transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
550
+ adc.assert_called_once_with(scopes=(
551
+ {% - for scope in service .oauth_scopes %}
552
+ '{{ scope }}',
553
+ {% - endfor %}
554
+ ))
555
+
556
+
493
557
def test_{{ service.name|snake_case }}_host_no_port():
494
558
{% with host = (service .host |default ('localhost' , true )).split (':' )[0] -%}
495
559
client = {{ service.client_name }}(
496
560
credentials=credentials.AnonymousCredentials(),
497
561
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
498
- transport='grpc',
499
562
)
500
563
assert client._transport._host == '{{ host }}:443'
501
564
{% endwith %}
@@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port():
506
569
client = {{ service.client_name }}(
507
570
credentials=credentials.AnonymousCredentials(),
508
571
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
509
- transport='grpc',
510
572
)
511
573
assert client._transport._host == '{{ host }}:8000'
512
574
{% endwith %}
0 commit comments