Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit efc5b88

Browse files
committed
Add CloudTrainingConfig namedtuple to wrap cloud training configurations (#178)
* Add CloudTrainingConfig namedtuple to wrap cloud training configurations. * Follow up code review comments.
1 parent 286cbc4 commit efc5b88

File tree

4 files changed

+53
-6
lines changed

4 files changed

+53
-6
lines changed

datalab/mlalpha/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ._analysis import csv_to_dataframe
3030
from ._package_runner import PackageRunner
3131
from ._feature_slice_view import FeatureSliceView
32+
from ._cloud_training_config import CloudTrainingConfig
3233
from ._util import *
3334

3435

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2017 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import namedtuple
16+
17+
_CloudTrainingConfig = namedtuple("CloudConfig",
18+
['region', 'scale_tier', 'master_type', 'worker_type',
19+
'parameter_server_type', 'worker_count', 'parameter_server_count'])
20+
_CloudTrainingConfig.__new__.__defaults__ = ('BASIC', None, None, None, None, None)
21+
22+
23+
class CloudTrainingConfig(_CloudTrainingConfig):
24+
"""A config namedtuple containing cloud specific configurations for CloudML training.
25+
26+
Fields:
27+
region: the region of the training job to be submitted. For example, "us-central1".
28+
Run "gcloud compute regions list" to get a list of regions.
29+
scale_tier: Specifies the machine types, the number of replicas for workers and
30+
parameter servers. For example, "STANDARD_1". See
31+
https://cloud.google.com/ml/reference/rest/v1beta1/projects.jobs#scaletier
32+
for list of accepted values.
33+
master_type: specifies the type of virtual machine to use for your training
34+
job's master worker. Must set this value when scale_tier is set to CUSTOM.
35+
See the link in "scale_tier".
36+
worker_type: specifies the type of virtual machine to use for your training
37+
job's worker nodes. Must set this value when scale_tier is set to CUSTOM.
38+
parameter_server_type: specifies the type of virtual machine to use for your training
39+
job's parameter server. Must set this value when scale_tier is set to CUSTOM.
40+
worker_count: the number of worker replicas to use for the training job. Each
41+
replica in the cluster will be of the type specified in "worker_type".
42+
Must set this value when scale_tier is set to CUSTOM.
43+
parameter_server_count: the number of parameter server replicas to use. Each
44+
replica in the cluster will be of the type specified in "parameter_server_type".
45+
Must set this value when scale_tier is set to CUSTOM.
46+
"""
47+
pass

solutionbox/inception/datalab_solutions/inception/_cloud.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def preprocess(self, dataset, output_dir, pipeline_option=None):
8686
p.run()
8787
return job_name
8888

89-
def train(self, input_dir, batch_size, max_steps, output_path,
90-
region, scale_tier):
89+
def train(self, input_dir, batch_size, max_steps, output_path, cloud_train_config):
9190
"""Cloud training with CloudML trainer service."""
9291

9392
import datalab.mlalpha as mlalpha
@@ -103,10 +102,9 @@ def train(self, input_dir, batch_size, max_steps, output_path,
103102
job_request = {
104103
'package_uris': staging_package_url,
105104
'python_module': 'datalab_solutions.inception.task',
106-
'scale_tier': scale_tier,
107-
'region': region,
108105
'args': job_args
109106
}
107+
job_request.update(dict(cloud_train_config._asdict()))
110108
cloud_runner = mlalpha.CloudRunner(job_request)
111109
job_id = 'inception_train_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S')
112110
return cloud_runner.run(job_id)

solutionbox/inception/datalab_solutions/inception/_package.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,20 @@ def local_train(input_dir, batch_size, max_steps, output_dir, checkpoint=None):
102102

103103

104104
def cloud_train(input_dir, batch_size, max_steps, output_dir,
105-
region, scale_tier='BASIC', checkpoint=None):
105+
cloud_train_config, checkpoint=None):
106106
"""Train model in the cloud with CloudML trainer service.
107107
The output can be used for local prediction or for online deployment.
108108
Args:
109109
input_dir: A directory path containing preprocessed results. GCS path only.
110110
batch_size: size of batch used for training.
111111
max_steps: number of steps to train.
112112
output_dir: The output directory to use. GCS path only.
113+
cloud_train_config: a datalab.ml.CloudTrainingConfig object.
113114
checkpoint: the Inception checkpoint to use.
114115
"""
115116

116117
job_info = _cloud.Cloud(checkpoint=checkpoint).train(input_dir, batch_size,
117-
max_steps, output_dir, region, scale_tier)
118+
max_steps, output_dir, cloud_train_config)
118119
if (_util.is_in_IPython()):
119120
import IPython
120121
log_url_query_strings = {

0 commit comments

Comments
 (0)