diff --git a/component_sdk/python/kfp_component/google/__init__.py b/component_sdk/python/kfp_component/google/__init__.py index e8a8d80fe37..6906d32256b 100644 --- a/component_sdk/python/kfp_component/google/__init__.py +++ b/component_sdk/python/kfp_component/google/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import ml_engine, dataflow \ No newline at end of file +from . import ml_engine, dataflow diff --git a/component_sdk/python/kfp_component/google/ml_engine/__init__.py b/component_sdk/python/kfp_component/google/ml_engine/__init__.py index 7075ef99d18..4ebbcc8854b 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/__init__.py +++ b/component_sdk/python/kfp_component/google/ml_engine/__init__.py @@ -27,4 +27,6 @@ from ._create_version import create_version from ._delete_version import delete_version from ._train import train -from ._batch_predict import batch_predict \ No newline at end of file +from ._batch_predict import batch_predict +from ._deploy import deploy +from ._set_default_version import set_default_version \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_client.py b/component_sdk/python/kfp_component/google/ml_engine/_client.py index f40f1e2aa94..b9e80ac523e 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_client.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_client.py @@ -81,25 +81,22 @@ def create_model(self, project_id, model): body = model ).execute() - def get_model(self, project_id, model_name): + def get_model(self, model_name): """Gets a model. Args: - project_id: the ID of the parent project. model_name: the name of the model. Returns: The retrieved model. """ return self._ml_client.projects().models().get( - name = 'projects/{}/models/{}'.format( - project_id, model_name) + name = model_name ).execute() - def create_version(self, project_id, model_name, version): + def create_version(self, model_name, version): """Creates a new version. Args: - project_id: the ID of the parent project. model_name: the name of the parent model. version: the payload of the version. @@ -107,16 +104,14 @@ def create_version(self, project_id, model_name, version): The created version. """ return self._ml_client.projects().models().versions().create( - parent = 'projects/{}/models/{}'.format(project_id, model_name), + parent = model_name, body = version ).execute() - def get_version(self, project_id, model_name, version_name): + def get_version(self, version_name): """Gets a version. Args: - project_id: the ID of the parent project. - model_name: the name of the parent model. version_name: the name of the version. Returns: @@ -124,20 +119,17 @@ def get_version(self, project_id, model_name, version_name): """ try: return self._ml_client.projects().models().versions().get( - name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) + name = version_name ).execute() except errors.HttpError as e: if e.resp.status == 404: return None raise - def delete_version(self, project_id, model_name, version_name): + def delete_version(self, version_name): """Deletes a version. Args: - project_id: the ID of the parent project. - model_name: the name of the parent model. version_name: the name of the version. Returns: @@ -145,8 +137,7 @@ def delete_version(self, project_id, model_name, version_name): """ try: return self._ml_client.projects().models().versions().delete( - name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) + name = version_name ).execute() except errors.HttpError as e: if e.resp.status == 404: @@ -154,6 +145,11 @@ def delete_version(self, project_id, model_name, version_name): return None raise + def set_default_version(self, version_name): + return self._ml_client.projects().models().versions().setDefault( + name = version_name + ).execute() + def get_operation(self, operation_name): """Gets an operation. diff --git a/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py b/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py index 23b6008e999..99d314c735f 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py @@ -17,11 +17,9 @@ from googleapiclient import errors -def wait_existing_version(ml_client, project_id, model_name, - version_name, wait_interval): +def wait_existing_version(ml_client, version_name, wait_interval): while True: - existing_version = ml_client.get_version( - project_id, model_name, version_name) + existing_version = ml_client.get_version(version_name) if not existing_version: return None state = existing_version.get('state', None) diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_job.py b/component_sdk/python/kfp_component/google/ml_engine/_create_job.py index c682e108967..73c26ee2d92 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_create_job.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_job.py @@ -127,5 +127,5 @@ def _dump_metadata(self): def _dump_job(self, job): logging.info('Dumping job: {}'.format(job)) - gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(job)) - gcp_common.dump_file('/tmp/outputs/job_id.txt', job['jobId']) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/job.json', json.dumps(job)) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/job_id.txt', job['jobId']) diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_model.py b/component_sdk/python/kfp_component/google/ml_engine/_create_model.py index 11504069372..12b14117001 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_create_model.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_model.py @@ -37,7 +37,8 @@ class CreateModelOp: def __init__(self, project_id, name, model): self._ml = MLEngineClient() self._project_id = project_id - self._model_name = name + self._model_short_name = name + self._model_name = None if model: self._model = model else: @@ -53,8 +54,7 @@ def execute(self): model = self._model) except errors.HttpError as e: if e.resp.status == 409: - existing_model = self._ml.get_model( - self._project_id, self._model_name) + existing_model = self._ml.get_model(self._model_name) if not self._is_dup_model(existing_model): raise logging.info('The same model {} has been submitted' @@ -67,9 +67,11 @@ def execute(self): return created_model def _set_model_name(self, context_id): - if not self._model_name: - self._model_name = 'model_' + context_id - self._model['name'] = gcp_common.normalize_name(self._model_name) + if not self._model_short_name: + self._model_short_name = 'model_' + context_id + self._model['name'] = gcp_common.normalize_name(self._model_short_name) + self._model_name = 'projects/{}/models/{}'.format( + self._project_id, self._model_short_name) def _is_dup_model(self, existing_model): @@ -82,11 +84,11 @@ def _is_dup_model(self, existing_model): def _dump_metadata(self): display.display(display.Link( 'https://console.cloud.google.com/mlengine/models/{}?project={}'.format( - self._model_name, self._project_id), + self._model_short_name, self._project_id), 'Model Details' )) def _dump_model(self, model): logging.info('Dumping model: {}'.format(model)) - gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(model)) - gcp_common.dump_file('/tmp/outputs/model_name.txt', self._model_name) \ No newline at end of file + gcp_common.dump_file('/tmp/kfp/output/ml_engine/model.json', json.dumps(model)) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_name.txt', self._model_name) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_version.py b/component_sdk/python/kfp_component/google/ml_engine/_create_version.py index fdea66c56ec..e70e2e39982 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_create_version.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_version.py @@ -15,6 +15,7 @@ import json import logging import time +import re from googleapiclient import errors from fire import decorators @@ -25,13 +26,12 @@ from ._common_ops import wait_existing_version, wait_for_operation_done @decorators.SetParseFns(python_version=str, runtime_version=str) -def create_version(project_id, model_name, deployemnt_uri=None, version_name=None, +def create_version(model_name, deployemnt_uri=None, version_name=None, runtime_version=None, python_version=None, version=None, replace_existing=False, wait_interval=30): """Creates a MLEngine version and wait for the operation to be done. Args: - project_id (str): required, the ID of the parent project. model_name (str): required, the name of the parent model. deployment_uri (str): optional, the Google Cloud Storage location of the trained model used to create the version. @@ -60,16 +60,17 @@ def create_version(project_id, model_name, deployemnt_uri=None, version_name=Non if python_version: version['pythonVersion'] = python_version - return CreateVersionOp(project_id, model_name, version, + return CreateVersionOp(model_name, version, replace_existing, wait_interval).execute_and_wait() class CreateVersionOp: - def __init__(self, project_id, model_name, version, + def __init__(self, model_name, version, replace_existing, wait_interval): self._ml = MLEngineClient() - self._project_id = project_id - self._model_name = gcp_common.normalize_name(model_name) + self._model_name = model_name + self._project_id, self._model_short_name = self._parse_model_name(model_name) self._version_name = None + self._version_short_name = None self._version = version self._replace_existing = replace_existing self._wait_interval = wait_interval @@ -81,7 +82,7 @@ def execute_and_wait(self): self._set_version_name(ctx.context_id()) self._dump_metadata() existing_version = wait_existing_version(self._ml, - self._project_id, self._model_name, self._version_name, + self._version_name, self._wait_interval) if existing_version and self._is_dup_version(existing_version): return self._handle_completed_version(existing_version) @@ -95,15 +96,21 @@ def execute_and_wait(self): created_version = self._create_version_and_wait() return self._handle_completed_version(created_version) + + def _parse_model_name(self, model_name): + match = re.search(r'^projects/([^/]+)/models/([^/]+)$', model_name) + if not match: + raise ValueError('model name "{}" is not in desired format.'.format(model_name)) + return (match.group(1), match.group(2)) def _set_version_name(self, context_id): - version_name = self._version.get('name', None) - if not version_name: - version_name = 'ver_' + context_id - version_name = gcp_common.normalize_name(version_name) - self._version_name = version_name - self._version['name'] = version_name - + name = self._version.get('name', None) + if not name: + name = 'ver_' + context_id + name = gcp_common.normalize_name(name) + self._version_short_name = name + self._version['name'] = name + self._version_name = '{}/versions/{}'.format(self._model_name, name) def _cancel(self): if self._delete_operation_name: @@ -113,8 +120,7 @@ def _cancel(self): self._ml.cancel_operation(self._create_operation_name) def _create_version_and_wait(self): - operation = self._ml.create_version(self._project_id, - self._model_name, self._version) + operation = self._ml.create_version(self._model_name, self._version) # Cache operation name for cancellation. self._create_operation_name = operation.get('name') try: @@ -128,8 +134,7 @@ def _create_version_and_wait(self): return operation.get('response', None) def _delete_version_and_wait(self): - operation = self._ml.delete_version( - self._project_id, self._model_name, self._version_name) + operation = self._ml.delete_version(self._version_name) # Cache operation name for cancellation. self._delete_operation_name = operation.get('name') try: @@ -147,20 +152,22 @@ def _handle_completed_version(self, version): error_message = version.get('errorMessage', 'Unknown failure') raise RuntimeError('Version is in failed state: {}'.format( error_message)) + # Workaround issue that CMLE doesn't return the full version name. + version['name'] = self._version_name self._dump_version(version) return version def _dump_metadata(self): display.display(display.Link( 'https://console.cloud.google.com/mlengine/models/{}/versions/{}?project={}'.format( - self._model_name, self._version_name, self._project_id), + self._model_short_name, self._version_short_name, self._project_id), 'Version Details' )) def _dump_version(self, version): logging.info('Dumping version: {}'.format(version)) - gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(version)) - gcp_common.dump_file('/tmp/outputs/version_name.txt', version['name']) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/version.json', json.dumps(version)) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/version_name.txt', version['name']) def _is_dup_version(self, existing_version): return not gcp_common.check_resource_changed( diff --git a/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py b/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py index 4bc68e2205f..acc1e784a8e 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py @@ -22,39 +22,33 @@ from .. import common as gcp_common from ._common_ops import wait_existing_version, wait_for_operation_done -def delete_version(project_id, model_name, version_name, wait_interval=30): +def delete_version(version_name, wait_interval=30): """Deletes a MLEngine version and wait. Args: - project_id (str): required, the ID of the parent project. - model_name (str): required, the name of the parent model. version_name (str): required, the name of the version. wait_interval (int): the interval to wait for a long running operation. """ - DeleteVersionOp(project_id, model_name, version_name, - wait_interval).execute_and_wait() + DeleteVersionOp(version_name, wait_interval).execute_and_wait() class DeleteVersionOp: - def __init__(self, project_id, model_name, version_name, wait_interval): + def __init__(self, version_name, wait_interval): self._ml = MLEngineClient() - self._project_id = project_id - self._model_name = gcp_common.normalize_name(model_name) - self._version_name = gcp_common.normalize_name(version_name) + self._version_name = version_name self._wait_interval = wait_interval self._delete_operation_name = None def execute_and_wait(self): with KfpExecutionContext(on_cancel=self._cancel): existing_version = wait_existing_version(self._ml, - self._project_id, self._model_name, self._version_name, + self._version_name, self._wait_interval) if not existing_version: logging.info('The version has already been deleted.') return None logging.info('Deleting existing version...') - operation = self._ml.delete_version( - self._project_id, self._model_name, self._version_name) + operation = self._ml.delete_version(self._version_name) # Cache operation name for cancellation. self._delete_operation_name = operation.get('name') try: diff --git a/component_sdk/python/kfp_component/google/ml_engine/_deploy.py b/component_sdk/python/kfp_component/google/ml_engine/_deploy.py new file mode 100644 index 00000000000..167da126928 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_deploy.py @@ -0,0 +1,79 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from fire import decorators + +from google.cloud import storage +from .. import common as gcp_common +from ..storage import parse_blob_path +from ._create_model import create_model +from ._create_version import create_version +from ._set_default_version import set_default_version + +@decorators.SetParseFns(python_version=str, runtime_version=str) +def deploy(model_uri, project_id, model_name=None, version_name=None, + runtime_version=None, python_version=None, version=None, + replace_existing_version=False, set_default=False, wait_interval=30): + """Deploy a model to MLEngine from GCS URI + + Args: + model_uri (str): required, the GCS URI which contains a model file. + Common used TF model search path (export/exporter) will be used + if exist. + project_id (str): required, the ID of the parent project. + model_name (str): optional, the name of the parent model. + version_name (str): optional, the name of the version. If it is not + provided, the operation uses a random name. + runtime_version (str): optinal, the Cloud ML Engine runtime version + to use for this deployment. If not set, Cloud ML Engine uses + the default stable version, 1.0. + python_version (str): optinal, the version of Python used in prediction. + If not set, the default version is '2.7'. Python '3.5' is available + when runtimeVersion is set to '1.4' and above. Python '2.7' works + with all supported runtime versions. + version (str): optional, the payload of the new version. + replace_existing_version (boolean): boolean flag indicates whether to replace + existing version in case of conflict. + set_default (boolean): boolean flag indicates whether to set the new + version as default version in the model. + wait_interval (int): the interval to wait for a long running operation. + """ + model_uri = _search_tf_model_common_exporter_dir(model_uri) + gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_uri.txt', + model_uri) + model = create_model(project_id, model_name) + model_name = model.get('name') + version = create_version(model_name, model_uri, version_name, + runtime_version, python_version, version, replace_existing_version, + wait_interval) + if set_default: + version_name = version.get('name') + version = set_default_version(version_name) + return version + +def _search_tf_model_common_exporter_dir(model_uri): + bucket_name, blob_name = parse_blob_path(model_uri) + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + exporter_path = os.path.join(blob_name, 'export/exporter/') + iterator = bucket.list_blobs(prefix=exporter_path, delimiter='/') + for _ in iterator.pages: + # Iterate to the last page + pass + if iterator.prefixes: + prefixes = list(iterator.prefixes) + prefixes.sort(reverse=True) + return 'gs://{}/{}'.format(bucket_name, prefixes[0]) + return model_uri diff --git a/component_sdk/python/kfp_component/google/ml_engine/_set_default_version.py b/component_sdk/python/kfp_component/google/ml_engine/_set_default_version.py new file mode 100644 index 00000000000..eb3c1da98f3 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_set_default_version.py @@ -0,0 +1,18 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._client import MLEngineClient + +def set_default_version(version_name): + return MLEngineClient().set_default_version(version_name) \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__create_version.py b/component_sdk/python/tests/google/ml_engine/test__create_version.py index 053fb69194e..2cadc994500 100644 --- a/component_sdk/python/tests/google/ml_engine/test__create_version.py +++ b/component_sdk/python/tests/google/ml_engine/test__create_version.py @@ -37,7 +37,7 @@ def test_create_version_succeed(self, mock_mlengine_client, 'response': version } - result = create_version('mock_project', 'mock_model', + result = create_version('projects/mock_project/models/mock_model', deployemnt_uri = 'gs://test-location', version_name = 'mock_version', version = version, replace_existing = True) @@ -64,7 +64,7 @@ def test_create_version_fail(self, mock_mlengine_client, } with self.assertRaises(RuntimeError) as context: - create_version('mock_project', 'mock_model', + create_version('projects/mock_project/models/mock_model', version = version, replace_existing = True, wait_interval = 30) self.assertEqual( @@ -89,7 +89,7 @@ def test_create_version_dup_version_succeed(self, mock_mlengine_client, mock_mlengine_client().get_version.side_effect = [ pending_version, ready_version] - result = create_version('mock_project', 'mock_model', version = version, + result = create_version('projects/mock_project/models/mock_model', version = version, replace_existing = True, wait_interval = 0) self.assertEqual(ready_version, result) @@ -114,7 +114,7 @@ def test_create_version_failed_state(self, mock_mlengine_client, pending_version, failed_version] with self.assertRaises(RuntimeError) as context: - create_version('mock_project', 'mock_model', version = version, + create_version('projects/mock_project/models/mock_model', version = version, replace_existing = True, wait_interval = 0) self.assertEqual( @@ -148,7 +148,7 @@ def test_create_version_conflict_version_replace_succeed(self, mock_mlengine_cli create_operation ] - result = create_version('mock_project', 'mock_model', version = version, + result = create_version('projects/mock_project/models/mock_model', version = version, replace_existing = True, wait_interval = 0) self.assertEqual(version, result) @@ -180,7 +180,7 @@ def test_create_version_conflict_version_delete_fail(self, mock_mlengine_client, mock_mlengine_client().get_operation.return_value = delete_operation with self.assertRaises(RuntimeError) as context: - create_version('mock_project', 'mock_model', version = version, + create_version('projects/mock_project/models/mock_model', version = version, replace_existing = True, wait_interval = 0) self.assertEqual( @@ -203,7 +203,7 @@ def test_create_version_conflict_version_fail(self, mock_mlengine_client, mock_mlengine_client().get_version.return_value = conflicting_version with self.assertRaises(RuntimeError) as context: - create_version('mock_project', 'mock_model', version = version, + create_version('projects/mock_project/models/mock_model', version = version, replace_existing = False, wait_interval = 0) self.assertEqual( diff --git a/component_sdk/python/tests/google/ml_engine/test__delete_version.py b/component_sdk/python/tests/google/ml_engine/test__delete_version.py index 5976ef3e448..6f4a7c93f29 100644 --- a/component_sdk/python/tests/google/ml_engine/test__delete_version.py +++ b/component_sdk/python/tests/google/ml_engine/test__delete_version.py @@ -34,7 +34,7 @@ def test_execute_succeed(self, mock_mlengine_client, 'done': True } - delete_version('mock_project', 'mock_model', 'mock_version', + delete_version('projects/mock_project/models/mock_model/versions/mock_version', wait_interval = 30) mock_mlengine_client().delete_version.assert_called_once() @@ -46,7 +46,7 @@ def test_execute_retry_succeed(self, mock_mlengine_client, } mock_mlengine_client().get_version.side_effect = [pending_version, None] - delete_version('mock_project', 'mock_model', 'mock_version', + delete_version('projects/mock_project/models/mock_model/versions/mock_version', wait_interval = 0) self.assertEqual(2, mock_mlengine_client().get_version.call_count) \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__deploy.py b/component_sdk/python/tests/google/ml_engine/test__deploy.py new file mode 100644 index 00000000000..dfb6cae46fb --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__deploy.py @@ -0,0 +1,95 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from googleapiclient import errors +from kfp_component.google.ml_engine import deploy + +MODULE = 'kfp_component.google.ml_engine._deploy' + +@mock.patch(MODULE + '.storage.Client') +@mock.patch(MODULE + '.create_model') +@mock.patch(MODULE + '.create_version') +@mock.patch(MODULE + '.set_default_version') +class TestDeploy(unittest.TestCase): + + def test_deploy_default_path(self, mock_set_default_version, mock_create_version, + mock_create_model, mock_storage_client): + + mock_storage_client().get_bucket().list_blobs().prefixes = [] + mock_create_model.return_value = { + 'name': 'projects/mock-project/models/mock-model' + } + expected_version = { + 'name': 'projects/mock-project/models/mock-model/version/mock-version' + } + mock_create_version.return_value = expected_version + + result = deploy('gs://model/uri', 'mock-project') + + self.assertEqual(expected_version, result) + mock_create_version.assert_called_with( + 'projects/mock-project/models/mock-model', + 'gs://model/uri', + None, # version_name + None, # runtime_version + None, # python_version + None, # version + False, # replace_existing_version + 30) + + def test_deploy_tf_exporter_path(self, mock_set_default_version, mock_create_version, + mock_create_model, mock_storage_client): + + mock_storage_client().get_bucket().list_blobs().prefixes = [ + 'uri/export/exporter/123' + ] + mock_create_model.return_value = { + 'name': 'projects/mock-project/models/mock-model' + } + expected_version = { + 'name': 'projects/mock-project/models/mock-model/version/mock-version' + } + mock_create_version.return_value = expected_version + + result = deploy('gs://model/uri', 'mock-project') + + self.assertEqual(expected_version, result) + mock_create_version.assert_called_with( + 'projects/mock-project/models/mock-model', + 'gs://model/uri/export/exporter/123', + None, # version_name + None, # runtime_version + None, # python_version + None, # version + False, # replace_existing_version + 30) + + def test_deploy_set_default_version(self, mock_set_default_version, mock_create_version, + mock_create_model, mock_storage_client): + + mock_storage_client().get_bucket().list_blobs().prefixes = [] + mock_create_model.return_value = { + 'name': 'projects/mock-project/models/mock-model' + } + expected_version = { + 'name': 'projects/mock-project/models/mock-model/version/mock-version' + } + mock_create_version.return_value = expected_version + mock_set_default_version.return_value = expected_version + + result = deploy('gs://model/uri', 'mock-project', set_default=True) + + self.assertEqual(expected_version, result) + mock_set_default_version.assert_called_with( + 'projects/mock-project/models/mock-model/version/mock-version') \ No newline at end of file