Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Move cloud trainer and predictor from their own classes to Job and Model respectively. #192

Merged
merged 5 commits into from
Feb 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datalab/mlalpha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ._tensorboard import TensorBoard
from ._dataset import CsvDataSet, BigQueryDataSet
from ._package import Packager
from ._cloud_models import CloudModels, CloudModelVersions
from ._cloud_models import Models, ModelVersions
from ._confusion_matrix import ConfusionMatrix
from ._analysis import csv_to_dataframe
from ._package_runner import PackageRunner
Expand Down
101 changes: 60 additions & 41 deletions datalab/mlalpha/_cloud_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,27 @@

from googleapiclient import discovery
import os
import time
import yaml

import datalab.context
import datalab.storage
import datalab.utils

from . import _util

# TODO(qimingj) Remove once the API is public since it will no longer be needed
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
'ml_v1beta1_discovery.json'


class CloudModels(object):
class Models(object):
"""Represents a list of Cloud ML models for a project."""

def __init__(self, project_id=None):
"""Initializes an instance of a CloudML Model list that is iteratable
("for model in CloudModels()").

"""
Args:
project_id: project_id of the models. If not provided, default project_id will be used.
"""
if project_id is None:
project_id = datalab.context.Context.default().project_id
self._project_id = project_id
self._credentials = datalab.context.Context.default().credentials
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
self._api = discovery.build('ml', 'v1', credentials=self._credentials)

def _retrieve_models(self, page_token, page_size):
list_info = self._api.projects().models().list(
Expand All @@ -51,11 +43,13 @@ def _retrieve_models(self, page_token, page_size):
page_token = list_info.get('nextPageToken', None)
return models, page_token

def __iter__(self):
def get_iterator(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the name change to get_iterator()? Also do the same change to the Jobs class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done on Jobs class. The reason of the change is that having a class both exposing methods and exposing itself as iterator is a bit odd and the iterator function is less discoverable.

"""Get iterator of models so it can be used as "for model in Models().get_iterator()".
"""
return iter(datalab.utils.Iterator(self._retrieve_models))

def get_model_details(self, model_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very generic name. Expand comment: what details are returned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side question: why does create not return anything, but all the other functions return the execute() result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a bit in comments. create() is instant (it does not return a long running operation) so there is no need to return anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I think it is better to return whatever service returns, so it is consistent across the methods.

"""Get details of a model.
"""Get details of the specified model from CloudML Service.

Args:
model_name: the name of the model. It can be a model full name
Expand All @@ -72,10 +66,16 @@ def create(self, model_name):

Args:
model_name: the short name of the model, such as "iris".
Returns:
If successful, returns informaiton of the model, such as
{u'regions': [u'us-central1'], u'name': u'projects/myproject/models/mymodel'}
Raises:
If the model creation failed.
"""
body = {'name': model_name}
parent = 'projects/' + self._project_id
self._api.projects().models().create(body=body, parent=parent).execute()
# Model creation is instant. If anything goes wrong, Exception will be thrown.
return self._api.projects().models().create(body=body, parent=parent).execute()

def delete(self, model_name):
"""Delete a model.
Expand All @@ -87,7 +87,10 @@ def delete(self, model_name):
full_name = model_name
if not model_name.startswith('projects/'):
full_name = ('projects/%s/models/%s' % (self._project_id, model_name))
return self._api.projects().models().delete(name=full_name).execute()
response = self._api.projects().models().delete(name=full_name).execute()
if 'name' not in response:
raise Exception('Invalid response from service. "name" is not found.')
_util.wait_for_long_running_operation(response['name'])

def list(self, count=10):
"""List models under the current project in a table view.
Expand Down Expand Up @@ -121,13 +124,11 @@ def describe(self, model_name):
print model_yaml


class CloudModelVersions(object):
class ModelVersions(object):
"""Represents a list of versions for a Cloud ML model."""

def __init__(self, model_name, project_id=None):
"""Initializes an instance of a CloudML model version list that is iteratable
("for version in CloudModelVersions()").

"""
Args:
model_name: the name of the model. It can be a model full name
("projects/[project_id]/models/[model_name]") or just [model_name].
Expand All @@ -137,8 +138,7 @@ def __init__(self, model_name, project_id=None):
if project_id is None:
self._project_id = datalab.context.Context.default().project_id
self._credentials = datalab.context.Context.default().credentials
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
self._api = discovery.build('ml', 'v1', credentials=self._credentials)
if not model_name.startswith('projects/'):
model_name = ('projects/%s/models/%s' % (self._project_id, model_name))
self._full_model_name = model_name
Expand All @@ -152,7 +152,10 @@ def _retrieve_versions(self, page_token, page_size):
page_token = list_info.get('nextPageToken', None)
return versions, page_token

