Skip to content

Commit

Permalink
fix: expose transport property for clients (#645)
Browse files Browse the repository at this point in the history
Sometimes it's useful to get a reference to the transport for a client object.

Closes #640
  • Loading branch information
software-dov authored Oct 9, 2020
1 parent 7ff5963 commit 13cddda
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

from_service_account_json = from_service_account_file

@property
def transport(self) -> {{ service.name }}Transport:
"""Return the transport used by the client instance.

Returns:
{{ service.name }}Transport: The transport used by the client instance.
"""
return self._transport


{% for message in service.resource_messages|sort(attribute="resource_type") -%}
@staticmethod
Expand All @@ -143,7 +152,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
"""Parse a {{ resource_msg.message_type.resource_type|snake_case }} path into its component segments."""
m = re.match(r"{{ resource_msg.message_type.path_regex_str }}", path)
return m.groupdict() if m else {}

{% endfor %} {# common resources #}

def __init__(self, *,
Expand Down Expand Up @@ -179,12 +188,12 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
Expand All @@ -193,10 +202,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()

# Create SSL credentials for mutual TLS if needed.
use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")))

ssl_credentials = None
is_mtls = False
if use_client_cert:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory:
factory.return_value = creds
client = {{ service.client_name }}.from_service_account_file("dummy/file/path.json")
assert client._transport._credentials == creds
assert client.transport._credentials == creds

client = {{ service.client_name }}.from_service_account_json("dummy/file/path.json")
assert client._transport._credentials == creds
assert client.transport._credentials == creds

{% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}
{% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}


def test_{{ service.client_name|snake_case }}_get_transport_class():
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(use_client_cert_env)
else:
expected_ssl_channel_creds = ssl_channel_creds
expected_host = client.DEFAULT_MTLS_ENDPOINT

grpc_transport.assert_called_once_with(
ssl_channel_credentials=expected_ssl_channel_creds,
credentials=None,
Expand All @@ -182,9 +182,9 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(use_client_cert_env)
# GOOGLE_API_USE_CLIENT_CERTIFICATE value.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}):
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None):
with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None):
with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock:
with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock:
with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock:
if use_client_cert_env == "false":
is_mtls_mock.return_value = False
ssl_credentials_mock.return_value = None
Expand All @@ -195,7 +195,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(use_client_cert_env)
ssl_credentials_mock.return_value = mock.Mock()
expected_host = client.DEFAULT_MTLS_ENDPOINT
expected_ssl_channel_creds = ssl_credentials_mock.return_value

grpc_transport.return_value = None
client = {{ service.client_name }}()
grpc_transport.assert_called_once_with(
Expand All @@ -208,7 +208,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(use_client_cert_env)
# Check the case client_cert_source and ADC client cert are not provided.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}):
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None):
with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None):
with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock:
is_mtls_mock.return_value = False
grpc_transport.return_value = None
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ m

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void -%}
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_{{ method.name|snake_case }}_field_headers():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
{% if method.void -%}
call.return_value = None
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_{{ method.name|snake_case }}_from_dict():
)
# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void -%}
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_{{ method.name|snake_case }}_flattened():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
# Designate an appropriate return value for the call.
{% if method.void -%}
Expand Down Expand Up @@ -462,7 +462,7 @@ def test_{{ method.name|snake_case }}_pager():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
call.side_effect = (
Expand Down Expand Up @@ -521,7 +521,7 @@ def test_{{ method.name|snake_case }}_pages():

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client._transport.{{ method.name|snake_case }}),
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
# Set the response to a series of pages.
call.side_effect = (
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_transport_instance():
credentials=credentials.AnonymousCredentials(),
)
client = {{ service.client_name }}(transport=transport)
assert client._transport is transport
assert client.transport is transport


def test_transport_grpc_default():
Expand All @@ -589,7 +589,7 @@ def test_transport_grpc_default():
credentials=credentials.AnonymousCredentials(),
)
assert isinstance(
client._transport,
client.transport,
transports.{{ service.name }}GrpcTransport,
)

Expand Down Expand Up @@ -669,7 +669,7 @@ def test_{{ service.name|snake_case }}_host_no_port():
credentials=credentials.AnonymousCredentials(),
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
)
assert client._transport._host == '{{ host }}:443'
assert client.transport._host == '{{ host }}:443'
{% endwith %}


Expand All @@ -679,7 +679,7 @@ def test_{{ service.name|snake_case }}_host_with_port():
credentials=credentials.AnonymousCredentials(),
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
)
assert client._transport._host == '{{ host }}:8000'
assert client.transport._host == '{{ host }}:8000'
{% endwith %}


Expand All @@ -701,7 +701,7 @@ def test_{{ service.name|snake_case }}_grpc_lro_client():
credentials=credentials.AnonymousCredentials(),
transport='grpc',
)
transport = client._transport
transport = client.transport

# Ensure that we have a api-core operations client.
assert isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

from_service_account_json = from_service_account_file

@property
def transport(self) -> {{ service.name }}Transport:
"""Return the transport used by the client instance.

Returns:
{{ service.name }}Transport: The transport used by the client instance.
"""
return self._transport


{% for message in service.resource_messages|sort(attribute="resource_type") -%}
@staticmethod
Expand All @@ -150,7 +159,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
"""Parse a {{ resource_msg.message_type.resource_type|snake_case }} path into its component segments."""
m = re.match(r"{{ resource_msg.message_type.path_regex_str }}", path)
return m.groupdict() if m else {}

{% endfor %} {# common resources #}

def __init__(self, *,
Expand Down Expand Up @@ -186,12 +195,12 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
not provided, the default SSL client certificate will be used if
present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
set, no client certificate will be used.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
Expand All @@ -200,10 +209,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()

# Create SSL credentials for mutual TLS if needed.
use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")))

ssl_credentials = None
is_mtls = False
if use_client_cert:
Expand Down
Loading

0 comments on commit 13cddda

Please sign in to comment.