Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit 3090cd4

Browse files
brandondutraqimingj
authored andcommitted
Cloudmlsdp (#177)
* added the ',' graph hack * sw * batch prediction done * sw * review comments
1 parent 7fc0ca5 commit 3090cd4

File tree

5 files changed

+621
-117
lines changed

5 files changed

+621
-117
lines changed

solutionbox/structured_data/datalab_solutions/structured_data/_package.py

Lines changed: 56 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -277,50 +277,30 @@ def cloud_train(train_file_pattern,
277277
print(job_request)
278278

279279

280-
def local_predict(model_dir, prediction_input_file):
280+
def local_predict():
281281
"""Runs local prediction.
282282
283283
Runs local prediction in memory and prints the results to the screen. For
284284
running prediction on a large dataset or saving the results, run
285285
local_batch_prediction or batch_prediction.
286286
287287
Args:
288-
model_dir: Path to folder that contains the model. This is usully OUT/model
289-
where OUT is the value of output_dir when local_training was ran.
290-
prediction_input_file: csv file that has the same schem as the input
291-
files used during local_preprocess, except that the target column is
292-
removed.
288+
293289
"""
294-
pass
295-
#TODO(brandondutra): remove this hack once cloudml 1.8 is released.
296-
# Check that the model folder has a metadata.yaml file. If not, copy it.
297-
# if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
298-
# shutil.copy2(os.path.join(model_dir, 'metadata.json'),
299-
# os.path.join(model_dir, 'metadata.yaml'))
290+
# Save the instances to a file, call local batch prediction, and print it back
291+
300292

301-
# cmd = ['gcloud beta ml local predict',
302-
# '--model-dir=%s' % model_dir,
303-
# '--text-instances=%s' % prediction_input_file]
304-
# print('Local prediction, running command: %s' % ' '.join(cmd))
305-
# _run_cmd(' '.join(cmd))
306-
# print('Local prediction done.')
307293

308294

309-
def cloud_predict(model_name, prediction_input_file, version_name=None):
295+
def cloud_predict():
310296
"""Use Online prediction.
311297
312298
Runs online prediction in the cloud and prints the results to the screen. For
313299
running prediction on a large dataset or saving the results, run
314300
local_batch_prediction or batch_prediction.
315301
316302
Args:
317-
model_dir: Path to folder that contains the model. This is usully OUT/model
318-
where OUT is the value of output_dir when local_training was ran.
319-
prediction_input_file: csv file that has the same schem as the input
320-
files used during local_preprocess, except that the target column is
321-
removed.
322-
vsersion_name: Optional version of the model to use. If None, the default
323-
version is used.
303+
324304
325305
Before using this, the model must be created. This can be done by running
326306
two gcloud commands:
@@ -334,91 +314,67 @@ def cloud_predict(model_name, prediction_input_file, version_name=None):
334314
Note that the model must be on GCS.
335315
"""
336316
pass
337-
# cmd = ['gcloud beta ml predict',
338-
# '--model=%s' % model_name,
339-
# '--text-instances=%s' % prediction_input_file]
340-
# if version_name:
341-
# cmd += ['--version=%s' % version_name]
342317

343-
# print('CloudML online prediction, running command: %s' % ' '.join(cmd))
344-
# _run_cmd(' '.join(cmd))
345-
# print('CloudML online prediction done.')
346318

347319

348-
def local_batch_predict(model_dir, prediction_input_file, output_dir):
320+
def local_batch_predict(model_dir, prediction_input_file, output_dir,
321+
batch_size=1000, shard_files=True):
349322
"""Local batch prediction.
350323
351324
Args:
352-
model_dir: local path to trained model.
353-
prediction_input_file: File path to input files. May contain a file pattern.
354-
Only csv files are supported, and the scema must match what was used
355-
in preprocessing except that the target column is removed.
356-
output_dir: folder to save results to.
325+
model_dir: local file path to trained model. Usually, this is
326+
training_output_dir/model.
327+
prediction_input_file: csv file pattern to a local file.
328+
output_dir: local output location to save the results.
329+
batch_size: Int. How many instances to run in memory at once. Larger values
330+
mean better performace but more memeory consumed.
331+
shard_files: If false, the output files are not shardded.
357332
"""
358-
pass
359-
#TODO(brandondutra): remove this hack once cloudml 1.8 is released.
360-
# Check that the model folder has a metadata.yaml file. If not, copy it.
361-
# if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
362-
# shutil.copy2(os.path.join(model_dir, 'metadata.json'),
363-
# os.path.join(model_dir, 'metadata.yaml'))
333+
cmd = ['predict.py',
334+
'--predict_data=%s' % prediction_input_file,
335+
'--trained_model_dir=%s' % model_dir,
336+
'--output_dir=%s' % output_dir,
337+
'--output_format=csv',
338+
'--batch_size=%s' % str(batch_size)]
339+
340+
if shard_files:
341+
cmd.append('--shard_files')
342+
else:
343+
cmd.append('--no-shard_files')
364344

365-
# cmd = ['python -m google.cloud.ml.dataflow.batch_prediction_main',
366-
# '--input_file_format=text',
367-
# '--input_file_patterns=%s' % prediction_input_file,
368-
# '--output_location=%s' % output_dir,
369-
# '--model_dir=%s' % model_dir]
345+
print('Starting local batch prediction.')
346+
predict.predict.main(args)
347+
print('Local batch prediction done.')
370348

371-
# print('Local batch prediction, running command: %s' % ' '.join(cmd))
372-
# _run_cmd(' '.join(cmd))
373-
# print('Local batch prediction done.')
374349

375350

376-
def cloud_batch_predict(model_name, prediction_input_file, output_dir, region,
377-
job_name=None, version_name=None):
378-
"""Cloud batch prediction.
351+
def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
352+
batch_size=1000, shard_files=True):
353+
"""Cloud batch prediction. Submitts a Dataflow job.
379354
380355
Args:
381-
model_name: name of the model. The model must already exist.
382-
prediction_input_file: File path to input files. May contain a file pattern.
383-
Only csv files are supported, and the scema must match what was used
384-
in preprocessing except that the target column is removed. Files must
385-
be on GCS
386-
output_dir: GCS folder to safe results to.
387-
region: GCP compute region to run the batch job. Try using your default
388-
region first, as this cloud batch prediction is not avaliable in all
389-
regions.
390-
job_name: job name used for the cloud job.
391-
version_name: model version to use. If node, the default version of the
392-
model is used.
356+
model_dir: GSC file path to trained model. Usually, this is
357+
training_output_dir/model.
358+
prediction_input_file: csv file pattern to a GSC file.
359+
output_dir: Location to save the results on GCS.
360+
batch_size: Int. How many instances to run in memory at once. Larger values
361+
mean better performace but more memeory consumed.
362+
shard_files: If false, the output files are not shardded.
393363
"""
394-
pass
395-
# job_name = job_name or ('structured_data_batch_predict_' +
396-
# datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
397-
398-
# if (not prediction_input_file.startswith('gs://') or
399-
# not output_dir.startswith('gs://')):
400-
# print('ERROR: prediction_input_file and output_dir must point to a '
401-
# 'location on GCS.')
402-
# return
403-
404-
# cmd = ['gcloud beta ml jobs submit prediction %s' % job_name,
405-
# '--model=%s' % model_name,
406-
# '--region=%s' % region,
407-
# '--data-format=TEXT',
408-
# '--input-paths=%s' % prediction_input_file,
409-
# '--output-path=%s' % output_dir]
410-
# if version_name:
411-
# cmd += ['--version=%s' % version_name]
412-
413-
# print('CloudML batch prediction, running command: %s' % ' '.join(cmd))
414-
# _run_cmd(' '.join(cmd))
415-
# print('CloudML batch prediction job submitted.')
416-
417-
# if _is_in_IPython():
418-
# import IPython
419-
420-
# dataflow_url = ('https://console.developers.google.com/ml/jobs?project=%s'
421-
# % _default_project())
422-
# html = ('<p>Click <a href="%s" target="_blank">here</a> to track '
423-
# 'the prediction job %s.</p><br/>' % (dataflow_url, job_name))
424-
# IPython.display.display_html(html, raw=True)
364+
cmd = ['predict.py',
365+
'--cloud',
366+
'--project_id=%s' % _default_project(),
367+
'--predict_data=%s' % prediction_input_file,
368+
'--trained_model_dir=%s' % model_dir,
369+
'--output_dir=%s' % output_dir,
370+
'--output_format=csv',
371+
'--batch_size=%s' % str(batch_size)]
372+
373+
if shard_files:
374+
cmd.append('--shard_files')
375+
else:
376+
cmd.append('--no-shard_files')
377+
378+
print('Starting cloud batch prediction.')
379+
predict.predict.main(args)
380+
print('See above link for job status.')
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2017 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import predict

0 commit comments

Comments
 (0)