Skip to content

Commit

Permalink
fix: expose ssl credentials from transport (#677)
Browse files Browse the repository at this point in the history
Expose ssl credentials from transport.

This is used to fix pubsub client [mtls issue](googleapis/python-pubsub#224). Pubsub client creates its own transport so mtls is completely missing. The solution would be taking the ssl credentials from the auto-generated client's transport and passing it when the handwritten client creates the transport.
  • Loading branch information
arithmetic1728 authored Oct 21, 2020
1 parent 0fe9330 commit da0ee3e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

Expand Down Expand Up @@ -122,6 +125,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel():
)
assert transport.grpc_channel == channel
assert transport._host == "squid.clam.whelk:443"
assert transport._ssl_channel_credentials == None


@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}])
Expand Down Expand Up @@ -749,6 +750,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel
assert transport._ssl_channel_credentials == mock_ssl_cred


@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }},])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

Expand Down Expand Up @@ -130,6 +133,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,16 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

Expand Down Expand Up @@ -174,6 +177,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel():
)
assert transport.grpc_channel == channel
assert transport._host == "squid.clam.whelk:443"
assert transport._ssl_channel_credentials == None


def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel():
Expand All @@ -1196,6 +1197,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel():
)
assert transport.grpc_channel == channel
assert transport._host == "squid.clam.whelk:443"
assert transport._ssl_channel_credentials == None


@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}])
Expand Down Expand Up @@ -1237,6 +1239,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel
assert transport._ssl_channel_credentials == mock_ssl_cred


@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}])
Expand Down

0 comments on commit da0ee3e

Please sign in to comment.