def __iter__(self):
def get_iterator(self):
"""Get iterator of versions so it can be used as
"for v in ModelVersions(model_name).get_iterator()".
"""
return iter(datalab.utils.Iterator(self._retrieve_versions))

def get_version_details(self, version_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_version_details does not return anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does. "return self._api.projects().models().versions().get(name=name).execute()"? It returns what the REST API call returns.

Expand All @@ -165,21 +168,6 @@ def get_version_details(self, version_name):
name = ('%s/versions/%s' % (self._full_model_name, version_name))
return self._api.projects().models().versions().get(name=name).execute()

def _wait_for_long_running_operation(self, response):
if 'name' not in response:
raise Exception('Invaid response from service. Cannot find "name" field.')
print('Waiting for job "%s"' % response['name'])
while True:
response = self._api.projects().operations().get(name=response['name']).execute()
if 'done' not in response or response['done'] != True:
time.sleep(3)
else:
if 'error' in response:
print(response['error'])
else:
print('Done.')
break

def deploy(self, version_name, path):
"""Deploy a model version to the cloud.

Expand Down Expand Up @@ -211,7 +199,9 @@ def deploy(self, version_name, path):
}
response = self._api.projects().models().versions().create(body=body,
parent=self._full_model_name).execute()
self._wait_for_long_running_operation(response)
if 'name' not in response:
raise Exception('Invalid response from service. "name" is not found.')
_util.wait_for_long_running_operation(response['name'])

def delete(self, version_name):
"""Delete a version of model.
Expand All @@ -221,8 +211,37 @@ def delete(self, version_name):
"""
name = ('%s/versions/%s' % (self._full_model_name, version_name))
response = self._api.projects().models().versions().delete(name=name).execute()
self._wait_for_long_running_operation(response)

if 'name' not in response:
raise Exception('Invalid response from service. "name" is not found.')
_util.wait_for_long_running_operation(response['name'])

def predict(self, version_name, data):
"""Get prediction results from features instances.

Args:
version_name: the name of the version used for prediction.
data: typically a list of instance to be submitted for prediction. The format of the
instance depends on the model. For example, structured data model may require
a csv line for each instance.
Note that online prediction only works on models that take one placeholder value,
such as a string encoding a csv line.
Returns:
A list of prediction results for given instances. Each element is a dictionary representing
output mapping from the graph.
An example:
[{"predictions": 1, "score": [0.00078, 0.71406, 0.28515]},
{"predictions": 1, "score": [0.00244, 0.99634, 0.00121]}]
"""
full_version_name = ('%s/versions/%s' % (self._full_model_name, version_name))
request = self._api.projects().predict(body={'instances': data},
name=full_version_name)
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0'
result = request.execute()
if 'predictions' not in result:
raise Exception('Invalid response from service. Cannot find "predictions" in response.')

return result['predictions']

def describe(self, version_name):
"""Print information of a specified model.

Expand Down
87 changes: 64 additions & 23 deletions datalab/mlalpha/_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,22 @@
from googleapiclient import discovery
import yaml

# TODO(qimingj) Remove once the API is public since it will no longer be needed
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
'ml_v1beta1_discovery.json'

import datalab.utils
import datalab.context
from googleapiclient import discovery
import yaml

# TODO(qimingj) Remove once the API is public since it will no longer be needed
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
'ml_v1beta1_discovery.json'


class Job(object):
"""Represents a Cloud ML job."""

def __init__(self, name, context=None, api=None):
def __init__(self, name, context=None):
"""Initializes an instance of a CloudML Job.

Args:
name: the name of the job. It can be an operation full name
("projects/[project_id]/operations/[operation_name]") or just [operation_name].
("projects/[project_id]/jobs/[operation_name]") or just [operation_name].
context: an optional Context object providing project_id and credentials.
api: optional CloudML API client.
"""
if context is None:
context = datalab.context.Context.default()
self._context = context
if api is None:
api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
self._api = api
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials)
if not name.startswith('projects/'):
name = 'projects/' + self._context.project_id + '/jobs/' + name
self._name = name
Expand All @@ -68,6 +51,63 @@ def describe(self):
job_yaml = yaml.safe_dump(self._info, default_flow_style=False)
print job_yaml

