Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit 186600a

Browse files
committed
Move cloud trainer and predictor from their own classes to Job and Model respectively. (#192)
* Move cloud trainer and predictor from their own classes to Job and Model respectively. Cloud trainer and predictor will be cleaned up in a seperate change. * Rename CloudModels to Models, CloudModelVersions to ModelVersions. Move their iterator from self to get_iterator() method. * Switch to cloudml v1 endpoint. * Remove one comment. * Follow up on CR comments. Fix a bug in datalab iterator that count keeps incrementing incorrectly.
1 parent b81244f commit 186600a

File tree

8 files changed

+160
-86
lines changed

8 files changed

+160
-86
lines changed

datalab/mlalpha/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ._tensorboard import TensorBoard
2525
from ._dataset import CsvDataSet, BigQueryDataSet
2626
from ._package import Packager
27-
from ._cloud_models import CloudModels, CloudModelVersions
27+
from ._cloud_models import Models, ModelVersions
2828
from ._confusion_matrix import ConfusionMatrix
2929
from ._analysis import csv_to_dataframe
3030
from ._package_runner import PackageRunner

datalab/mlalpha/_cloud_models.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,27 @@
1414

1515
from googleapiclient import discovery
1616
import os
17-
import time
1817
import yaml
1918

2019
import datalab.context
2120
import datalab.storage
2221
import datalab.utils
2322

23+
from . import _util
2424

25-
# TODO(qimingj) Remove once the API is public since it will no longer be needed
26-
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
27-
'ml_v1beta1_discovery.json'
28-
29-
30-
class CloudModels(object):
25+
class Models(object):
3126
"""Represents a list of Cloud ML models for a project."""
3227

3328
def __init__(self, project_id=None):
34-
"""Initializes an instance of a CloudML Model list that is iteratable
35-
("for model in CloudModels()").
36-
29+
"""
3730
Args:
3831
project_id: project_id of the models. If not provided, default project_id will be used.
3932
"""
4033
if project_id is None:
4134
project_id = datalab.context.Context.default().project_id
4235
self._project_id = project_id
4336
self._credentials = datalab.context.Context.default().credentials
44-
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
45-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
37+
self._api = discovery.build('ml', 'v1', credentials=self._credentials)
4638

4739
def _retrieve_models(self, page_token, page_size):
4840
list_info = self._api.projects().models().list(
@@ -51,11 +43,13 @@ def _retrieve_models(self, page_token, page_size):
5143
page_token = list_info.get('nextPageToken', None)
5244
return models, page_token
5345

54-
def __iter__(self):
46+
def get_iterator(self):
47+
"""Get iterator of models so it can be used as "for model in Models().get_iterator()".
48+
"""
5549
return iter(datalab.utils.Iterator(self._retrieve_models))
5650

5751
def get_model_details(self, model_name):
58-
"""Get details of a model.
52+
"""Get details of the specified model from CloudML Service.
5953
6054
Args:
6155
model_name: the name of the model. It can be a model full name
@@ -72,10 +66,16 @@ def create(self, model_name):
7266
7367
Args:
7468
model_name: the short name of the model, such as "iris".
69+
Returns:
70+
If successful, returns informaiton of the model, such as
71+
{u'regions': [u'us-central1'], u'name': u'projects/myproject/models/mymodel'}
72+
Raises:
73+
If the model creation failed.
7574
"""
7675
body = {'name': model_name}
7776
parent = 'projects/' + self._project_id
78-
self._api.projects().models().create(body=body, parent=parent).execute()
77+
# Model creation is instant. If anything goes wrong, Exception will be thrown.
78+
return self._api.projects().models().create(body=body, parent=parent).execute()
7979

8080
def delete(self, model_name):
8181
"""Delete a model.
@@ -87,7 +87,10 @@ def delete(self, model_name):
8787
full_name = model_name
8888
if not model_name.startswith('projects/'):
8989
full_name = ('projects/%s/models/%s' % (self._project_id, model_name))
90-
return self._api.projects().models().delete(name=full_name).execute()
90+
response = self._api.projects().models().delete(name=full_name).execute()
91+
if 'name' not in response:
92+
raise Exception('Invalid response from service. "name" is not found.')
93+
_util.wait_for_long_running_operation(response['name'])
9194

9295
def list(self, count=10):
9396
"""List models under the current project in a table view.
@@ -121,13 +124,11 @@ def describe(self, model_name):
121124
print model_yaml
122125

123126

124-
class CloudModelVersions(object):
127+
class ModelVersions(object):
125128
"""Represents a list of versions for a Cloud ML model."""
126129

127130
def __init__(self, model_name, project_id=None):
128-
"""Initializes an instance of a CloudML model version list that is iteratable
129-
("for version in CloudModelVersions()").
130-
131+
"""
131132
Args:
132133
model_name: the name of the model. It can be a model full name
133134
("projects/[project_id]/models/[model_name]") or just [model_name].
@@ -137,8 +138,7 @@ def __init__(self, model_name, project_id=None):
137138
if project_id is None:
138139
self._project_id = datalab.context.Context.default().project_id
139140
self._credentials = datalab.context.Context.default().credentials
140-
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
141-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
141+
self._api = discovery.build('ml', 'v1', credentials=self._credentials)
142142
if not model_name.startswith('projects/'):
143143
model_name = ('projects/%s/models/%s' % (self._project_id, model_name))
144144
self._full_model_name = model_name
@@ -152,7 +152,10 @@ def _retrieve_versions(self, page_token, page_size):
152152
page_token = list_info.get('nextPageToken', None)
153153
return versions, page_token
154154

155-
def __iter__(self):
155+
def get_iterator(self):
156+
"""Get iterator of versions so it can be used as
157+
"for v in ModelVersions(model_name).get_iterator()".
158+
"""
156159
return iter(datalab.utils.Iterator(self._retrieve_versions))
157160

158161
def get_version_details(self, version_name):
@@ -165,21 +168,6 @@ def get_version_details(self, version_name):
165168
name = ('%s/versions/%s' % (self._full_model_name, version_name))
166169
return self._api.projects().models().versions().get(name=name).execute()
167170

168-
def _wait_for_long_running_operation(self, response):
169-
if 'name' not in response:
170-
raise Exception('Invaid response from service. Cannot find "name" field.')
171-
print('Waiting for job "%s"' % response['name'])
172-
while True:
173-
response = self._api.projects().operations().get(name=response['name']).execute()
174-
if 'done' not in response or response['done'] != True:
175-
time.sleep(3)
176-
else:
177-
if 'error' in response:
178-
print(response['error'])
179-
else:
180-
print('Done.')
181-
break
182-
183171
def deploy(self, version_name, path):
184172
"""Deploy a model version to the cloud.
185173
@@ -211,7 +199,9 @@ def deploy(self, version_name, path):
211199
}
212200
response = self._api.projects().models().versions().create(body=body,
213201
parent=self._full_model_name).execute()
214-
self._wait_for_long_running_operation(response)
202+
if 'name' not in response:
203+
raise Exception('Invalid response from service. "name" is not found.')
204+
_util.wait_for_long_running_operation(response['name'])
215205

216206
def delete(self, version_name):
217207
"""Delete a version of model.
@@ -221,8 +211,37 @@ def delete(self, version_name):
221211
"""
222212
name = ('%s/versions/%s' % (self._full_model_name, version_name))
223213
response = self._api.projects().models().versions().delete(name=name).execute()
224-
self._wait_for_long_running_operation(response)
225-
214+
if 'name' not in response:
215+
raise Exception('Invalid response from service. "name" is not found.')
216+
_util.wait_for_long_running_operation(response['name'])
217+
218+
def predict(self, version_name, data):
219+
"""Get prediction results from features instances.
220+
221+
Args:
222+
version_name: the name of the version used for prediction.
223+
data: typically a list of instance to be submitted for prediction. The format of the
224+
instance depends on the model. For example, structured data model may require
225+
a csv line for each instance.
226+
Note that online prediction only works on models that take one placeholder value,
227+
such as a string encoding a csv line.
228+
Returns:
229+
A list of prediction results for given instances. Each element is a dictionary representing
230+
output mapping from the graph.
231+
An example:
232+
[{"predictions": 1, "score": [0.00078, 0.71406, 0.28515]},
233+
{"predictions": 1, "score": [0.00244, 0.99634, 0.00121]}]
234+
"""
235+
full_version_name = ('%s/versions/%s' % (self._full_model_name, version_name))
236+
request = self._api.projects().predict(body={'instances': data},
237+
name=full_version_name)
238+
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0'
239+
result = request.execute()
240+
if 'predictions' not in result:
241+
raise Exception('Invalid response from service. Cannot find "predictions" in response.')
242+
243+
return result['predictions']
244+
226245
def describe(self, version_name):
227246
"""Print information of a specified model.
228247

datalab/mlalpha/_job.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,39 +18,22 @@
1818
from googleapiclient import discovery
1919
import yaml
2020

21-
# TODO(qimingj) Remove once the API is public since it will no longer be needed
22-
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
23-
'ml_v1beta1_discovery.json'
24-
25-
import datalab.utils
26-
import datalab.context
27-
from googleapiclient import discovery
28-
import yaml
29-
30-
# TODO(qimingj) Remove once the API is public since it will no longer be needed
31-
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
32-
'ml_v1beta1_discovery.json'
33-
3421

3522
class Job(object):
3623
"""Represents a Cloud ML job."""
3724

38-
def __init__(self, name, context=None, api=None):
25+
def __init__(self, name, context=None):
3926
"""Initializes an instance of a CloudML Job.
4027
4128
Args:
4229
name: the name of the job. It can be an operation full name
43-
("projects/[project_id]/operations/[operation_name]") or just [operation_name].
30+
("projects/[project_id]/jobs/[operation_name]") or just [operation_name].
4431
context: an optional Context object providing project_id and credentials.
45-
api: optional CloudML API client.
4632
"""
4733
if context is None:
4834
context = datalab.context.Context.default()
4935
self._context = context
50-
if api is None:
51-
api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
52-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
53-
self._api = api
36+
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials)
5437
if not name.startswith('projects/'):
5538
name = 'projects/' + self._context.project_id + '/jobs/' + name
5639
self._name = name
@@ -68,6 +51,63 @@ def describe(self):
6851
job_yaml = yaml.safe_dump(self._info, default_flow_style=False)
6952
print job_yaml
7053

