Skip to content

Commit

Permalink
Refactor MLEngine code and add deploy and set_default commands
Browse files Browse the repository at this point in the history
  • Loading branch information
hongye-sun committed Feb 27, 2019
1 parent ea26a57 commit c82024c
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 76 deletions.
2 changes: 1 addition & 1 deletion component_sdk/python/kfp_component/google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import ml_engine, dataflow
from . import ml_engine, dataflow
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@
from ._create_version import create_version
from ._delete_version import delete_version
from ._train import train
from ._batch_predict import batch_predict
from ._batch_predict import batch_predict
from ._deploy import deploy
from ._set_default_version import set_default_version
30 changes: 13 additions & 17 deletions component_sdk/python/kfp_component/google/ml_engine/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,79 +81,75 @@ def create_model(self, project_id, model):
body = model
).execute()

def get_model(self, project_id, model_name):
def get_model(self, model_name):
"""Gets a model.
Args:
project_id: the ID of the parent project.
model_name: the name of the model.
Returns:
The retrieved model.
"""
return self._ml_client.projects().models().get(
name = 'projects/{}/models/{}'.format(
project_id, model_name)
name = model_name
).execute()

def create_version(self, project_id, model_name, version):
def create_version(self, model_name, version):
"""Creates a new version.
Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version: the payload of the version.
Returns:
The created version.
"""
return self._ml_client.projects().models().versions().create(
parent = 'projects/{}/models/{}'.format(project_id, model_name),
parent = model_name,
body = version
).execute()

def get_version(self, project_id, model_name, version_name):
def get_version(self, version_name):
"""Gets a version.
Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version_name: the name of the version.
Returns:
The retrieved version. None if the version is not found.
"""
try:
return self._ml_client.projects().models().versions().get(
name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
name = version_name
).execute()
except errors.HttpError as e:
if e.resp.status == 404:
return None
raise

def delete_version(self, project_id, model_name, version_name):
def delete_version(self, version_name):
"""Deletes a version.
Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version_name: the name of the version.
Returns:
The delete operation. None if the version is not found.
"""
try:
return self._ml_client.projects().models().versions().delete(
name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
name = version_name
).execute()
except errors.HttpError as e:
if e.resp.status == 404:
logging.info('The version has already been deleted.')
return None
raise

def set_default_version(self, version_name):
return self._ml_client.projects().models().versions().setDefault(
name = version_name
).execute()