@staticmethod
def submit_training(job_request, job_id=None):
"""Submit a training job.
Copy link
Contributor

@brandondutra brandondutra Feb 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand comment that the args dict will get expanded to
--key value or --key list_item_0 --key list_item_1, ...

Also say that the python_module "trainer/task.py" should be able to parse these args.

Wait, regular flags are not supported? So task.py cannot use just "--foo". Only key-value parameters are supported? I don't like this restriction. I also bet we don't document this restriction anywhere.

So command lines that are not supported in datalab:

  1. --foo
  2. --foo 1 2 3
  3. --foo=--bar=2 (because we append the key and value in two calls to .append)
  4. -n
  5. -n 200

I would rather allow args to be a dict or list. So then we support every kind of command line. Or we should at least document task.py arg's restrictions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done on comments. I think we only expand args if it is a dict. If it is a list we don't mess up with it.


Args:
job_request: the arguments of the training job in a dict. For example,
{
'package_uris': 'gs://my-bucket/iris/trainer-0.1.tar.gz',
'python_module': 'trainer.task',
'scale_tier': 'BASIC',
'region': 'us-central1',
'args': {
'train_data_paths': ['gs://mubucket/data/features_train'],
'eval_data_paths': ['gs://mubucket/data/features_eval'],
'metadata_path': 'gs://mubucket/data/metadata.yaml',
'output_path': 'gs://mubucket/data/mymodel/',
}
}
If 'args' is present in job_request and is a dict, it will be expanded to
--key value or --key list_item_0 --key list_item_1, ...
job_id: id for the training job. If None, an id based on timestamp will be generated.
Returns:
A Job object representing the cloud training job.
"""
new_job_request = dict(job_request)
# convert job_args from dict to list as service required.
if 'args' in job_request and isinstance(job_request['args'], dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if job_request['args'] is a list because someone did ['--abc=123', '--def=345', ..] or something else, there is no error message! Add some else statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone passes in a list, we pass it through to the service, and service will respond with errors if it is not expected.

job_args = job_request['args']
args = []
for k,v in job_args.iteritems():
if isinstance(v, list):
for item in v:
args.append('--' + str(k))
args.append(str(item))
else:
args.append('--' + str(k))
args.append(str(v))
new_job_request['args'] = args

if job_id is None:
job_id = datetime.datetime.now().strftime('%y%m%d_%H%M%S')
if 'python_module' in new_job_request:
job_id = new_job_request['python_module'].replace('.', '_') + \
'_' + job_id

job = {
'job_id': job_id,
'training_input': new_job_request,
}
context = datalab.context.Context.default()
cloudml = discovery.build('ml', 'v1', credentials=context.credentials)
request = cloudml.projects().jobs().create(body=job,
parent='projects/' + context.project_id)
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0'
request.execute()
return Job(job_id)


class Jobs(object):
"""Represents a list of Cloud ML jobs for a project."""
Expand All @@ -82,8 +122,7 @@ def __init__(self, filter=None):
"""
self._filter = filter
self._context = datalab.context.Context.default()
self._api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials)

def _retrieve_jobs(self, page_token, page_size):
list_info = self._api.projects().jobs().list(parent='projects/' + self._context.project_id,
Expand All @@ -93,7 +132,9 @@ def _retrieve_jobs(self, page_token, page_size):
page_token = list_info.get('nextPageToken', None)
return jobs, page_token

def __iter__(self):
def get_iterator(self):
"""Get iterator of jobs so it can be used as "for model in Jobs().get_iterator()".
"""
return iter(datalab.utils.Iterator(self._retrieve_jobs))

def list(self, count=10):
Expand Down
19 changes: 19 additions & 0 deletions datalab/mlalpha/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,29 @@
# limitations under the License.

import datetime
from googleapiclient import discovery
import os
import shutil
import subprocess
import tempfile
import time

import datalab.context

# TODO: Create an Operation class.
def wait_for_long_running_operation(operation_full_name):
print('Waiting for operation "%s"' % operation_full_name)
api = discovery.build('ml', 'v1', credentials=datalab.context.Context.default().credentials)
while True:
response = api.projects().operations().get(name=operation_full_name).execute()
if 'done' not in response or response['done'] != True:
time.sleep(3)
else:
if 'error' in response:
print(response['error'])
else:
print('Done.')
break


def package_and_copy(package_root_dir, setup_py, output_tar_path):
Expand Down
3 changes: 2 additions & 1 deletion datalab/utils/_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __iter__(self):

self._page_token = next_page_token
self._first_page = False
self._count += len(items)
if self._count == 0:
self._count = len(items)

for item in items:
yield item
Expand Down
Loading