54+
@staticmethod
55+
def submit_training(job_request, job_id=None):
56+
"""Submit a training job.
57+
58+
Args:
59+
job_request: the arguments of the training job in a dict. For example,
60+
{
61+
'package_uris': 'gs://my-bucket/iris/trainer-0.1.tar.gz',
62+
'python_module': 'trainer.task',
63+
'scale_tier': 'BASIC',
64+
'region': 'us-central1',
65+
'args': {
66+
'train_data_paths': ['gs://mubucket/data/features_train'],
67+
'eval_data_paths': ['gs://mubucket/data/features_eval'],
68+
'metadata_path': 'gs://mubucket/data/metadata.yaml',
69+
'output_path': 'gs://mubucket/data/mymodel/',
70+
}
71+
}
72+
If 'args' is present in job_request and is a dict, it will be expanded to
73+
--key value or --key list_item_0 --key list_item_1, ...
74+
job_id: id for the training job. If None, an id based on timestamp will be generated.
75+
Returns:
76+
A Job object representing the cloud training job.
77+
"""
78+
new_job_request = dict(job_request)
79+
# convert job_args from dict to list as service required.
80+
if 'args' in job_request and isinstance(job_request['args'], dict):
81+
job_args = job_request['args']
82+
args = []
83+
for k,v in job_args.iteritems():
84+
if isinstance(v, list):
85+
for item in v:
86+
args.append('--' + str(k))
87+
args.append(str(item))
88+
else:
89+
args.append('--' + str(k))
90+
args.append(str(v))
91+
new_job_request['args'] = args
92+
93+
if job_id is None:
94+
job_id = datetime.datetime.now().strftime('%y%m%d_%H%M%S')
95+
if 'python_module' in new_job_request:
96+
job_id = new_job_request['python_module'].replace('.', '_') + \
97+
'_' + job_id
98+
99+
job = {
100+
'job_id': job_id,
101+
'training_input': new_job_request,
102+
}
103+
context = datalab.context.Context.default()
104+
cloudml = discovery.build('ml', 'v1', credentials=context.credentials)
105+
request = cloudml.projects().jobs().create(body=job,
106+
parent='projects/' + context.project_id)
107+
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0'
108+
request.execute()
109+
return Job(job_id)
110+
71111