def get_operation(self, operation_name):
"""Gets an operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

from googleapiclient import errors

def wait_existing_version(ml_client, project_id, model_name,
version_name, wait_interval):
def wait_existing_version(ml_client, version_name, wait_interval):
while True:
existing_version = ml_client.get_version(
project_id, model_name, version_name)
existing_version = ml_client.get_version(version_name)
if not existing_version:
return None
state = existing_version.get('state', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ def _dump_metadata(self):

def _dump_job(self, job):
logging.info('Dumping job: {}'.format(job))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(job))
gcp_common.dump_file('/tmp/outputs/job_id.txt', job['jobId'])
gcp_common.dump_file('/tmp/kfp/output/ml_engine/job.json', json.dumps(job))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/job_id.txt', job['jobId'])
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CreateModelOp:
def __init__(self, project_id, name, model):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = name
self._model_short_name = name
self._model_name = None
if model:
self._model = model
else:
Expand All @@ -53,8 +54,7 @@ def execute(self):
model = self._model)
except errors.HttpError as e:
if e.resp.status == 409:
existing_model = self._ml.get_model(
self._project_id, self._model_name)
existing_model = self._ml.get_model(self._model_name)
if not self._is_dup_model(existing_model):
raise
logging.info('The same model {} has been submitted'
Expand All @@ -67,9 +67,11 @@ def execute(self):
return created_model

def _set_model_name(self, context_id):
if not self._model_name:
self._model_name = 'model_' + context_id
self._model['name'] = gcp_common.normalize_name(self._model_name)
if not self._model_short_name:
self._model_short_name = 'model_' + context_id
self._model['name'] = gcp_common.normalize_name(self._model_short_name)
self._model_name = 'projects/{}/models/{}'.format(
self._project_id, self._model_short_name)


def _is_dup_model(self, existing_model):
Expand All @@ -82,11 +84,11 @@ def _is_dup_model(self, existing_model):
def _dump_metadata(self):
display.display(display.Link(
'https://console.cloud.google.com/mlengine/models/{}?project={}'.format(
self._model_name, self._project_id),
self._model_short_name, self._project_id),
'Model Details'
))

def _dump_model(self, model):
logging.info('Dumping model: {}'.format(model))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(model))
gcp_common.dump_file('/tmp/outputs/model_name.txt', self._model_name)
gcp_common.dump_file('/tmp/kfp/output/ml_engine/model.json', json.dumps(model))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_name.txt', self._model_name)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import time
import re

from googleapiclient import errors
from fire import decorators
Expand All @@ -25,13 +26,12 @@
from ._common_ops import wait_existing_version, wait_for_operation_done

@decorators.SetParseFns(python_version=str, runtime_version=str)
def create_version(project_id, model_name, deployemnt_uri=None, version_name=None,
def create_version(model_name, deployemnt_uri=None, version_name=None,
runtime_version=None, python_version=None, version=None,
replace_existing=False, wait_interval=30):
"""Creates a MLEngine version and wait for the operation to be done.
Args:
project_id (str): required, the ID of the parent project.
model_name (str): required, the name of the parent model.
deployment_uri (str): optional, the Google Cloud Storage location of
the trained model used to create the version.
Expand Down Expand Up @@ -60,16 +60,17 @@ def create_version(project_id, model_name, deployemnt_uri=None, version_name=Non
if python_version:
version['pythonVersion'] = python_version

return CreateVersionOp(project_id, model_name, version,
return CreateVersionOp(model_name, version,
replace_existing, wait_interval).execute_and_wait()

class CreateVersionOp:
def __init__(self, project_id, model_name, version,
def __init__(self, model_name, version,
replace_existing, wait_interval):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = gcp_common.normalize_name(model_name)
self._model_name = model_name
self._project_id, self._model_short_name = self._parse_model_name(model_name)
self._version_name = None
self._version_short_name = None
self._version = version
self._replace_existing = replace_existing
self._wait_interval = wait_interval
Expand All @@ -81,7 +82,7 @@ def execute_and_wait(self):
self._set_version_name(ctx.context_id())
self._dump_metadata()
existing_version = wait_existing_version(self._ml,
self._project_id, self._model_name, self._version_name,
self._version_name,
self._wait_interval)
if existing_version and self._is_dup_version(existing_version):
return self._handle_completed_version(existing_version)
Expand All @@ -95,15 +96,21 @@ def execute_and_wait(self):

created_version = self._create_version_and_wait()
return self._handle_completed_version(created_version)

def _parse_model_name(self, model_name):
match = re.search(r'^projects/([^/]+)/models/([^/]+)$', model_name)
if not match:
raise ValueError('model name "{}" is not in desired format.'.format(model_name))
return (match.group(1), match.group(2))

def _set_version_name(self, context_id):
version_name = self._version.get('name', None)
if not version_name:
version_name = 'ver_' + context_id
version_name = gcp_common.normalize_name(version_name)
self._version_name = version_name
self._version['name'] = version_name

name = self._version.get('name', None)
if not name:
name = 'ver_' + context_id
name = gcp_common.normalize_name(name)
self._version_short_name = name
self._version['name'] = name
self._version_name = '{}/versions/{}'.format(self._model_name, name)

def _cancel(self):
if self._delete_operation_name:
Expand All @@ -113,8 +120,7 @@ def _cancel(self):
self._ml.cancel_operation(self._create_operation_name)

def _create_version_and_wait(self):
operation = self._ml.create_version(self._project_id,
self._model_name, self._version)
operation = self._ml.create_version(self._model_name, self._version)
# Cache operation name for cancellation.
self._create_operation_name = operation.get('name')
try:
Expand All @@ -128,8 +134,7 @@ def _create_version_and_wait(self):
return operation.get('response', None)

def _delete_version_and_wait(self):
operation = self._ml.delete_version(
self._project_id, self._model_name, self._version_name)
operation = self._ml.delete_version(self._version_name)
# Cache operation name for cancellation.
self._delete_operation_name = operation.get('name')
try:
Expand All @@ -147,20 +152,22 @@ def _handle_completed_version(self, version):
error_message = version.get('errorMessage', 'Unknown failure')
raise RuntimeError('Version is in failed state: {}'.format(
error_message))
# Workaround issue that CMLE doesn't return the full version name.
version['name'] = self._version_name
self._dump_version(version)
return version

def _dump_metadata(self):
display.display(display.Link(
'https://console.cloud.google.com/mlengine/models/{}/versions/{}?project={}'.format(
self._model_name, self._version_name, self._project_id),
self._model_short_name, self._version_short_name, self._project_id),
'Version Details'
))

def _dump_version(self, version):
logging.info('Dumping version: {}'.format(version))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(version))
gcp_common.dump_file('/tmp/outputs/version_name.txt', version['name'])
gcp_common.dump_file('/tmp/kfp/output/ml_engine/version.json', json.dumps(version))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/version_name.txt', version['name'])

def _is_dup_version(self, existing_version):
return not gcp_common.check_resource_changed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,33 @@
from .. import common as gcp_common
from ._common_ops import wait_existing_version, wait_for_operation_done

def delete_version(project_id, model_name, version_name, wait_interval=30):
def delete_version(version_name, wait_interval=30):
"""Deletes a MLEngine version and wait.
Args:
project_id (str): required, the ID of the parent project.
model_name (str): required, the name of the parent model.
version_name (str): required, the name of the version.
wait_interval (int): the interval to wait for a long running operation.
"""
DeleteVersionOp(project_id, model_name, version_name,
wait_interval).execute_and_wait()
DeleteVersionOp(version_name, wait_interval).execute_and_wait()

class DeleteVersionOp:
def __init__(self, project_id, model_name, version_name, wait_interval):
def __init__(self, version_name, wait_interval):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = gcp_common.normalize_name(model_name)
self._version_name = gcp_common.normalize_name(version_name)
self._version_name = version_name
self._wait_interval = wait_interval
self._delete_operation_name = None

def execute_and_wait(self):
with KfpExecutionContext(on_cancel=self._cancel):
existing_version = wait_existing_version(self._ml,
self._project_id, self._model_name, self._version_name,
self._version_name,
self._wait_interval)
if not existing_version:
logging.info('The version has already been deleted.')
return None

logging.info('Deleting existing version...')
operation = self._ml.delete_version(
self._project_id, self._model_name, self._version_name)
operation = self._ml.delete_version(self._version_name)
# Cache operation name for cancellation.
self._delete_operation_name = operation.get('name')
try:
Expand Down
Loading

0 comments on commit c82024c

Please sign in to comment.