diff --git a/packages/google-cloud-retail/.github/header-checker-lint.yml b/packages/google-cloud-retail/.github/header-checker-lint.yml
new file mode 100644
index 000000000000..fc281c05bd55
--- /dev/null
+++ b/packages/google-cloud-retail/.github/header-checker-lint.yml
@@ -0,0 +1,15 @@
+{"allowedCopyrightHolders": ["Google LLC"],
+ "allowedLicenses": ["Apache-2.0", "MIT", "BSD-3"],
+ "ignoreFiles": ["**/requirements.txt", "**/requirements-test.txt"],
+ "sourceFileExtensions": [
+ "ts",
+ "js",
+ "java",
+ "sh",
+ "Dockerfile",
+ "yaml",
+ "py",
+ "html",
+ "txt"
+ ]
+}
\ No newline at end of file
diff --git a/packages/google-cloud-retail/.gitignore b/packages/google-cloud-retail/.gitignore
index b9daa52f118d..b4243ced74e4 100644
--- a/packages/google-cloud-retail/.gitignore
+++ b/packages/google-cloud-retail/.gitignore
@@ -50,8 +50,10 @@ docs.metadata
# Virtual environment
env/
+
+# Test logs
coverage.xml
-sponge_log.xml
+*sponge_log.xml
# System test environment variables.
system_tests/local_test_setup
diff --git a/packages/google-cloud-retail/.kokoro/build.sh b/packages/google-cloud-retail/.kokoro/build.sh
index 09b60c979502..a98eb35b1d4d 100755
--- a/packages/google-cloud-retail/.kokoro/build.sh
+++ b/packages/google-cloud-retail/.kokoro/build.sh
@@ -40,6 +40,16 @@ python3 -m pip uninstall --yes --quiet nox-automation
python3 -m pip install --upgrade --quiet nox
python3 -m nox --version
+# If this is a continuous build, send the test log to the FlakyBot.
+# See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot.
+if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"continuous"* ]]; then
+ cleanup() {
+ chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot
+ $KOKORO_GFILE_DIR/linux_amd64/flakybot
+ }
+ trap cleanup EXIT HUP
+fi
+
# If NOX_SESSION is set, it only runs the specified session,
# otherwise run all the sessions.
if [[ -n "${NOX_SESSION:-}" ]]; then
diff --git a/packages/google-cloud-retail/.kokoro/samples/python3.6/periodic-head.cfg b/packages/google-cloud-retail/.kokoro/samples/python3.6/periodic-head.cfg
new file mode 100644
index 000000000000..f9cfcd33e058
--- /dev/null
+++ b/packages/google-cloud-retail/.kokoro/samples/python3.6/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-pubsub/.kokoro/test-samples-against-head.sh"
+}
diff --git a/packages/google-cloud-retail/.kokoro/samples/python3.7/periodic-head.cfg b/packages/google-cloud-retail/.kokoro/samples/python3.7/periodic-head.cfg
new file mode 100644
index 000000000000..f9cfcd33e058
--- /dev/null
+++ b/packages/google-cloud-retail/.kokoro/samples/python3.7/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-pubsub/.kokoro/test-samples-against-head.sh"
+}
diff --git a/packages/google-cloud-retail/.kokoro/samples/python3.8/periodic-head.cfg b/packages/google-cloud-retail/.kokoro/samples/python3.8/periodic-head.cfg
new file mode 100644
index 000000000000..f9cfcd33e058
--- /dev/null
+++ b/packages/google-cloud-retail/.kokoro/samples/python3.8/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-pubsub/.kokoro/test-samples-against-head.sh"
+}
diff --git a/packages/google-cloud-retail/.kokoro/test-samples-against-head.sh b/packages/google-cloud-retail/.kokoro/test-samples-against-head.sh
new file mode 100755
index 000000000000..a1e41a6acf69
--- /dev/null
+++ b/packages/google-cloud-retail/.kokoro/test-samples-against-head.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# A customized test runner for samples.
+#
+# For periodic builds, you can specify this file for testing against head.
+
+# `-e` enables the script to automatically fail when a command fails
+# `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero
+set -eo pipefail
+# Enables `**` to include files nested inside sub-folders
+shopt -s globstar
+
+cd github/python-retail
+
+exec .kokoro/test-samples-impl.sh
diff --git a/packages/google-cloud-retail/.kokoro/test-samples-impl.sh b/packages/google-cloud-retail/.kokoro/test-samples-impl.sh
new file mode 100755
index 000000000000..cf5de74c17a5
--- /dev/null
+++ b/packages/google-cloud-retail/.kokoro/test-samples-impl.sh
@@ -0,0 +1,102 @@
+#!/bin/bash
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# `-e` enables the script to automatically fail when a command fails
+# `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero
+set -eo pipefail
+# Enables `**` to include files nested inside sub-folders
+shopt -s globstar
+
+# Exit early if samples directory doesn't exist
+if [ ! -d "./samples" ]; then
+ echo "No tests run. `./samples` not found"
+ exit 0
+fi
+
+# Disable buffering, so that the logs stream through.
+export PYTHONUNBUFFERED=1
+
+# Debug: show build environment
+env | grep KOKORO
+
+# Install nox
+python3.6 -m pip install --upgrade --quiet nox
+
+# Use secrets acessor service account to get secrets
+if [[ -f "${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" ]]; then
+ gcloud auth activate-service-account \
+ --key-file="${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" \
+ --project="cloud-devrel-kokoro-resources"
+fi
+
+# This script will create 3 files:
+# - testing/test-env.sh
+# - testing/service-account.json
+# - testing/client-secrets.json
+./scripts/decrypt-secrets.sh
+
+source ./testing/test-env.sh
+export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/testing/service-account.json
+
+# For cloud-run session, we activate the service account for gcloud sdk.
+gcloud auth activate-service-account \
+ --key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
+
+export GOOGLE_CLIENT_SECRETS=$(pwd)/testing/client-secrets.json
+
+echo -e "\n******************** TESTING PROJECTS ********************"
+
+# Switch to 'fail at end' to allow all tests to complete before exiting.
+set +e
+# Use RTN to return a non-zero value if the test fails.
+RTN=0
+ROOT=$(pwd)
+# Find all requirements.txt in the samples directory (may break on whitespace).
+for file in samples/**/requirements.txt; do
+ cd "$ROOT"
+ # Navigate to the project folder.
+ file=$(dirname "$file")
+ cd "$file"
+
+ echo "------------------------------------------------------------"
+ echo "- testing $file"
+ echo "------------------------------------------------------------"
+
+ # Use nox to execute the tests for the project.
+ python3.6 -m nox -s "$RUN_TESTS_SESSION"
+ EXIT=$?
+
+ # If this is a periodic build, send the test log to the FlakyBot.
+ # See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot.
+ if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then
+ chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot
+ $KOKORO_GFILE_DIR/linux_amd64/flakybot
+ fi
+
+ if [[ $EXIT -ne 0 ]]; then
+ RTN=1
+ echo -e "\n Testing failed: Nox returned a non-zero exit code. \n"
+ else
+ echo -e "\n Testing completed.\n"
+ fi
+
+done
+cd "$ROOT"
+
+# Workaround for Kokoro permissions issue: delete secrets
+rm testing/{test-env.sh,client-secrets.json,service-account.json}
+
+exit "$RTN"
diff --git a/packages/google-cloud-retail/.kokoro/test-samples.sh b/packages/google-cloud-retail/.kokoro/test-samples.sh
index a3a4e1030a4e..4013eb994a1b 100755
--- a/packages/google-cloud-retail/.kokoro/test-samples.sh
+++ b/packages/google-cloud-retail/.kokoro/test-samples.sh
@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# The default test runner for samples.
+#
+# For periodic builds, we rewinds the repo to the latest release, and
+# run test-samples-impl.sh.
# `-e` enables the script to automatically fail when a command fails
# `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero
@@ -24,87 +28,19 @@ cd github/python-retail
# Run periodic samples tests at latest release
if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then
+ # preserving the test runner implementation.
+ cp .kokoro/test-samples-impl.sh "${TMPDIR}/test-samples-impl.sh"
+ echo "--- IMPORTANT IMPORTANT IMPORTANT ---"
+ echo "Now we rewind the repo back to the latest release..."
LATEST_RELEASE=$(git describe --abbrev=0 --tags)
git checkout $LATEST_RELEASE
-fi
-
-# Exit early if samples directory doesn't exist
-if [ ! -d "./samples" ]; then
- echo "No tests run. `./samples` not found"
- exit 0
-fi
-
-# Disable buffering, so that the logs stream through.
-export PYTHONUNBUFFERED=1
-
-# Debug: show build environment
-env | grep KOKORO
-
-# Install nox
-python3.6 -m pip install --upgrade --quiet nox
-
-# Use secrets acessor service account to get secrets
-if [[ -f "${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" ]]; then
- gcloud auth activate-service-account \
- --key-file="${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" \
- --project="cloud-devrel-kokoro-resources"
-fi
-
-# This script will create 3 files:
-# - testing/test-env.sh
-# - testing/service-account.json
-# - testing/client-secrets.json
-./scripts/decrypt-secrets.sh
-
-source ./testing/test-env.sh
-export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/testing/service-account.json
-
-# For cloud-run session, we activate the service account for gcloud sdk.
-gcloud auth activate-service-account \
- --key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
-
-export GOOGLE_CLIENT_SECRETS=$(pwd)/testing/client-secrets.json
-
-echo -e "\n******************** TESTING PROJECTS ********************"
-
-# Switch to 'fail at end' to allow all tests to complete before exiting.
-set +e
-# Use RTN to return a non-zero value if the test fails.
-RTN=0
-ROOT=$(pwd)
-# Find all requirements.txt in the samples directory (may break on whitespace).
-for file in samples/**/requirements.txt; do
- cd "$ROOT"
- # Navigate to the project folder.
- file=$(dirname "$file")
- cd "$file"
-
- echo "------------------------------------------------------------"
- echo "- testing $file"
- echo "------------------------------------------------------------"
-
- # Use nox to execute the tests for the project.
- python3.6 -m nox -s "$RUN_TESTS_SESSION"
- EXIT=$?
-
- # If this is a periodic build, send the test log to the FlakyBot.
- # See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot.
- if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then
- chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot
- $KOKORO_GFILE_DIR/linux_amd64/flakybot
+ echo "The current head is: "
+ echo $(git rev-parse --verify HEAD)
+ echo "--- IMPORTANT IMPORTANT IMPORTANT ---"
+ # move back the test runner implementation if there's no file.
+ if [ ! -f .kokoro/test-samples-impl.sh ]; then
+ cp "${TMPDIR}/test-samples-impl.sh" .kokoro/test-samples-impl.sh
fi
+fi
- if [[ $EXIT -ne 0 ]]; then
- RTN=1
- echo -e "\n Testing failed: Nox returned a non-zero exit code. \n"
- else
- echo -e "\n Testing completed.\n"
- fi
-
-done
-cd "$ROOT"
-
-# Workaround for Kokoro permissions issue: delete secrets
-rm testing/{test-env.sh,client-secrets.json,service-account.json}
-
-exit "$RTN"
+exec .kokoro/test-samples-impl.sh
diff --git a/packages/google-cloud-retail/.pre-commit-config.yaml b/packages/google-cloud-retail/.pre-commit-config.yaml
index a9024b15d725..32302e4883a1 100644
--- a/packages/google-cloud-retail/.pre-commit-config.yaml
+++ b/packages/google-cloud-retail/.pre-commit-config.yaml
@@ -12,6 +12,6 @@ repos:
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
- rev: 3.8.4
+ rev: 3.9.0
hooks:
- id: flake8
diff --git a/packages/google-cloud-retail/.trampolinerc b/packages/google-cloud-retail/.trampolinerc
index c7d663ae9c57..383b6ec89fbc 100644
--- a/packages/google-cloud-retail/.trampolinerc
+++ b/packages/google-cloud-retail/.trampolinerc
@@ -18,7 +18,6 @@
required_envvars+=(
"STAGING_BUCKET"
"V2_STAGING_BUCKET"
- "NOX_SESSION"
)
# Add env vars which are passed down into the container here.
diff --git a/packages/google-cloud-retail/CONTRIBUTING.rst b/packages/google-cloud-retail/CONTRIBUTING.rst
index adefacaacf11..b764fa966cac 100644
--- a/packages/google-cloud-retail/CONTRIBUTING.rst
+++ b/packages/google-cloud-retail/CONTRIBUTING.rst
@@ -70,9 +70,14 @@ We use `nox `__ to instrument our tests.
- To test your changes, run unit tests with ``nox``::
$ nox -s unit-2.7
- $ nox -s unit-3.7
+ $ nox -s unit-3.8
$ ...
+- Args to pytest can be passed through the nox command separated by a `--`. For
+ example, to run a single test::
+
+ $ nox -s unit-3.8 -- -k
+
.. note::
The unit tests and system tests are described in the
@@ -93,8 +98,12 @@ On Debian/Ubuntu::
************
Coding Style
************
+- We use the automatic code formatter ``black``. You can run it using
+ the nox session ``blacken``. This will eliminate many lint errors. Run via::
+
+ $ nox -s blacken
-- PEP8 compliance, with exceptions defined in the linter configuration.
+- PEP8 compliance is required, with exceptions defined in the linter configuration.
If you have ``nox`` installed, you can test that you have not introduced
any non-compliant code via::
@@ -133,13 +142,18 @@ Running System Tests
- To run system tests, you can execute::
- $ nox -s system-3.7
+ # Run all system tests
+ $ nox -s system-3.8
$ nox -s system-2.7
+ # Run a single system test
+ $ nox -s system-3.8 -- -k
+
+
.. note::
System tests are only configured to run under Python 2.7 and
- Python 3.7. For expediency, we do not run them in older versions
+ Python 3.8. For expediency, we do not run them in older versions
of Python 3.
This alone will not run the tests. You'll need to change some local
diff --git a/packages/google-cloud-retail/MANIFEST.in b/packages/google-cloud-retail/MANIFEST.in
index e9e29d12033d..e783f4c6209b 100644
--- a/packages/google-cloud-retail/MANIFEST.in
+++ b/packages/google-cloud-retail/MANIFEST.in
@@ -16,10 +16,10 @@
# Generated by synthtool. DO NOT EDIT!
include README.rst LICENSE
-recursive-include google *.json *.proto
+recursive-include google *.json *.proto py.typed
recursive-include tests *
global-exclude *.py[co]
global-exclude __pycache__
# Exclude scripts for samples readmegen
-prune scripts/readme-gen
\ No newline at end of file
+prune scripts/readme-gen
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/__init__.py b/packages/google-cloud-retail/google/cloud/retail_v2/__init__.py
index 0ba7c13ed6c7..f20ae04d3471 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/__init__.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/__init__.py
@@ -82,7 +82,6 @@
"ListCatalogsResponse",
"PredictRequest",
"PredictResponse",
- "PredictionServiceClient",
"PriceInfo",
"Product",
"ProductDetail",
@@ -103,7 +102,8 @@
"UserEventImportSummary",
"UserEventInlineSource",
"UserEventInputConfig",
+ "UserEventServiceClient",
"UserInfo",
"WriteUserEventRequest",
- "UserEventServiceClient",
+ "PredictionServiceClient",
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/async_client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/async_client.py
index 95cc56663592..b50645fcb695 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/async_client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/async_client.py
@@ -79,8 +79,36 @@ class CatalogServiceAsyncClient:
CatalogServiceClient.parse_common_location_path
)
- from_service_account_info = CatalogServiceClient.from_service_account_info
- from_service_account_file = CatalogServiceClient.from_service_account_file
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ CatalogServiceAsyncClient: The constructed client.
+ """
+ return CatalogServiceClient.from_service_account_info.__func__(CatalogServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ CatalogServiceAsyncClient: The constructed client.
+ """
+ return CatalogServiceClient.from_service_account_file.__func__(CatalogServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
from_service_account_json = from_service_account_file
@property
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/client.py
index e951309994e9..b7b40b09e0ca 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/client.py
@@ -288,21 +288,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)
- ssl_credentials = None
+ client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
- import grpc # type: ignore
-
- cert, key = client_options.client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
is_mtls = True
+ client_cert_source_func = client_options.client_cert_source
else:
- creds = SslCredentials()
- is_mtls = creds.is_mtls
- ssl_credentials = creds.ssl_credentials if is_mtls else None
+ is_mtls = mtls.has_default_client_cert_source()
+ client_cert_source_func = (
+ mtls.default_client_cert_source() if is_mtls else None
+ )
# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
@@ -345,7 +341,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
- ssl_channel_credentials=ssl_credentials,
+ client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/pagers.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/pagers.py
index e18a9515c51b..5df5143d1fde 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/pagers.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/pagers.py
@@ -15,7 +15,16 @@
# limitations under the License.
#
-from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple
+from typing import (
+ Any,
+ AsyncIterable,
+ Awaitable,
+ Callable,
+ Iterable,
+ Sequence,
+ Tuple,
+ Optional,
+)
from google.cloud.retail_v2.types import catalog
from google.cloud.retail_v2.types import catalog_service
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/base.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/base.py
index 6c5906c06368..199f863c85fc 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/base.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/base.py
@@ -68,10 +68,10 @@ def __init__(
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
@@ -79,6 +79,9 @@ def __init__(
host += ":443"
self._host = host
+ # Save the scopes.
+ self._scopes = scopes or self.AUTH_SCOPES
+
# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
@@ -88,20 +91,17 @@ def __init__(
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
- credentials_file, scopes=scopes, quota_project_id=quota_project_id
+ credentials_file, scopes=self._scopes, quota_project_id=quota_project_id
)
elif credentials is None:
credentials, _ = auth.default(
- scopes=scopes, quota_project_id=quota_project_id
+ scopes=self._scopes, quota_project_id=quota_project_id
)
# Save the credentials.
self._credentials = credentials
- # Lifted into its own function so it can be stubbed out during tests.
- self._prep_wrapped_messages(client_info)
-
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc.py
index abb6a850a67e..2a3e1e632995 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc.py
@@ -58,6 +58,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -88,6 +89,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -102,72 +107,60 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
-
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -175,17 +168,8 @@ def __init__(
],
)
- self._stubs = {} # type: Dict[str, Callable]
-
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@classmethod
def create_channel(
@@ -199,7 +183,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py
index 6195c5bd38c4..efe5aea8674d 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py
@@ -62,7 +62,7 @@ def create_channel(
) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
@@ -102,6 +102,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -133,12 +134,16 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
Raises:
@@ -147,72 +152,60 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
-
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -220,17 +213,8 @@ def __init__(
],
)
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
-
- self._stubs = {}
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@property
def grpc_channel(self) -> aio.Channel:
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/async_client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/async_client.py
index 8276b225a596..02a5f64cc0b5 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/async_client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/async_client.py
@@ -75,8 +75,36 @@ class PredictionServiceAsyncClient:
PredictionServiceClient.parse_common_location_path
)
- from_service_account_info = PredictionServiceClient.from_service_account_info
- from_service_account_file = PredictionServiceClient.from_service_account_file
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ PredictionServiceAsyncClient: The constructed client.
+ """
+ return PredictionServiceClient.from_service_account_info.__func__(PredictionServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ PredictionServiceAsyncClient: The constructed client.
+ """
+ return PredictionServiceClient.from_service_account_file.__func__(PredictionServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
from_service_account_json = from_service_account_file
@property
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/client.py
index 5727be6f2cc3..86105d26b8c6 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/client.py
@@ -292,21 +292,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)
- ssl_credentials = None
+ client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
- import grpc # type: ignore
-
- cert, key = client_options.client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
is_mtls = True
+ client_cert_source_func = client_options.client_cert_source
else:
- creds = SslCredentials()
- is_mtls = creds.is_mtls
- ssl_credentials = creds.ssl_credentials if is_mtls else None
+ is_mtls = mtls.has_default_client_cert_source()
+ client_cert_source_func = (
+ mtls.default_client_cert_source() if is_mtls else None
+ )
# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
@@ -349,7 +345,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
- ssl_channel_credentials=ssl_credentials,
+ client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/base.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/base.py
index ab258d9a170d..d76509ebf582 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/base.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/base.py
@@ -67,10 +67,10 @@ def __init__(
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
@@ -78,6 +78,9 @@ def __init__(
host += ":443"
self._host = host
+ # Save the scopes.
+ self._scopes = scopes or self.AUTH_SCOPES
+
# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
@@ -87,20 +90,17 @@ def __init__(
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
- credentials_file, scopes=scopes, quota_project_id=quota_project_id
+ credentials_file, scopes=self._scopes, quota_project_id=quota_project_id
)
elif credentials is None:
credentials, _ = auth.default(
- scopes=scopes, quota_project_id=quota_project_id
+ scopes=self._scopes, quota_project_id=quota_project_id
)
# Save the credentials.
self._credentials = credentials
- # Lifted into its own function so it can be stubbed out during tests.
- self._prep_wrapped_messages(client_info)
-
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc.py
index 9448d7d83232..cc8b8a45f3c4 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc.py
@@ -57,6 +57,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -87,6 +88,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -101,72 +106,60 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
-
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
-
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -174,17 +167,8 @@ def __init__(
],
)
- self._stubs = {} # type: Dict[str, Callable]
-
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@classmethod
def create_channel(
@@ -198,7 +182,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py
index 0d8525fd2c2a..39eec67f5b75 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py
@@ -61,7 +61,7 @@ def create_channel(
) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
@@ -101,6 +101,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -132,12 +133,16 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
Raises:
@@ -146,72 +151,60 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
-
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
-
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -219,17 +212,8 @@ def __init__(
],
)
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
-
- self._stubs = {}
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@property
def grpc_channel(self) -> aio.Channel:
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/async_client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/async_client.py
index 9980b3797113..ce1f96d31fcc 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/async_client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/async_client.py
@@ -88,8 +88,36 @@ class ProductServiceAsyncClient:
ProductServiceClient.parse_common_location_path
)
- from_service_account_info = ProductServiceClient.from_service_account_info
- from_service_account_file = ProductServiceClient.from_service_account_file
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ ProductServiceAsyncClient: The constructed client.
+ """
+ return ProductServiceClient.from_service_account_info.__func__(ProductServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ ProductServiceAsyncClient: The constructed client.
+ """
+ return ProductServiceClient.from_service_account_file.__func__(ProductServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
from_service_account_json = from_service_account_file
@property
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/client.py
index 8581183227ca..c7f65839f6b4 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/client.py
@@ -317,21 +317,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)
- ssl_credentials = None
+ client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
- import grpc # type: ignore
-
- cert, key = client_options.client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
is_mtls = True
+ client_cert_source_func = client_options.client_cert_source
else:
- creds = SslCredentials()
- is_mtls = creds.is_mtls
- ssl_credentials = creds.ssl_credentials if is_mtls else None
+ is_mtls = mtls.has_default_client_cert_source()
+ client_cert_source_func = (
+ mtls.default_client_cert_source() if is_mtls else None
+ )
# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
@@ -374,7 +370,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
- ssl_channel_credentials=ssl_credentials,
+ client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/base.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/base.py
index 7db37c0d5df1..914e01127a13 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/base.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/base.py
@@ -73,10 +73,10 @@ def __init__(
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
@@ -84,6 +84,9 @@ def __init__(
host += ":443"
self._host = host
+ # Save the scopes.
+ self._scopes = scopes or self.AUTH_SCOPES
+
# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
@@ -93,20 +96,17 @@ def __init__(
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
- credentials_file, scopes=scopes, quota_project_id=quota_project_id
+ credentials_file, scopes=self._scopes, quota_project_id=quota_project_id
)
elif credentials is None:
credentials, _ = auth.default(
- scopes=scopes, quota_project_id=quota_project_id
+ scopes=self._scopes, quota_project_id=quota_project_id
)
# Save the credentials.
self._credentials = credentials
- # Lifted into its own function so it can be stubbed out during tests.
- self._prep_wrapped_messages(client_info)
-
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc.py
index 93af7f030882..3de3736c144d 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc.py
@@ -64,6 +64,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -94,6 +95,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -108,72 +113,61 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
-
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -181,18 +175,8 @@ def __init__(
],
)
- self._stubs = {} # type: Dict[str, Callable]
- self._operations_client = None
-
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@classmethod
def create_channel(
@@ -206,7 +190,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py
index 465eed7d64e6..c9100b79fd0c 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py
@@ -68,7 +68,7 @@ def create_channel(
) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
@@ -108,6 +108,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -139,12 +140,16 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
Raises:
@@ -153,72 +158,61 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
-
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -226,18 +220,8 @@ def __init__(
],
)
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
-
- self._stubs = {}
- self._operations_client = None
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@property
def grpc_channel(self) -> aio.Channel:
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/async_client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/async_client.py
index 2729d74d8c19..b435d2847579 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/async_client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/async_client.py
@@ -86,8 +86,36 @@ class UserEventServiceAsyncClient:
UserEventServiceClient.parse_common_location_path
)
- from_service_account_info = UserEventServiceClient.from_service_account_info
- from_service_account_file = UserEventServiceClient.from_service_account_file
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ UserEventServiceAsyncClient: The constructed client.
+ """
+ return UserEventServiceClient.from_service_account_info.__func__(UserEventServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ UserEventServiceAsyncClient: The constructed client.
+ """
+ return UserEventServiceClient.from_service_account_file.__func__(UserEventServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
from_service_account_json = from_service_account_file
@property
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/client.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/client.py
index b178d41062e9..a5adcdeca0d5 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/client.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/client.py
@@ -301,21 +301,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)
- ssl_credentials = None
+ client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
- import grpc # type: ignore
-
- cert, key = client_options.client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
is_mtls = True
+ client_cert_source_func = client_options.client_cert_source
else:
- creds = SslCredentials()
- is_mtls = creds.is_mtls
- ssl_credentials = creds.ssl_credentials if is_mtls else None
+ is_mtls = mtls.has_default_client_cert_source()
+ client_cert_source_func = (
+ mtls.default_client_cert_source() if is_mtls else None
+ )
# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
@@ -358,7 +354,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
- ssl_channel_credentials=ssl_credentials,
+ client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/base.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/base.py
index eebc342bee7d..69909d9a3be0 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/base.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/base.py
@@ -73,10 +73,10 @@ def __init__(
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
@@ -84,6 +84,9 @@ def __init__(
host += ":443"
self._host = host
+ # Save the scopes.
+ self._scopes = scopes or self.AUTH_SCOPES
+
# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
@@ -93,20 +96,17 @@ def __init__(
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
- credentials_file, scopes=scopes, quota_project_id=quota_project_id
+ credentials_file, scopes=self._scopes, quota_project_id=quota_project_id
)
elif credentials is None:
credentials, _ = auth.default(
- scopes=scopes, quota_project_id=quota_project_id
+ scopes=self._scopes, quota_project_id=quota_project_id
)
# Save the credentials.
self._credentials = credentials
- # Lifted into its own function so it can be stubbed out during tests.
- self._prep_wrapped_messages(client_info)
-
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc.py
index d2a112df1221..d72f23bfaec9 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc.py
@@ -64,6 +64,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -94,6 +95,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -108,72 +113,61 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
-
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -181,18 +175,8 @@ def __init__(
],
)
- self._stubs = {} # type: Dict[str, Callable]
- self._operations_client = None
-
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@classmethod
def create_channel(
@@ -206,7 +190,7 @@ def create_channel(
) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py
index 1022639d5d2f..18a15fb2e925 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py
@@ -68,7 +68,7 @@ def create_channel(
) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
- address (Optional[str]): The host for the channel to use.
+ host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
@@ -108,6 +108,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
@@ -139,12 +140,16 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
- client_info (google.api_core.gapic_v1.client_info.ClientInfo):
- The client info used to send a user-agent string along with
- API requests. If ``None``, then default info will be used.
- Generally, you only need to set this if you're developing
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
your own client library.
Raises:
@@ -153,72 +158,61 @@ def __init__(
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
+ self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if channel:
- # Sanity check: Ensure that channel and credentials are not both
- # provided.
+ # Ignore credentials if a channel was passed.
credentials = False
-
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
- elif api_mtls_endpoint:
- warnings.warn(
- "api_mtls_endpoint and client_cert_source are deprecated",
- DeprecationWarning,
- )
- host = (
- api_mtls_endpoint
- if ":" in api_mtls_endpoint
- else api_mtls_endpoint + ":443"
- )
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
-
- # Create SSL credentials with client_cert_source or application
- # default SSL credentials.
- if client_cert_source:
- cert, key = client_cert_source()
- ssl_credentials = grpc.ssl_channel_credentials(
- certificate_chain=cert, private_key=key
- )
else:
- ssl_credentials = SslCredentials().ssl_credentials
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
- # create a new channel. The provided one is ignored.
- self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
- credentials_file=credentials_file,
- ssl_credentials=ssl_credentials,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- options=[
- ("grpc.max_send_message_length", -1),
- ("grpc.max_receive_message_length", -1),
- ],
- )
- self._ssl_channel_credentials = ssl_credentials
- else:
- host = host if ":" in host else host + ":443"
-
- if credentials is None:
- credentials, _ = auth.default(
- scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
- )
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ )
- # create a new channel. The provided one is ignored.
+ if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
- host,
- credentials=credentials,
+ self._host,
+ credentials=self._credentials,
credentials_file=credentials_file,
- ssl_credentials=ssl_channel_credentials,
- scopes=scopes or self.AUTH_SCOPES,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
@@ -226,18 +220,8 @@ def __init__(
],
)
- # Run the base constructor.
- super().__init__(
- host=host,
- credentials=credentials,
- credentials_file=credentials_file,
- scopes=scopes or self.AUTH_SCOPES,
- quota_project_id=quota_project_id,
- client_info=client_info,
- )
-
- self._stubs = {}
- self._operations_client = None
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
@property
def grpc_channel(self) -> aio.Channel:
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/types/__init__.py b/packages/google-cloud-retail/google/cloud/retail_v2/types/__init__.py
index ca1f4b800090..77a897de49fe 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/types/__init__.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/types/__init__.py
@@ -16,8 +16,8 @@
#
from .catalog import (
- ProductLevelConfig,
Catalog,
+ ProductLevelConfig,
)
from .catalog_service import (
ListCatalogsRequest,
@@ -30,53 +30,53 @@
PriceInfo,
UserInfo,
)
-from .product import Product
-from .user_event import (
- UserEvent,
- ProductDetail,
- PurchaseTransaction,
-)
from .import_config import (
- GcsSource,
BigQuerySource,
- ProductInlineSource,
- UserEventInlineSource,
+ GcsSource,
ImportErrorsConfig,
- ImportProductsRequest,
- ImportUserEventsRequest,
- ProductInputConfig,
- UserEventInputConfig,
ImportMetadata,
+ ImportProductsRequest,
ImportProductsResponse,
+ ImportUserEventsRequest,
ImportUserEventsResponse,
+ ProductInlineSource,
+ ProductInputConfig,
UserEventImportSummary,
+ UserEventInlineSource,
+ UserEventInputConfig,
)
from .prediction_service import (
PredictRequest,
PredictResponse,
)
+from .product import Product
+from .product_service import (
+ CreateProductRequest,
+ DeleteProductRequest,
+ GetProductRequest,
+ UpdateProductRequest,
+)
from .purge_config import (
PurgeMetadata,
PurgeUserEventsRequest,
PurgeUserEventsResponse,
)
-from .product_service import (
- CreateProductRequest,
- GetProductRequest,
- UpdateProductRequest,
- DeleteProductRequest,
+from .user_event import (
+ ProductDetail,
+ PurchaseTransaction,
+ UserEvent,
)
from .user_event_service import (
- WriteUserEventRequest,
CollectUserEventRequest,
+ RejoinUserEventsMetadata,
RejoinUserEventsRequest,
RejoinUserEventsResponse,
- RejoinUserEventsMetadata,
+ WriteUserEventRequest,
)
__all__ = (
- "ProductLevelConfig",
"Catalog",
+ "ProductLevelConfig",
"ListCatalogsRequest",
"ListCatalogsResponse",
"UpdateCatalogRequest",
@@ -84,35 +84,35 @@
"Image",
"PriceInfo",
"UserInfo",
- "Product",
- "UserEvent",
- "ProductDetail",
- "PurchaseTransaction",
- "GcsSource",
"BigQuerySource",
- "ProductInlineSource",
- "UserEventInlineSource",
+ "GcsSource",
"ImportErrorsConfig",
- "ImportProductsRequest",
- "ImportUserEventsRequest",
- "ProductInputConfig",
- "UserEventInputConfig",
"ImportMetadata",
+ "ImportProductsRequest",
"ImportProductsResponse",
+ "ImportUserEventsRequest",
"ImportUserEventsResponse",
+ "ProductInlineSource",
+ "ProductInputConfig",
"UserEventImportSummary",
+ "UserEventInlineSource",
+ "UserEventInputConfig",
"PredictRequest",
"PredictResponse",
- "PurgeMetadata",
- "PurgeUserEventsRequest",
- "PurgeUserEventsResponse",
+ "Product",
"CreateProductRequest",
+ "DeleteProductRequest",
"GetProductRequest",
"UpdateProductRequest",
- "DeleteProductRequest",
- "WriteUserEventRequest",
+ "PurgeMetadata",
+ "PurgeUserEventsRequest",
+ "PurgeUserEventsResponse",
+ "ProductDetail",
+ "PurchaseTransaction",
+ "UserEvent",
"CollectUserEventRequest",
+ "RejoinUserEventsMetadata",
"RejoinUserEventsRequest",
"RejoinUserEventsResponse",
- "RejoinUserEventsMetadata",
+ "WriteUserEventRequest",
)
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/types/import_config.py b/packages/google-cloud-retail/google/cloud/retail_v2/types/import_config.py
index ca2ab88579a6..a1ff0d7ce800 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/types/import_config.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/types/import_config.py
@@ -190,7 +190,7 @@ class ImportProductsRequest(proto.Message):
Attributes:
parent (str):
Required.
- "projects/1234/locations/global/catalogs/default_catalog/branches/default_branch"
+ ``projects/1234/locations/global/catalogs/default_catalog/branches/default_branch``
If no updateMask is specified, requires products.create
permission. If updateMask is specified, requires
@@ -222,7 +222,7 @@ class ImportUserEventsRequest(proto.Message):
Attributes:
parent (str):
Required.
- "projects/1234/locations/global/catalogs/default_catalog".
+ ``projects/1234/locations/global/catalogs/default_catalog``
input_config (google.cloud.retail_v2.types.UserEventInputConfig):
Required. The desired input location of the
data.
diff --git a/packages/google-cloud-retail/google/cloud/retail_v2/types/user_event_service.py b/packages/google-cloud-retail/google/cloud/retail_v2/types/user_event_service.py
index 06d3621df15c..8403ed8236b3 100644
--- a/packages/google-cloud-retail/google/cloud/retail_v2/types/user_event_service.py
+++ b/packages/google-cloud-retail/google/cloud/retail_v2/types/user_event_service.py
@@ -39,7 +39,7 @@ class WriteUserEventRequest(proto.Message):
Attributes:
parent (str):
Required. The parent catalog resource name, such as
- "projects/1234/locations/global/catalogs/default_catalog".
+ ``projects/1234/locations/global/catalogs/default_catalog``.
user_event (google.cloud.retail_v2.types.UserEvent):
Required. User event to write.
"""
@@ -55,7 +55,7 @@ class CollectUserEventRequest(proto.Message):
Attributes:
parent (str):
Required. The parent catalog name, such as
- "projects/1234/locations/global/catalogs/default_catalog".
+ ``projects/1234/locations/global/catalogs/default_catalog``.
user_event (str):
Required. URL encoded UserEvent proto with a
length limit of 2,000,000 characters.
@@ -87,7 +87,7 @@ class RejoinUserEventsRequest(proto.Message):
Attributes:
parent (str):
Required. The parent catalog resource name, such as
- "projects/1234/locations/global/catalogs/default_catalog".
+ ``projects/1234/locations/global/catalogs/default_catalog``.
user_event_rejoin_scope (google.cloud.retail_v2.types.RejoinUserEventsRequest.UserEventRejoinScope):
The type of the user event rejoin to define the scope and
range of the user events to be rejoined with the latest
diff --git a/packages/google-cloud-retail/noxfile.py b/packages/google-cloud-retail/noxfile.py
index e938b7f369d5..8e9b9413d302 100644
--- a/packages/google-cloud-retail/noxfile.py
+++ b/packages/google-cloud-retail/noxfile.py
@@ -18,6 +18,7 @@
from __future__ import absolute_import
import os
+import pathlib
import shutil
import nox
@@ -30,6 +31,8 @@
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
+CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
+
# 'docfx' is excluded since it only needs to run in 'docs-presubmit'
nox.options.sessions = [
"unit",
@@ -41,6 +44,9 @@
"docs",
]
+# Error if a python version is missing
+nox.options.error_on_missing_interpreters = True
+
@nox.session(python=DEFAULT_PYTHON_VERSION)
def lint(session):
@@ -81,17 +87,21 @@ def lint_setup_py(session):
def default(session):
# Install all test dependencies, then install this package in-place.
- session.install("asyncmock", "pytest-asyncio")
- session.install(
- "mock", "pytest", "pytest-cov",
+ constraints_path = str(
+ CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
)
- session.install("-e", ".")
+ session.install("asyncmock", "pytest-asyncio", "-c", constraints_path)
+
+ session.install("mock", "pytest", "pytest-cov", "-c", constraints_path)
+
+ session.install("-e", ".", "-c", constraints_path)
# Run py.test against the unit tests.
session.run(
"py.test",
"--quiet",
+ f"--junitxml=unit_{session.python}_sponge_log.xml",
"--cov=google/cloud",
"--cov=tests/unit",
"--cov-append",
@@ -112,6 +122,9 @@ def unit(session):
@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
def system(session):
"""Run the system test suite."""
+ constraints_path = str(
+ CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
+ )
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")
@@ -121,6 +134,9 @@ def system(session):
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""):
session.skip("Credentials must be set via environment variable")
+ # Install pyopenssl for mTLS testing.
+ if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true":
+ session.install("pyopenssl")
system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
@@ -133,16 +149,26 @@ def system(session):
# Install all test dependencies, then install this package into the
# virtualenv's dist-packages.
- session.install(
- "mock", "pytest", "google-cloud-testutils",
- )
- session.install("-e", ".")
+ session.install("mock", "pytest", "google-cloud-testutils", "-c", constraints_path)
+ session.install("-e", ".", "-c", constraints_path)
# Run py.test against the system tests.
if system_test_exists:
- session.run("py.test", "--quiet", system_test_path, *session.posargs)
+ session.run(
+ "py.test",
+ "--quiet",
+ f"--junitxml=system_{session.python}_sponge_log.xml",
+ system_test_path,
+ *session.posargs,
+ )
if system_test_folder_exists:
- session.run("py.test", "--quiet", system_test_folder_path, *session.posargs)
+ session.run(
+ "py.test",
+ "--quiet",
+ f"--junitxml=system_{session.python}_sponge_log.xml",
+ system_test_folder_path,
+ *session.posargs,
+ )
@nox.session(python=DEFAULT_PYTHON_VERSION)
diff --git a/packages/google-cloud-retail/renovate.json b/packages/google-cloud-retail/renovate.json
index 4fa949311b20..f08bc22c9a55 100644
--- a/packages/google-cloud-retail/renovate.json
+++ b/packages/google-cloud-retail/renovate.json
@@ -1,5 +1,6 @@
{
"extends": [
"config:base", ":preserveSemverRanges"
- ]
+ ],
+ "ignorePaths": [".pre-commit-config.yaml"]
}
diff --git a/packages/google-cloud-retail/synth.metadata b/packages/google-cloud-retail/synth.metadata
index 04b67db75f06..952757583085 100644
--- a/packages/google-cloud-retail/synth.metadata
+++ b/packages/google-cloud-retail/synth.metadata
@@ -3,21 +3,30 @@
{
"git": {
"name": ".",
- "remote": "sso://devrel/cloud/libraries/python/python-retail"
+ "remote": "https://github.com/googleapis/python-retail.git",
+ "sha": "06dbca264cab76c37876557a6e9954fa1105988f"
+ }
+ },
+ {
+ "git": {
+ "name": "googleapis",
+ "remote": "https://github.com/googleapis/googleapis.git",
+ "sha": "c2bdbfa6f7423369da902a5faaa86bfd213b5169",
+ "internalRef": "364696134"
}
},
{
"git": {
"name": "synthtool",
"remote": "https://github.com/googleapis/synthtool.git",
- "sha": "16ec872dd898d7de6e1822badfac32484b5d9031"
+ "sha": "86ed43d4f56e6404d068e62e497029018879c771"
}
},
{
"git": {
"name": "synthtool",
"remote": "https://github.com/googleapis/synthtool.git",
- "sha": "16ec872dd898d7de6e1822badfac32484b5d9031"
+ "sha": "86ed43d4f56e6404d068e62e497029018879c771"
}
}
],
@@ -31,5 +40,136 @@
"generator": "bazel"
}
}
+ ],
+ "generatedFiles": [
+ ".coveragerc",
+ ".flake8",
+ ".github/CONTRIBUTING.md",
+ ".github/ISSUE_TEMPLATE/bug_report.md",
+ ".github/ISSUE_TEMPLATE/feature_request.md",
+ ".github/ISSUE_TEMPLATE/support_request.md",
+ ".github/PULL_REQUEST_TEMPLATE.md",
+ ".github/header-checker-lint.yml",
+ ".github/release-please.yml",
+ ".github/snippet-bot.yml",
+ ".gitignore",
+ ".kokoro/build.sh",
+ ".kokoro/continuous/common.cfg",
+ ".kokoro/continuous/continuous.cfg",
+ ".kokoro/docker/docs/Dockerfile",
+ ".kokoro/docker/docs/fetch_gpg_keys.sh",
+ ".kokoro/docs/common.cfg",
+ ".kokoro/docs/docs-presubmit.cfg",
+ ".kokoro/docs/docs.cfg",
+ ".kokoro/populate-secrets.sh",
+ ".kokoro/presubmit/common.cfg",
+ ".kokoro/presubmit/presubmit.cfg",
+ ".kokoro/publish-docs.sh",
+ ".kokoro/release.sh",
+ ".kokoro/release/common.cfg",
+ ".kokoro/release/release.cfg",
+ ".kokoro/samples/lint/common.cfg",
+ ".kokoro/samples/lint/continuous.cfg",
+ ".kokoro/samples/lint/periodic.cfg",
+ ".kokoro/samples/lint/presubmit.cfg",
+ ".kokoro/samples/python3.6/common.cfg",
+ ".kokoro/samples/python3.6/continuous.cfg",
+ ".kokoro/samples/python3.6/periodic-head.cfg",
+ ".kokoro/samples/python3.6/periodic.cfg",
+ ".kokoro/samples/python3.6/presubmit.cfg",
+ ".kokoro/samples/python3.7/common.cfg",
+ ".kokoro/samples/python3.7/continuous.cfg",
+ ".kokoro/samples/python3.7/periodic-head.cfg",
+ ".kokoro/samples/python3.7/periodic.cfg",
+ ".kokoro/samples/python3.7/presubmit.cfg",
+ ".kokoro/samples/python3.8/common.cfg",
+ ".kokoro/samples/python3.8/continuous.cfg",
+ ".kokoro/samples/python3.8/periodic-head.cfg",
+ ".kokoro/samples/python3.8/periodic.cfg",
+ ".kokoro/samples/python3.8/presubmit.cfg",
+ ".kokoro/test-samples-against-head.sh",
+ ".kokoro/test-samples-impl.sh",
+ ".kokoro/test-samples.sh",
+ ".kokoro/trampoline.sh",
+ ".kokoro/trampoline_v2.sh",
+ ".pre-commit-config.yaml",
+ ".trampolinerc",
+ "CODE_OF_CONDUCT.md",
+ "CONTRIBUTING.rst",
+ "LICENSE",
+ "MANIFEST.in",
+ "docs/_static/custom.css",
+ "docs/_templates/layout.html",
+ "docs/conf.py",
+ "docs/multiprocessing.rst",
+ "docs/retail_v2/catalog_service.rst",
+ "docs/retail_v2/prediction_service.rst",
+ "docs/retail_v2/product_service.rst",
+ "docs/retail_v2/services.rst",
+ "docs/retail_v2/types.rst",
+ "docs/retail_v2/user_event_service.rst",
+ "google/cloud/retail/__init__.py",
+ "google/cloud/retail/py.typed",
+ "google/cloud/retail_v2/__init__.py",
+ "google/cloud/retail_v2/py.typed",
+ "google/cloud/retail_v2/services/__init__.py",
+ "google/cloud/retail_v2/services/catalog_service/__init__.py",
+ "google/cloud/retail_v2/services/catalog_service/async_client.py",
+ "google/cloud/retail_v2/services/catalog_service/client.py",
+ "google/cloud/retail_v2/services/catalog_service/pagers.py",
+ "google/cloud/retail_v2/services/catalog_service/transports/__init__.py",
+ "google/cloud/retail_v2/services/catalog_service/transports/base.py",
+ "google/cloud/retail_v2/services/catalog_service/transports/grpc.py",
+ "google/cloud/retail_v2/services/catalog_service/transports/grpc_asyncio.py",
+ "google/cloud/retail_v2/services/prediction_service/__init__.py",
+ "google/cloud/retail_v2/services/prediction_service/async_client.py",
+ "google/cloud/retail_v2/services/prediction_service/client.py",
+ "google/cloud/retail_v2/services/prediction_service/transports/__init__.py",
+ "google/cloud/retail_v2/services/prediction_service/transports/base.py",
+ "google/cloud/retail_v2/services/prediction_service/transports/grpc.py",
+ "google/cloud/retail_v2/services/prediction_service/transports/grpc_asyncio.py",
+ "google/cloud/retail_v2/services/product_service/__init__.py",
+ "google/cloud/retail_v2/services/product_service/async_client.py",
+ "google/cloud/retail_v2/services/product_service/client.py",
+ "google/cloud/retail_v2/services/product_service/transports/__init__.py",
+ "google/cloud/retail_v2/services/product_service/transports/base.py",
+ "google/cloud/retail_v2/services/product_service/transports/grpc.py",
+ "google/cloud/retail_v2/services/product_service/transports/grpc_asyncio.py",
+ "google/cloud/retail_v2/services/user_event_service/__init__.py",
+ "google/cloud/retail_v2/services/user_event_service/async_client.py",
+ "google/cloud/retail_v2/services/user_event_service/client.py",
+ "google/cloud/retail_v2/services/user_event_service/transports/__init__.py",
+ "google/cloud/retail_v2/services/user_event_service/transports/base.py",
+ "google/cloud/retail_v2/services/user_event_service/transports/grpc.py",
+ "google/cloud/retail_v2/services/user_event_service/transports/grpc_asyncio.py",
+ "google/cloud/retail_v2/types/__init__.py",
+ "google/cloud/retail_v2/types/catalog.py",
+ "google/cloud/retail_v2/types/catalog_service.py",
+ "google/cloud/retail_v2/types/common.py",
+ "google/cloud/retail_v2/types/import_config.py",
+ "google/cloud/retail_v2/types/prediction_service.py",
+ "google/cloud/retail_v2/types/product.py",
+ "google/cloud/retail_v2/types/product_service.py",
+ "google/cloud/retail_v2/types/purge_config.py",
+ "google/cloud/retail_v2/types/user_event.py",
+ "google/cloud/retail_v2/types/user_event_service.py",
+ "mypy.ini",
+ "noxfile.py",
+ "renovate.json",
+ "scripts/decrypt-secrets.sh",
+ "scripts/fixup_retail_v2_keywords.py",
+ "scripts/readme-gen/readme_gen.py",
+ "scripts/readme-gen/templates/README.tmpl.rst",
+ "scripts/readme-gen/templates/auth.tmpl.rst",
+ "scripts/readme-gen/templates/auth_api_key.tmpl.rst",
+ "scripts/readme-gen/templates/install_deps.tmpl.rst",
+ "scripts/readme-gen/templates/install_portaudio.tmpl.rst",
+ "setup.cfg",
+ "testing/.gitignore",
+ "tests/unit/gapic/retail_v2/__init__.py",
+ "tests/unit/gapic/retail_v2/test_catalog_service.py",
+ "tests/unit/gapic/retail_v2/test_prediction_service.py",
+ "tests/unit/gapic/retail_v2/test_product_service.py",
+ "tests/unit/gapic/retail_v2/test_user_event_service.py"
]
}
\ No newline at end of file
diff --git a/packages/google-cloud-retail/testing/constraints-3.7.txt b/packages/google-cloud-retail/testing/constraints-3.7.txt
index e69de29bb2d1..da93009be5fe 100644
--- a/packages/google-cloud-retail/testing/constraints-3.7.txt
+++ b/packages/google-cloud-retail/testing/constraints-3.7.txt
@@ -0,0 +1,2 @@
+# This constraints file is left inentionally empty
+# so the latest version of dependencies is installed
\ No newline at end of file
diff --git a/packages/google-cloud-retail/testing/constraints-3.8.txt b/packages/google-cloud-retail/testing/constraints-3.8.txt
index e69de29bb2d1..da93009be5fe 100644
--- a/packages/google-cloud-retail/testing/constraints-3.8.txt
+++ b/packages/google-cloud-retail/testing/constraints-3.8.txt
@@ -0,0 +1,2 @@
+# This constraints file is left inentionally empty
+# so the latest version of dependencies is installed
\ No newline at end of file
diff --git a/packages/google-cloud-retail/testing/constraints-3.9.txt b/packages/google-cloud-retail/testing/constraints-3.9.txt
index e69de29bb2d1..da93009be5fe 100644
--- a/packages/google-cloud-retail/testing/constraints-3.9.txt
+++ b/packages/google-cloud-retail/testing/constraints-3.9.txt
@@ -0,0 +1,2 @@
+# This constraints file is left inentionally empty
+# so the latest version of dependencies is installed
\ No newline at end of file
diff --git a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/__init__.py b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/__init__.py
index 8b137891791f..42ffdf2bc43d 100644
--- a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/__init__.py
+++ b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/__init__.py
@@ -1 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_catalog_service.py b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_catalog_service.py
index 4698c63cbb61..4736ef59535b 100644
--- a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_catalog_service.py
+++ b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_catalog_service.py
@@ -87,15 +87,19 @@ def test__get_default_mtls_endpoint():
)
-def test_catalog_service_client_from_service_account_info():
+@pytest.mark.parametrize(
+ "client_class", [CatalogServiceClient, CatalogServiceAsyncClient,]
+)
+def test_catalog_service_client_from_service_account_info(client_class):
creds = credentials.AnonymousCredentials()
with mock.patch.object(
service_account.Credentials, "from_service_account_info"
) as factory:
factory.return_value = creds
info = {"valid": True}
- client = CatalogServiceClient.from_service_account_info(info)
+ client = client_class.from_service_account_info(info)
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -111,9 +115,11 @@ def test_catalog_service_client_from_service_account_file(client_class):
factory.return_value = creds
client = client_class.from_service_account_file("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
client = client_class.from_service_account_json("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -174,7 +180,7 @@ def test_catalog_service_client_client_options(
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -190,7 +196,7 @@ def test_catalog_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -206,7 +212,7 @@ def test_catalog_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_MTLS_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -234,7 +240,7 @@ def test_catalog_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id="octopus",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -285,29 +291,25 @@ def test_catalog_service_client_mtls_env_auto(
client_cert_source=client_cert_source_callback
)
with mock.patch.object(transport_class, "__init__") as patched:
- ssl_channel_creds = mock.Mock()
- with mock.patch(
- "grpc.ssl_channel_credentials", return_value=ssl_channel_creds
- ):
- patched.return_value = None
- client = client_class(client_options=options)
+ patched.return_value = None
+ client = client_class(client_options=options)
- if use_client_cert_env == "false":
- expected_ssl_channel_creds = None
- expected_host = client.DEFAULT_ENDPOINT
- else:
- expected_ssl_channel_creds = ssl_channel_creds
- expected_host = client.DEFAULT_MTLS_ENDPOINT
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
# Check the case ADC client cert is provided. Whether client cert is used depends on
# GOOGLE_API_USE_CLIENT_CERTIFICATE value.
@@ -316,66 +318,53 @@ def test_catalog_service_client_mtls_env_auto(
):
with mock.patch.object(transport_class, "__init__") as patched:
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
):
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.ssl_credentials",
- new_callable=mock.PropertyMock,
- ) as ssl_credentials_mock:
- if use_client_cert_env == "false":
- is_mtls_mock.return_value = False
- ssl_credentials_mock.return_value = None
- expected_host = client.DEFAULT_ENDPOINT
- expected_ssl_channel_creds = None
- else:
- is_mtls_mock.return_value = True
- ssl_credentials_mock.return_value = mock.Mock()
- expected_host = client.DEFAULT_MTLS_ENDPOINT
- expected_ssl_channel_creds = (
- ssl_credentials_mock.return_value
- )
-
- patched.return_value = None
- client = client_class()
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
- # Check the case client_cert_source and ADC client cert are not provided.
- with mock.patch.dict(
- os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
- ):
- with mock.patch.object(transport_class, "__init__") as patched:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
- ):
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- is_mtls_mock.return_value = False
patched.return_value = None
client = client_class()
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=expected_host,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -401,7 +390,7 @@ def test_catalog_service_client_client_options_scopes(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=["1", "2"],
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -431,7 +420,7 @@ def test_catalog_service_client_client_options_credentials_file(
credentials_file="credentials.json",
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -450,7 +439,7 @@ def test_catalog_service_client_client_options_from_dict():
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -493,6 +482,22 @@ def test_list_catalogs_from_dict():
test_list_catalogs(request_type=dict)
+def test_list_catalogs_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = CatalogServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_catalogs), "__call__") as call:
+ client.list_catalogs()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == catalog_service.ListCatalogsRequest()
+
+
@pytest.mark.asyncio
async def test_list_catalogs_async(
transport: str = "grpc_asyncio", request_type=catalog_service.ListCatalogsRequest
@@ -812,6 +817,22 @@ def test_update_catalog_from_dict():
test_update_catalog(request_type=dict)
+def test_update_catalog_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = CatalogServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_catalog), "__call__") as call:
+ client.update_catalog()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == catalog_service.UpdateCatalogRequest()
+
+
@pytest.mark.asyncio
async def test_update_catalog_async(
transport: str = "grpc_asyncio", request_type=catalog_service.UpdateCatalogRequest
@@ -1144,6 +1165,51 @@ def test_catalog_service_transport_auth_adc():
)
+@pytest.mark.parametrize(
+ "transport_class",
+ [
+ transports.CatalogServiceGrpcTransport,
+ transports.CatalogServiceGrpcAsyncIOTransport,
+ ],
+)
+def test_catalog_service_grpc_transport_client_cert_source_for_mtls(transport_class):
+ cred = credentials.AnonymousCredentials()
+
+ # Check ssl_channel_credentials is used if provided.
+ with mock.patch.object(transport_class, "create_channel") as mock_create_channel:
+ mock_ssl_channel_creds = mock.Mock()
+ transport_class(
+ host="squid.clam.whelk",
+ credentials=cred,
+ ssl_channel_credentials=mock_ssl_channel_creds,
+ )
+ mock_create_channel.assert_called_once_with(
+ "squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ ssl_credentials=mock_ssl_channel_creds,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
+ # is used.
+ with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
+ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred:
+ transport_class(
+ credentials=cred,
+ client_cert_source_for_mtls=client_cert_source_callback,
+ )
+ expected_cert, expected_key = client_cert_source_callback()
+ mock_ssl_cred.assert_called_once_with(
+ certificate_chain=expected_cert, private_key=expected_key
+ )
+
+
def test_catalog_service_host_no_port():
client = CatalogServiceClient(
credentials=credentials.AnonymousCredentials(),
@@ -1188,6 +1254,8 @@ def test_catalog_service_grpc_asyncio_transport_channel():
assert transport._ssl_channel_credentials == None
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
@@ -1240,6 +1308,8 @@ def test_catalog_service_transport_channel_mtls_with_client_cert_source(
assert transport._ssl_channel_credentials == mock_ssl_cred
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
diff --git a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_prediction_service.py b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_prediction_service.py
index fdd420f50cc0..41592b9bf731 100644
--- a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_prediction_service.py
+++ b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_prediction_service.py
@@ -92,15 +92,19 @@ def test__get_default_mtls_endpoint():
)
-def test_prediction_service_client_from_service_account_info():
+@pytest.mark.parametrize(
+ "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,]
+)
+def test_prediction_service_client_from_service_account_info(client_class):
creds = credentials.AnonymousCredentials()
with mock.patch.object(
service_account.Credentials, "from_service_account_info"
) as factory:
factory.return_value = creds
info = {"valid": True}
- client = PredictionServiceClient.from_service_account_info(info)
+ client = client_class.from_service_account_info(info)
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -116,9 +120,11 @@ def test_prediction_service_client_from_service_account_file(client_class):
factory.return_value = creds
client = client_class.from_service_account_file("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
client = client_class.from_service_account_json("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -179,7 +185,7 @@ def test_prediction_service_client_client_options(
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -195,7 +201,7 @@ def test_prediction_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -211,7 +217,7 @@ def test_prediction_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_MTLS_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -239,7 +245,7 @@ def test_prediction_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id="octopus",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -300,29 +306,25 @@ def test_prediction_service_client_mtls_env_auto(
client_cert_source=client_cert_source_callback
)
with mock.patch.object(transport_class, "__init__") as patched:
- ssl_channel_creds = mock.Mock()
- with mock.patch(
- "grpc.ssl_channel_credentials", return_value=ssl_channel_creds
- ):
- patched.return_value = None
- client = client_class(client_options=options)
+ patched.return_value = None
+ client = client_class(client_options=options)
- if use_client_cert_env == "false":
- expected_ssl_channel_creds = None
- expected_host = client.DEFAULT_ENDPOINT
- else:
- expected_ssl_channel_creds = ssl_channel_creds
- expected_host = client.DEFAULT_MTLS_ENDPOINT
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
# Check the case ADC client cert is provided. Whether client cert is used depends on
# GOOGLE_API_USE_CLIENT_CERTIFICATE value.
@@ -331,66 +333,53 @@ def test_prediction_service_client_mtls_env_auto(
):
with mock.patch.object(transport_class, "__init__") as patched:
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
):
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.ssl_credentials",
- new_callable=mock.PropertyMock,
- ) as ssl_credentials_mock:
- if use_client_cert_env == "false":
- is_mtls_mock.return_value = False
- ssl_credentials_mock.return_value = None
- expected_host = client.DEFAULT_ENDPOINT
- expected_ssl_channel_creds = None
- else:
- is_mtls_mock.return_value = True
- ssl_credentials_mock.return_value = mock.Mock()
- expected_host = client.DEFAULT_MTLS_ENDPOINT
- expected_ssl_channel_creds = (
- ssl_credentials_mock.return_value
- )
-
- patched.return_value = None
- client = client_class()
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
- # Check the case client_cert_source and ADC client cert are not provided.
- with mock.patch.dict(
- os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
- ):
- with mock.patch.object(transport_class, "__init__") as patched:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
- ):
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- is_mtls_mock.return_value = False
patched.return_value = None
client = client_class()
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=expected_host,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -416,7 +405,7 @@ def test_prediction_service_client_client_options_scopes(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=["1", "2"],
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -446,7 +435,7 @@ def test_prediction_service_client_client_options_credentials_file(
credentials_file="credentials.json",
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -465,7 +454,7 @@ def test_prediction_service_client_client_options_from_dict():
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -514,6 +503,22 @@ def test_predict_from_dict():
test_predict(request_type=dict)
+def test_predict_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = PredictionServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.predict), "__call__") as call:
+ client.predict()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == prediction_service.PredictRequest()
+
+
@pytest.mark.asyncio
async def test_predict_async(
transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest
@@ -770,6 +775,51 @@ def test_prediction_service_transport_auth_adc():
)
+@pytest.mark.parametrize(
+ "transport_class",
+ [
+ transports.PredictionServiceGrpcTransport,
+ transports.PredictionServiceGrpcAsyncIOTransport,
+ ],
+)
+def test_prediction_service_grpc_transport_client_cert_source_for_mtls(transport_class):
+ cred = credentials.AnonymousCredentials()
+
+ # Check ssl_channel_credentials is used if provided.
+ with mock.patch.object(transport_class, "create_channel") as mock_create_channel:
+ mock_ssl_channel_creds = mock.Mock()
+ transport_class(
+ host="squid.clam.whelk",
+ credentials=cred,
+ ssl_channel_credentials=mock_ssl_channel_creds,
+ )
+ mock_create_channel.assert_called_once_with(
+ "squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ ssl_credentials=mock_ssl_channel_creds,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
+ # is used.
+ with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
+ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred:
+ transport_class(
+ credentials=cred,
+ client_cert_source_for_mtls=client_cert_source_callback,
+ )
+ expected_cert, expected_key = client_cert_source_callback()
+ mock_ssl_cred.assert_called_once_with(
+ certificate_chain=expected_cert, private_key=expected_key
+ )
+
+
def test_prediction_service_host_no_port():
client = PredictionServiceClient(
credentials=credentials.AnonymousCredentials(),
@@ -814,6 +864,8 @@ def test_prediction_service_grpc_asyncio_transport_channel():
assert transport._ssl_channel_credentials == None
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
@@ -866,6 +918,8 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source(
assert transport._ssl_channel_credentials == mock_ssl_cred
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
diff --git a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_product_service.py b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_product_service.py
index 1e754d4d57ac..63b7ba160d44 100644
--- a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_product_service.py
+++ b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_product_service.py
@@ -94,15 +94,19 @@ def test__get_default_mtls_endpoint():
)
-def test_product_service_client_from_service_account_info():
+@pytest.mark.parametrize(
+ "client_class", [ProductServiceClient, ProductServiceAsyncClient,]
+)
+def test_product_service_client_from_service_account_info(client_class):
creds = credentials.AnonymousCredentials()
with mock.patch.object(
service_account.Credentials, "from_service_account_info"
) as factory:
factory.return_value = creds
info = {"valid": True}
- client = ProductServiceClient.from_service_account_info(info)
+ client = client_class.from_service_account_info(info)
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -118,9 +122,11 @@ def test_product_service_client_from_service_account_file(client_class):
factory.return_value = creds
client = client_class.from_service_account_file("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
client = client_class.from_service_account_json("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -181,7 +187,7 @@ def test_product_service_client_client_options(
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -197,7 +203,7 @@ def test_product_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -213,7 +219,7 @@ def test_product_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_MTLS_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -241,7 +247,7 @@ def test_product_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id="octopus",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -292,29 +298,25 @@ def test_product_service_client_mtls_env_auto(
client_cert_source=client_cert_source_callback
)
with mock.patch.object(transport_class, "__init__") as patched:
- ssl_channel_creds = mock.Mock()
- with mock.patch(
- "grpc.ssl_channel_credentials", return_value=ssl_channel_creds
- ):
- patched.return_value = None
- client = client_class(client_options=options)
+ patched.return_value = None
+ client = client_class(client_options=options)
- if use_client_cert_env == "false":
- expected_ssl_channel_creds = None
- expected_host = client.DEFAULT_ENDPOINT
- else:
- expected_ssl_channel_creds = ssl_channel_creds
- expected_host = client.DEFAULT_MTLS_ENDPOINT
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
# Check the case ADC client cert is provided. Whether client cert is used depends on
# GOOGLE_API_USE_CLIENT_CERTIFICATE value.
@@ -323,66 +325,53 @@ def test_product_service_client_mtls_env_auto(
):
with mock.patch.object(transport_class, "__init__") as patched:
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
):
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.ssl_credentials",
- new_callable=mock.PropertyMock,
- ) as ssl_credentials_mock:
- if use_client_cert_env == "false":
- is_mtls_mock.return_value = False
- ssl_credentials_mock.return_value = None
- expected_host = client.DEFAULT_ENDPOINT
- expected_ssl_channel_creds = None
- else:
- is_mtls_mock.return_value = True
- ssl_credentials_mock.return_value = mock.Mock()
- expected_host = client.DEFAULT_MTLS_ENDPOINT
- expected_ssl_channel_creds = (
- ssl_credentials_mock.return_value
- )
-
- patched.return_value = None
- client = client_class()
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
- # Check the case client_cert_source and ADC client cert are not provided.
- with mock.patch.dict(
- os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
- ):
- with mock.patch.object(transport_class, "__init__") as patched:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
- ):
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- is_mtls_mock.return_value = False
patched.return_value = None
client = client_class()
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=expected_host,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -408,7 +397,7 @@ def test_product_service_client_client_options_scopes(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=["1", "2"],
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -438,7 +427,7 @@ def test_product_service_client_client_options_credentials_file(
credentials_file="credentials.json",
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -457,7 +446,7 @@ def test_product_service_client_client_options_from_dict():
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -527,6 +516,22 @@ def test_create_product_from_dict():
test_create_product(request_type=dict)
+def test_create_product_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = ProductServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_product), "__call__") as call:
+ client.create_product()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == product_service.CreateProductRequest()
+
+
@pytest.mark.asyncio
async def test_create_product_async(
transport: str = "grpc_asyncio", request_type=product_service.CreateProductRequest
@@ -794,6 +799,22 @@ def test_get_product_from_dict():
test_get_product(request_type=dict)
+def test_get_product_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = ProductServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_product), "__call__") as call:
+ client.get_product()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == product_service.GetProductRequest()
+
+
@pytest.mark.asyncio
async def test_get_product_async(
transport: str = "grpc_asyncio", request_type=product_service.GetProductRequest
@@ -1039,6 +1060,22 @@ def test_update_product_from_dict():
test_update_product(request_type=dict)
+def test_update_product_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = ProductServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_product), "__call__") as call:
+ client.update_product()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == product_service.UpdateProductRequest()
+
+
@pytest.mark.asyncio
async def test_update_product_async(
transport: str = "grpc_asyncio", request_type=product_service.UpdateProductRequest
@@ -1270,6 +1307,22 @@ def test_delete_product_from_dict():
test_delete_product(request_type=dict)
+def test_delete_product_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = ProductServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_product), "__call__") as call:
+ client.delete_product()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == product_service.DeleteProductRequest()
+
+
@pytest.mark.asyncio
async def test_delete_product_async(
transport: str = "grpc_asyncio", request_type=product_service.DeleteProductRequest
@@ -1450,6 +1503,22 @@ def test_import_products_from_dict():
test_import_products(request_type=dict)
+def test_import_products_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = ProductServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.import_products), "__call__") as call:
+ client.import_products()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == import_config.ImportProductsRequest()
+
+
@pytest.mark.asyncio
async def test_import_products_async(
transport: str = "grpc_asyncio", request_type=import_config.ImportProductsRequest
@@ -1705,6 +1774,51 @@ def test_product_service_transport_auth_adc():
)
+@pytest.mark.parametrize(
+ "transport_class",
+ [
+ transports.ProductServiceGrpcTransport,
+ transports.ProductServiceGrpcAsyncIOTransport,
+ ],
+)
+def test_product_service_grpc_transport_client_cert_source_for_mtls(transport_class):
+ cred = credentials.AnonymousCredentials()
+
+ # Check ssl_channel_credentials is used if provided.
+ with mock.patch.object(transport_class, "create_channel") as mock_create_channel:
+ mock_ssl_channel_creds = mock.Mock()
+ transport_class(
+ host="squid.clam.whelk",
+ credentials=cred,
+ ssl_channel_credentials=mock_ssl_channel_creds,
+ )
+ mock_create_channel.assert_called_once_with(
+ "squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ ssl_credentials=mock_ssl_channel_creds,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
+ # is used.
+ with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
+ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred:
+ transport_class(
+ credentials=cred,
+ client_cert_source_for_mtls=client_cert_source_callback,
+ )
+ expected_cert, expected_key = client_cert_source_callback()
+ mock_ssl_cred.assert_called_once_with(
+ certificate_chain=expected_cert, private_key=expected_key
+ )
+
+
def test_product_service_host_no_port():
client = ProductServiceClient(
credentials=credentials.AnonymousCredentials(),
@@ -1749,6 +1863,8 @@ def test_product_service_grpc_asyncio_transport_channel():
assert transport._ssl_channel_credentials == None
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
@@ -1801,6 +1917,8 @@ def test_product_service_transport_channel_mtls_with_client_cert_source(
assert transport._ssl_channel_credentials == mock_ssl_cred
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
diff --git a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_user_event_service.py b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_user_event_service.py
index 8a1417898dcb..4e17ea283e87 100644
--- a/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_user_event_service.py
+++ b/packages/google-cloud-retail/tests/unit/gapic/retail_v2/test_user_event_service.py
@@ -99,15 +99,19 @@ def test__get_default_mtls_endpoint():
)
-def test_user_event_service_client_from_service_account_info():
+@pytest.mark.parametrize(
+ "client_class", [UserEventServiceClient, UserEventServiceAsyncClient,]
+)
+def test_user_event_service_client_from_service_account_info(client_class):
creds = credentials.AnonymousCredentials()
with mock.patch.object(
service_account.Credentials, "from_service_account_info"
) as factory:
factory.return_value = creds
info = {"valid": True}
- client = UserEventServiceClient.from_service_account_info(info)
+ client = client_class.from_service_account_info(info)
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -123,9 +127,11 @@ def test_user_event_service_client_from_service_account_file(client_class):
factory.return_value = creds
client = client_class.from_service_account_file("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
client = client_class.from_service_account_json("dummy/file/path.json")
assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
assert client.transport._host == "retail.googleapis.com:443"
@@ -186,7 +192,7 @@ def test_user_event_service_client_client_options(
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -202,7 +208,7 @@ def test_user_event_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -218,7 +224,7 @@ def test_user_event_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_MTLS_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -246,7 +252,7 @@ def test_user_event_service_client_client_options(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id="octopus",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -307,29 +313,25 @@ def test_user_event_service_client_mtls_env_auto(
client_cert_source=client_cert_source_callback
)
with mock.patch.object(transport_class, "__init__") as patched:
- ssl_channel_creds = mock.Mock()
- with mock.patch(
- "grpc.ssl_channel_credentials", return_value=ssl_channel_creds
- ):
- patched.return_value = None
- client = client_class(client_options=options)
+ patched.return_value = None
+ client = client_class(client_options=options)
- if use_client_cert_env == "false":
- expected_ssl_channel_creds = None
- expected_host = client.DEFAULT_ENDPOINT
- else:
- expected_ssl_channel_creds = ssl_channel_creds
- expected_host = client.DEFAULT_MTLS_ENDPOINT
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
# Check the case ADC client cert is provided. Whether client cert is used depends on
# GOOGLE_API_USE_CLIENT_CERTIFICATE value.
@@ -338,66 +340,53 @@ def test_user_event_service_client_mtls_env_auto(
):
with mock.patch.object(transport_class, "__init__") as patched:
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
):
with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.ssl_credentials",
- new_callable=mock.PropertyMock,
- ) as ssl_credentials_mock:
- if use_client_cert_env == "false":
- is_mtls_mock.return_value = False
- ssl_credentials_mock.return_value = None
- expected_host = client.DEFAULT_ENDPOINT
- expected_ssl_channel_creds = None
- else:
- is_mtls_mock.return_value = True
- ssl_credentials_mock.return_value = mock.Mock()
- expected_host = client.DEFAULT_MTLS_ENDPOINT
- expected_ssl_channel_creds = (
- ssl_credentials_mock.return_value
- )
-
- patched.return_value = None
- client = client_class()
- patched.assert_called_once_with(
- credentials=None,
- credentials_file=None,
- host=expected_host,
- scopes=None,
- ssl_channel_credentials=expected_ssl_channel_creds,
- quota_project_id=None,
- client_info=transports.base.DEFAULT_CLIENT_INFO,
- )
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
- # Check the case client_cert_source and ADC client cert are not provided.
- with mock.patch.dict(
- os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
- ):
- with mock.patch.object(transport_class, "__init__") as patched:
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.__init__", return_value=None
- ):
- with mock.patch(
- "google.auth.transport.grpc.SslCredentials.is_mtls",
- new_callable=mock.PropertyMock,
- ) as is_mtls_mock:
- is_mtls_mock.return_value = False
patched.return_value = None
client = client_class()
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
- host=client.DEFAULT_ENDPOINT,
+ host=expected_host,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ )
+
@pytest.mark.parametrize(
"client_class,transport_class,transport_name",
@@ -423,7 +412,7 @@ def test_user_event_service_client_client_options_scopes(
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=["1", "2"],
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -453,7 +442,7 @@ def test_user_event_service_client_client_options_credentials_file(
credentials_file="credentials.json",
host=client.DEFAULT_ENDPOINT,
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -472,7 +461,7 @@ def test_user_event_service_client_client_options_from_dict():
credentials_file=None,
host="squid.clam.whelk",
scopes=None,
- ssl_channel_credentials=None,
+ client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)
@@ -542,6 +531,22 @@ def test_write_user_event_from_dict():
test_write_user_event(request_type=dict)
+def test_write_user_event_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = UserEventServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.write_user_event), "__call__") as call:
+ client.write_user_event()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == user_event_service.WriteUserEventRequest()
+
+
@pytest.mark.asyncio
async def test_write_user_event_async(
transport: str = "grpc_asyncio",
@@ -704,6 +709,24 @@ def test_collect_user_event_from_dict():
test_collect_user_event(request_type=dict)
+def test_collect_user_event_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = UserEventServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.collect_user_event), "__call__"
+ ) as call:
+ client.collect_user_event()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == user_event_service.CollectUserEventRequest()
+
+
@pytest.mark.asyncio
async def test_collect_user_event_async(
transport: str = "grpc_asyncio",
@@ -836,6 +859,24 @@ def test_purge_user_events_from_dict():
test_purge_user_events(request_type=dict)
+def test_purge_user_events_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = UserEventServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.purge_user_events), "__call__"
+ ) as call:
+ client.purge_user_events()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == purge_config.PurgeUserEventsRequest()
+
+
@pytest.mark.asyncio
async def test_purge_user_events_async(
transport: str = "grpc_asyncio", request_type=purge_config.PurgeUserEventsRequest
@@ -965,6 +1006,24 @@ def test_import_user_events_from_dict():
test_import_user_events(request_type=dict)
+def test_import_user_events_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = UserEventServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.import_user_events), "__call__"
+ ) as call:
+ client.import_user_events()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == import_config.ImportUserEventsRequest()
+
+
@pytest.mark.asyncio
async def test_import_user_events_async(
transport: str = "grpc_asyncio", request_type=import_config.ImportUserEventsRequest
@@ -1094,6 +1153,24 @@ def test_rejoin_user_events_from_dict():
test_rejoin_user_events(request_type=dict)
+def test_rejoin_user_events_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = UserEventServiceClient(
+ credentials=credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.rejoin_user_events), "__call__"
+ ) as call:
+ client.rejoin_user_events()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+
+ assert args[0] == user_event_service.RejoinUserEventsRequest()
+
+
@pytest.mark.asyncio
async def test_rejoin_user_events_async(
transport: str = "grpc_asyncio",
@@ -1358,6 +1435,51 @@ def test_user_event_service_transport_auth_adc():
)
+@pytest.mark.parametrize(
+ "transport_class",
+ [
+ transports.UserEventServiceGrpcTransport,
+ transports.UserEventServiceGrpcAsyncIOTransport,
+ ],
+)
+def test_user_event_service_grpc_transport_client_cert_source_for_mtls(transport_class):
+ cred = credentials.AnonymousCredentials()
+
+ # Check ssl_channel_credentials is used if provided.
+ with mock.patch.object(transport_class, "create_channel") as mock_create_channel:
+ mock_ssl_channel_creds = mock.Mock()
+ transport_class(
+ host="squid.clam.whelk",
+ credentials=cred,
+ ssl_channel_credentials=mock_ssl_channel_creds,
+ )
+ mock_create_channel.assert_called_once_with(
+ "squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ ssl_credentials=mock_ssl_channel_creds,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
+ # is used.
+ with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
+ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred:
+ transport_class(
+ credentials=cred,
+ client_cert_source_for_mtls=client_cert_source_callback,
+ )
+ expected_cert, expected_key = client_cert_source_callback()
+ mock_ssl_cred.assert_called_once_with(
+ certificate_chain=expected_cert, private_key=expected_key
+ )
+
+
def test_user_event_service_host_no_port():
client = UserEventServiceClient(
credentials=credentials.AnonymousCredentials(),
@@ -1402,6 +1524,8 @@ def test_user_event_service_grpc_asyncio_transport_channel():
assert transport._ssl_channel_credentials == None
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[
@@ -1454,6 +1578,8 @@ def test_user_event_service_transport_channel_mtls_with_client_cert_source(
assert transport._ssl_channel_credentials == mock_ssl_cred
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
@pytest.mark.parametrize(
"transport_class",
[