72112
class Jobs(object):
73113
"""Represents a list of Cloud ML jobs for a project."""
@@ -82,8 +122,7 @@ def __init__(self, filter=None):
82122
"""
83123
self._filter = filter
84124
self._context = datalab.context.Context.default()
85-
self._api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
86-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
125+
self._api = discovery.build('ml', 'v1', credentials=self._context.credentials)
87126

88127
def _retrieve_jobs(self, page_token, page_size):
89128
list_info = self._api.projects().jobs().list(parent='projects/' + self._context.project_id,
@@ -93,7 +132,9 @@ def _retrieve_jobs(self, page_token, page_size):
93132
page_token = list_info.get('nextPageToken', None)
94133
return jobs, page_token
95134

96-
def __iter__(self):
135+
def get_iterator(self):
136+
"""Get iterator of jobs so it can be used as "for model in Jobs().get_iterator()".
137+
"""
97138
return iter(datalab.utils.Iterator(self._retrieve_jobs))
98139

99140
def list(self, count=10):

datalab/mlalpha/_util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,29 @@
1313
# limitations under the License.
1414

1515
import datetime
16+
from googleapiclient import discovery
1617
import os
1718
import shutil
1819
import subprocess
1920
import tempfile
21+
import time
22+
23+
import datalab.context
24+
25+
# TODO: Create an Operation class.
26+
def wait_for_long_running_operation(operation_full_name):
27+
print('Waiting for operation "%s"' % operation_full_name)
28+
api = discovery.build('ml', 'v1', credentials=datalab.context.Context.default().credentials)
29+
while True:
30+
response = api.projects().operations().get(name=operation_full_name).execute()
31+
if 'done' not in response or response['done'] != True:
32+
time.sleep(3)
33+
else:
34+
if 'error' in response:
35+
print(response['error'])
36+
else:
37+
print('Done.')
38+
break
2039

2140

2241
def package_and_copy(package_root_dir, setup_py, output_tar_path):

datalab/utils/_iterator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __iter__(self):
3737

3838
self._page_token = next_page_token
3939
self._first_page = False
40-
self._count += len(items)
40+
if self._count == 0:
41+
self._count = len(items)
4142

4243
for item in items:
4344
yield item

0 commit comments

Comments
 (0)