-
Notifications
You must be signed in to change notification settings - Fork 78
Move cloud trainer and predictor from their own classes to Job and Model respectively. #192
Changes from all commits
d005a12
bd0cd1c
426b93a
3a3007d
9e0b8fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,35 +14,27 @@ | |
|
||
from googleapiclient import discovery | ||
import os | ||
import time | ||
import yaml | ||
|
||
import datalab.context | ||
import datalab.storage | ||
import datalab.utils | ||
|
||
from . import _util | ||
|
||
# TODO(qimingj) Remove once the API is public since it will no longer be needed | ||
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \ | ||
'ml_v1beta1_discovery.json' | ||
|
||
|
||
class CloudModels(object): | ||
class Models(object): | ||
"""Represents a list of Cloud ML models for a project.""" | ||
|
||
def __init__(self, project_id=None): | ||
"""Initializes an instance of a CloudML Model list that is iteratable | ||
("for model in CloudModels()"). | ||
|
||
""" | ||
Args: | ||
project_id: project_id of the models. If not provided, default project_id will be used. | ||
""" | ||
if project_id is None: | ||
project_id = datalab.context.Context.default().project_id | ||
self._project_id = project_id | ||
self._credentials = datalab.context.Context.default().credentials | ||
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials, | ||
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL) | ||
self._api = discovery.build('ml', 'v1', credentials=self._credentials) | ||
|
||
def _retrieve_models(self, page_token, page_size): | ||
list_info = self._api.projects().models().list( | ||
|
@@ -51,11 +43,13 @@ def _retrieve_models(self, page_token, page_size): | |
page_token = list_info.get('nextPageToken', None) | ||
return models, page_token | ||
|
||
def __iter__(self): | ||
def get_iterator(self): | ||
"""Get iterator of models so it can be used as "for model in Models().get_iterator()". | ||
""" | ||
return iter(datalab.utils.Iterator(self._retrieve_models)) | ||
|
||
def get_model_details(self, model_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very generic name. Expand comment: what details are returned? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. side question: why does create not return anything, but all the other functions return the execute() result? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a bit in comments. create() is instant (it does not return a long running operation) so there is no need to return anything. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I think it is better to return whatever service returns, so it is consistent across the methods. |
||
"""Get details of a model. | ||
"""Get details of the specified model from CloudML Service. | ||
|
||
Args: | ||
model_name: the name of the model. It can be a model full name | ||
|
@@ -72,10 +66,16 @@ def create(self, model_name): | |
|
||
Args: | ||
model_name: the short name of the model, such as "iris". | ||
Returns: | ||
If successful, returns informaiton of the model, such as | ||
{u'regions': [u'us-central1'], u'name': u'projects/myproject/models/mymodel'} | ||
Raises: | ||
If the model creation failed. | ||
""" | ||
body = {'name': model_name} | ||
parent = 'projects/' + self._project_id | ||
self._api.projects().models().create(body=body, parent=parent).execute() | ||
# Model creation is instant. If anything goes wrong, Exception will be thrown. | ||
return self._api.projects().models().create(body=body, parent=parent).execute() | ||
|
||
def delete(self, model_name): | ||
"""Delete a model. | ||
|
@@ -87,7 +87,10 @@ def delete(self, model_name): | |
full_name = model_name | ||
if not model_name.startswith('projects/'): | ||
full_name = ('projects/%s/models/%s' % (self._project_id, model_name)) | ||
return self._api.projects().models().delete(name=full_name).execute() | ||
response = self._api.projects().models().delete(name=full_name).execute() | ||
if 'name' not in response: | ||
raise Exception('Invalid response from service. "name" is not found.') | ||
_util.wait_for_long_running_operation(response['name']) | ||
|
||
def list(self, count=10): | ||
"""List models under the current project in a table view. | ||
|
@@ -121,13 +124,11 @@ def describe(self, model_name): | |
print model_yaml | ||
|
||
|
||
class CloudModelVersions(object): | ||
class ModelVersions(object): | ||
"""Represents a list of versions for a Cloud ML model.""" | ||
|
||
def __init__(self, model_name, project_id=None): | ||
"""Initializes an instance of a CloudML model version list that is iteratable | ||
("for version in CloudModelVersions()"). | ||
|
||
""" | ||
Args: | ||
model_name: the name of the model. It can be a model full name | ||
("projects/[project_id]/models/[model_name]") or just [model_name]. | ||
|
@@ -137,8 +138,7 @@ def __init__(self, model_name, project_id=None): | |
if project_id is None: | ||
self._project_id = datalab.context.Context.default().project_id | ||
self._credentials = datalab.context.Context.default().credentials | ||
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials, | ||
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL) | ||
self._api = discovery.build('ml', 'v1', credentials=self._credentials) | ||
if not model_name.startswith('projects/'): | ||
model_name = ('projects/%s/models/%s' % (self._project_id, model_name)) | ||
self._full_model_name = model_name | ||
|
@@ -152,7 +152,10 @@ def _retrieve_versions(self, page_token, page_size): | |
page_token = list_info.get('nextPageToken', None) | ||
return versions, page_token | ||
|
||
def __iter__(self): | ||
def get_iterator(self): | ||
"""Get iterator of versions so it can be used as | ||
"for v in ModelVersions(model_name).get_iterator()". | ||
""" | ||
return iter(datalab.utils.Iterator(self._retrieve_versions)) | ||
|
||
def get_version_details(self, version_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_version_details does not return anything? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does. "return self._api.projects().models().versions().get(name=name).execute()"? It returns what the REST API call returns. |
||
|
@@ -165,21 +168,6 @@ def get_version_details(self, version_name): | |
name = ('%s/versions/%s' % (self._full_model_name, version_name)) | ||
return self._api.projects().models().versions().get(name=name).execute() | ||
|
||
def _wait_for_long_running_operation(self, response): | ||
if 'name' not in response: | ||
raise Exception('Invaid response from service. Cannot find "name" field.') | ||
print('Waiting for job "%s"' % response['name']) | ||
while True: | ||
response = self._api.projects().operations().get(name=response['name']).execute() | ||
if 'done' not in response or response['done'] != True: | ||
time.sleep(3) | ||
else: | ||
if 'error' in response: | ||
print(response['error']) | ||
else: | ||
print('Done.') | ||
break | ||
|
||
def deploy(self, version_name, path): | ||
"""Deploy a model version to the cloud. | ||
|
||
|
@@ -211,7 +199,9 @@ def deploy(self, version_name, path): | |
} | ||
response = self._api.projects().models().versions().create(body=body, | ||
parent=self._full_model_name).execute() | ||
self._wait_for_long_running_operation(response) | ||
if 'name' not in response: | ||
raise Exception('Invalid response from service. "name" is not found.') | ||
_util.wait_for_long_running_operation(response['name']) | ||
|
||
def delete(self, version_name): | ||
"""Delete a version of model. | ||
|
@@ -221,8 +211,37 @@ def delete(self, version_name): | |
""" | ||
name = ('%s/versions/%s' % (self._full_model_name, version_name)) | ||
response = self._api.projects().models().versions().delete(name=name).execute() | ||
self._wait_for_long_running_operation(response) | ||
|
||
if 'name' not in response: | ||
raise Exception('Invalid response from service. "name" is not found.') | ||
_util.wait_for_long_running_operation(response['name']) | ||
|
||
def predict(self, version_name, data): | ||
"""Get prediction results from features instances. | ||
|
||
Args: | ||
version_name: the name of the version used for prediction. | ||
data: typically a list of instance to be submitted for prediction. The format of the | ||
instance depends on the model. For example, structured data model may require | ||
a csv line for each instance. | ||
Note that online prediction only works on models that take one placeholder value, | ||
such as a string encoding a csv line. | ||
Returns: | ||
A list of prediction results for given instances. Each element is a dictionary representing | ||
output mapping from the graph. | ||
An example: | ||
[{"predictions": 1, "score": [0.00078, 0.71406, 0.28515]}, | ||
{"predictions": 1, "score": [0.00244, 0.99634, 0.00121]}] | ||
""" | ||
full_version_name = ('%s/versions/%s' % (self._full_model_name, version_name)) | ||
request = self._api.projects().predict(body={'instances': data}, | ||
name=full_version_name) | ||
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0' | ||
result = request.execute() | ||
if 'predictions' not in result: | ||
raise Exception('Invalid response from service. Cannot find "predictions" in response.') | ||
|
||
return result['predictions'] | ||
|
||
def describe(self, version_name): | ||
"""Print information of a specified model. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,39 +18,22 @@ | |
from googleapiclient import discovery | ||
import yaml | ||
|
||
# TODO(qimingj) Remove once the API is public since it will no longer be needed | ||
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \ | ||
'ml_v1beta1_discovery.json' | ||
|
||
import datalab.utils | ||
import datalab.context | ||
from googleapiclient import discovery | ||
import yaml | ||
|
||
# TODO(qimingj) Remove once the API is public since it will no longer be needed | ||
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \ | ||
'ml_v1beta1_discovery.json' | ||
|
||
|
||
class Job(object): | ||
"""Represents a Cloud ML job.""" | ||
|
||
def __init__(self, name, context=None, api=None): | ||
def __init__(self, name, context=None): | ||
"""Initializes an instance of a CloudML Job. | ||
|
||
Args: | ||
name: the name of the job. It can be an operation full name | ||
("projects/[project_id]/operations/[operation_name]") or just [operation_name]. | ||
("projects/[project_id]/jobs/[operation_name]") or just [operation_name]. | ||
context: an optional Context object providing project_id and credentials. | ||
api: optional CloudML API client. | ||
""" | ||
if context is None: | ||
context = datalab.context.Context.default() | ||
self._context = context | ||
if api is None: | ||
api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials, | ||
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL) | ||
self._api = api | ||
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials) | ||
if not name.startswith('projects/'): | ||
name = 'projects/' + self._context.project_id + '/jobs/' + name | ||
self._name = name | ||
|
@@ -68,6 +51,63 @@ def describe(self): | |
job_yaml = yaml.safe_dump(self._info, default_flow_style=False) | ||
print job_yaml | ||
|
||
@staticmethod | ||
def submit_training(job_request, job_id=None): | ||
"""Submit a training job. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expand comment that the args dict will get expanded to Also say that the python_module "trainer/task.py" should be able to parse these args. Wait, regular flags are not supported? So task.py cannot use just "--foo". Only key-value parameters are supported? I don't like this restriction. I also bet we don't document this restriction anywhere. So command lines that are not supported in datalab:
I would rather allow args to be a dict or list. So then we support every kind of command line. Or we should at least document task.py arg's restrictions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done on comments. I think we only expand args if it is a dict. If it is a list we don't mess up with it. |
||
|
||
Args: | ||
job_request: the arguments of the training job in a dict. For example, | ||
{ | ||
'package_uris': 'gs://my-bucket/iris/trainer-0.1.tar.gz', | ||
'python_module': 'trainer.task', | ||
'scale_tier': 'BASIC', | ||
'region': 'us-central1', | ||
'args': { | ||
'train_data_paths': ['gs://mubucket/data/features_train'], | ||
'eval_data_paths': ['gs://mubucket/data/features_eval'], | ||
'metadata_path': 'gs://mubucket/data/metadata.yaml', | ||
'output_path': 'gs://mubucket/data/mymodel/', | ||
} | ||
} | ||
If 'args' is present in job_request and is a dict, it will be expanded to | ||
--key value or --key list_item_0 --key list_item_1, ... | ||
job_id: id for the training job. If None, an id based on timestamp will be generated. | ||
Returns: | ||
A Job object representing the cloud training job. | ||
""" | ||
new_job_request = dict(job_request) | ||
# convert job_args from dict to list as service required. | ||
if 'args' in job_request and isinstance(job_request['args'], dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if job_request['args'] is a list because someone did ['--abc=123', '--def=345', ..] or something else, there is no error message! Add some else statements. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If someone passes in a list, we pass it through to the service, and service will respond with errors if it is not expected. |
||
job_args = job_request['args'] | ||
args = [] | ||
for k,v in job_args.iteritems(): | ||
if isinstance(v, list): | ||
for item in v: | ||
args.append('--' + str(k)) | ||
args.append(str(item)) | ||
else: | ||
args.append('--' + str(k)) | ||
args.append(str(v)) | ||
new_job_request['args'] = args | ||
|
||
if job_id is None: | ||
job_id = datetime.datetime.now().strftime('%y%m%d_%H%M%S') | ||
if 'python_module' in new_job_request: | ||
job_id = new_job_request['python_module'].replace('.', '_') + \ | ||
'_' + job_id | ||
|
||
job = { | ||
'job_id': job_id, | ||
'training_input': new_job_request, | ||
} | ||
context = datalab.context.Context.default() | ||
cloudml = discovery.build('ml', 'v1', credentials=context.credentials) | ||
request = cloudml.projects().jobs().create(body=job, | ||
parent='projects/' + context.project_id) | ||
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0' | ||
request.execute() | ||
return Job(job_id) | ||
|
||
|
||
class Jobs(object): | ||
"""Represents a list of Cloud ML jobs for a project.""" | ||
|
@@ -82,8 +122,7 @@ def __init__(self, filter=None): | |
""" | ||
self._filter = filter | ||
self._context = datalab.context.Context.default() | ||
self._api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials, | ||
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL) | ||
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials) | ||
|
||
def _retrieve_jobs(self, page_token, page_size): | ||
list_info = self._api.projects().jobs().list(parent='projects/' + self._context.project_id, | ||
|
@@ -93,7 +132,9 @@ def _retrieve_jobs(self, page_token, page_size): | |
page_token = list_info.get('nextPageToken', None) | ||
return jobs, page_token | ||
|
||
def __iter__(self): | ||
def get_iterator(self): | ||
"""Get iterator of jobs so it can be used as "for model in Jobs().get_iterator()". | ||
""" | ||
return iter(datalab.utils.Iterator(self._retrieve_jobs)) | ||
|
||
def list(self, count=10): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the name change to get_iterator()? Also do the same change to the Jobs class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done on Jobs class. The reason of the change is that having a class both exposing methods and exposing itself as iterator is a bit odd and the iterator function is less discoverable.