From 524dbab16d248198ca10a08ecede4600fd36cefc Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Tue, 19 Jan 2021 14:29:26 -0800 Subject: [PATCH] feat: add mtls feature to rest transport (#731) * feat: add mtls support to rest transport * update * update * update --- .../%sub/services/%service/client.py.j2 | 16 +- .../services/%service/transports/grpc.py.j2 | 20 +- .../%service/transports/grpc_asyncio.py.j2 | 20 +- .../services/%service/transports/rest.py.j2 | 9 +- .../%name_%version/%sub/test_%service.py.j2 | 171 ++++++++++++------ 5 files changed, 157 insertions(+), 79 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 21b0c9f201..ada75471b0 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -239,21 +239,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): # Create SSL credentials for mutual TLS if needed. use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -292,7 +286,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 0bcd1fa64f..b1d1d18917 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -51,6 +51,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -82,6 +83,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -98,6 +103,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -107,8 +117,6 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" if credentials is None: @@ -144,12 +152,18 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): if credentials is None: credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 index 83013d83ee..94da8db76a 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 @@ -94,6 +94,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -126,6 +127,10 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -141,6 +146,11 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): and ``credentials_file`` are passed. """ self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: # Sanity check: Ensure that channel and credentials are not both @@ -151,8 +161,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" if credentials is None: @@ -188,12 +196,18 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): if credentials is None: credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 338bfcbaff..3446a06e90 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -48,7 +48,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -68,8 +68,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): {%- if service.has_lro %} self._operations_client = None {%- endif %} + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) {%- if service.has_lro %} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 6affa40bc8..860625b493 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -156,7 +156,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -172,7 +172,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +188,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -214,7 +214,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,76 +244,67 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) with mock.patch.object(transport_class, '__init__') as patched: - ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize("client_class,transport_class,transport_name", [ {%- if 'grpc' in opts.transport %} @@ -336,7 +327,7 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -362,7 +353,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -380,7 +371,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -1348,6 +1339,64 @@ def test_{{ service.name|snake_case }}_transport_auth_adc(): ) {% endif %} +{% if 'grpc' in opts.transport %} +@pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}]) +def test_{{ service.name|snake_case }}_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) +{% endif %} + +{% if 'rest' in opts.transport %} +def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.{{ service.rest_transport_name }} ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) +{% endif %} + def test_{{ service.name|snake_case }}_host_no_port(): {% with host = (service.host|default('localhost', true)).split(':')[0] -%} client = {{ service.client_name }}( @@ -1394,6 +1443,8 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}]) def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_source( transport_class @@ -1440,6 +1491,8 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}]) def test_{{ service.name|snake_case }}_transport_channel_mtls_with_adc( transport_class