Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit fa99d76

Browse files
authored
Datalab Inception (image classification) solution. (#117)
* Datalab Inception (image classification) solution. * Fix dataflow URL.
1 parent 9b27cbe commit fa99d76

File tree

13 files changed

+2216
-0
lines changed

13 files changed

+2216
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2017 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2017 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
13+
14+
from ._package import local_preprocess, cloud_preprocess, local_train, cloud_train, local_predict, cloud_predict
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2017 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Cloud implementation for preprocessing, training and prediction for inception model.
17+
"""
18+
19+
import apache_beam as beam
20+
import base64
21+
import collections
22+
import datetime
23+
from googleapiclient import discovery
24+
import google.cloud.ml as ml
25+
import logging
26+
import os
27+
28+
from . import _model
29+
from . import _preprocess
30+
from . import _trainer
31+
from . import _util
32+
33+
34+
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
35+
'ml_v1beta1_discovery.json'
36+
_TF_GS_URL= 'gs://cloud-datalab/deploy/tf/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl'
37+
38+
39+
class Cloud(object):
40+
"""Class for cloud training, preprocessing and prediction."""
41+
42+
def __init__(self, project, checkpoint=None):
43+
self._project = project
44+
self._checkpoint = checkpoint
45+
if self._checkpoint is None:
46+
self._checkpoint = _util._DEFAULT_CHECKPOINT_GSURL
47+
48+
def preprocess(self, input_csvs, labels_file, output_dir, pipeline_option=None):
49+
"""Cloud preprocessing with Cloud DataFlow."""
50+
51+
job_name = 'preprocess-inception-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')
52+
options = {
53+
'staging_location': os.path.join(output_dir, 'tmp', 'staging'),
54+
'temp_location': os.path.join(output_dir, 'tmp'),
55+
'job_name': job_name,
56+
'project': self._project,
57+
'extra_packages': [ml.sdk_location, _util._PACKAGE_GS_URL, _TF_GS_URL],
58+
'teardown_policy': 'TEARDOWN_ALWAYS',
59+
'no_save_main_session': True
60+
}
61+
if pipeline_option is not None:
62+
options.update(pipeline_option)
63+
64+
opts = beam.pipeline.PipelineOptions(flags=[], **options)
65+
p = beam.Pipeline('DataflowPipelineRunner', options=opts)
66+
_preprocess.configure_pipeline(
67+
p, self._checkpoint, input_csvs, labels_file, output_dir, job_name)
68+
p.run()
69+
70+
def train(self, labels_file, input_dir, batch_size, max_steps, output_path, credentials,
71+
region, scale_tier):
72+
"""Cloud training with CloudML trainer service."""
73+
74+
num_classes = len(_util.get_labels(labels_file))
75+
job_id = 'inception_train_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S')
76+
job_args_dict = {
77+
'input_dir': input_dir,
78+
'output_path': output_path,
79+
'max_steps': max_steps,
80+
'batch_size': batch_size,
81+
'num_classes': num_classes,
82+
'checkpoint': self._checkpoint
83+
}
84+
# convert job_args from dict to list as service required.
85+
job_args = []
86+
for k,v in job_args_dict.iteritems():
87+
if isinstance(v, list):
88+
for item in v:
89+
90+
job_args.append('--' + k)
91+
job_args.append(str(item))
92+
else:
93+
job_args.append('--' + k)
94+
job_args.append(str(v))
95+
96+
job_request = {
97+
'package_uris': _util._PACKAGE_GS_URL,
98+
'python_module': 'datalab_solutions.inception.task',
99+
'scale_tier': scale_tier,
100+
'region': region,
101+
'args': job_args
102+
}
103+
job = {
104+
'job_id': job_id,
105+
'training_input': job_request,
106+
}
107+
cloudml = discovery.build('ml', 'v1beta1', discoveryServiceUrl=_CLOUDML_DISCOVERY_URL,
108+
credentials=credentials)
109+
request = cloudml.projects().jobs().create(body=job,
110+
parent='projects/' + self._project)
111+
request.headers['user-agent'] = 'GoogleCloudDataLab/1.0'
112+
job_info = request.execute()
113+
return job_info
114+
115+
def predict(self, model_id, image_files, labels_file, credentials):
116+
"""Cloud prediction with CloudML prediction service."""
117+
118+
labels = _util.get_labels(labels_file)
119+
data = []
120+
for ii, img_file in enumerate(image_files):
121+
with ml.util._file.open_local_or_gcs(img_file, 'rb') as f:
122+
img = base64.b64encode(f.read())
123+
data.append({
124+
'key': str(ii),
125+
'image_bytes': {'b64': img}
126+
})
127+
parts = model_id.split('.')
128+
if len(parts) != 2:
129+
raise Exception('Invalid model name for cloud prediction. Use "model.version".')
130+
full_version_name = ('projects/%s/models/%s/versions/%s' % (self._project, parts[0], parts[1]))
131+
api = discovery.build('ml', 'v1beta1', credentials=credentials,
132+
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
133+
request = api.projects().predict(body={'instances': data}, name=full_version_name)
134+
job_results = request.execute()
135+
if 'predictions' not in job_results:
136+
raise Exception('Invalid response from service. Cannot find "predictions" in response.')
137+
predictions = job_results['predictions']
138+
labels_and_scores = [(labels[x['prediction']], x['scores'][x['prediction']])
139+
for x in predictions]
140+
return labels_and_scores

0 commit comments

Comments
 (0)