Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Cloudmlsdp #177

Merged
merged 5 commits into from
Feb 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -277,50 +277,30 @@ def cloud_train(train_file_pattern,
print(job_request)


def local_predict(model_dir, prediction_input_file):
def local_predict():
"""Runs local prediction.

Runs local prediction in memory and prints the results to the screen. For
running prediction on a large dataset or saving the results, run
local_batch_prediction or batch_prediction.

Args:
model_dir: Path to folder that contains the model. This is usully OUT/model
where OUT is the value of output_dir when local_training was ran.
prediction_input_file: csv file that has the same schem as the input
files used during local_preprocess, except that the target column is
removed.

"""
pass
#TODO(brandondutra): remove this hack once cloudml 1.8 is released.
# Check that the model folder has a metadata.yaml file. If not, copy it.
# if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
# shutil.copy2(os.path.join(model_dir, 'metadata.json'),
# os.path.join(model_dir, 'metadata.yaml'))
# Save the instances to a file, call local batch prediction, and print it back


# cmd = ['gcloud beta ml local predict',
# '--model-dir=%s' % model_dir,
# '--text-instances=%s' % prediction_input_file]
# print('Local prediction, running command: %s' % ' '.join(cmd))
# _run_cmd(' '.join(cmd))
# print('Local prediction done.')


def cloud_predict(model_name, prediction_input_file, version_name=None):
def cloud_predict():
"""Use Online prediction.

Runs online prediction in the cloud and prints the results to the screen. For
running prediction on a large dataset or saving the results, run
local_batch_prediction or batch_prediction.

Args:
model_dir: Path to folder that contains the model. This is usully OUT/model
where OUT is the value of output_dir when local_training was ran.
prediction_input_file: csv file that has the same schem as the input
files used during local_preprocess, except that the target column is
removed.
vsersion_name: Optional version of the model to use. If None, the default
version is used.


Before using this, the model must be created. This can be done by running
two gcloud commands:
Expand All @@ -334,91 +314,67 @@ def cloud_predict(model_name, prediction_input_file, version_name=None):
Note that the model must be on GCS.
"""
pass
# cmd = ['gcloud beta ml predict',
# '--model=%s' % model_name,
# '--text-instances=%s' % prediction_input_file]
# if version_name:
# cmd += ['--version=%s' % version_name]

# print('CloudML online prediction, running command: %s' % ' '.join(cmd))
# _run_cmd(' '.join(cmd))
# print('CloudML online prediction done.')


def local_batch_predict(model_dir, prediction_input_file, output_dir):
def local_batch_predict(model_dir, prediction_input_file, output_dir,
batch_size=1000, shard_files=True):
"""Local batch prediction.

Args:
model_dir: local path to trained model.
prediction_input_file: File path to input files. May contain a file pattern.
Only csv files are supported, and the scema must match what was used
in preprocessing except that the target column is removed.
output_dir: folder to save results to.
model_dir: local file path to trained model. Usually, this is
training_output_dir/model.
prediction_input_file: csv file pattern to a local file.
output_dir: local output location to save the results.
batch_size: Int. How many instances to run in memory at once. Larger values
mean better performace but more memeory consumed.
shard_files: If false, the output files are not shardded.
"""
pass
#TODO(brandondutra): remove this hack once cloudml 1.8 is released.
# Check that the model folder has a metadata.yaml file. If not, copy it.
# if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
# shutil.copy2(os.path.join(model_dir, 'metadata.json'),
# os.path.join(model_dir, 'metadata.yaml'))
cmd = ['predict.py',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one does not block this change, maybe a TODO: what if the process hangs? What if users want to cancel the job? Hitting "Reset Session" doesn't work since it does not kill child processes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not making a call to subprocess. I'm calling my main directly :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. the 'cmd' name gave me that impression.

'--predict_data=%s' % prediction_input_file,
'--trained_model_dir=%s' % model_dir,
'--output_dir=%s' % output_dir,
'--output_format=csv',
'--batch_size=%s' % str(batch_size)]

if shard_files:
cmd.append('--shard_files')
else:
cmd.append('--no-shard_files')

# cmd = ['python -m google.cloud.ml.dataflow.batch_prediction_main',
# '--input_file_format=text',
# '--input_file_patterns=%s' % prediction_input_file,
# '--output_location=%s' % output_dir,
# '--model_dir=%s' % model_dir]
print('Starting local batch prediction.')
predict.predict.main(args)
print('Local batch prediction done.')

# print('Local batch prediction, running command: %s' % ' '.join(cmd))
# _run_cmd(' '.join(cmd))
# print('Local batch prediction done.')


def cloud_batch_predict(model_name, prediction_input_file, output_dir, region,
job_name=None, version_name=None):
"""Cloud batch prediction.
def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
batch_size=1000, shard_files=True):
"""Cloud batch prediction. Submitts a Dataflow job.

Args:
model_name: name of the model. The model must already exist.
prediction_input_file: File path to input files. May contain a file pattern.
Only csv files are supported, and the scema must match what was used
in preprocessing except that the target column is removed. Files must
be on GCS
output_dir: GCS folder to safe results to.
region: GCP compute region to run the batch job. Try using your default
region first, as this cloud batch prediction is not avaliable in all
regions.
job_name: job name used for the cloud job.
version_name: model version to use. If node, the default version of the
model is used.
model_dir: GSC file path to trained model. Usually, this is
training_output_dir/model.
prediction_input_file: csv file pattern to a GSC file.
output_dir: Location to save the results on GCS.
batch_size: Int. How many instances to run in memory at once. Larger values
mean better performace but more memeory consumed.
shard_files: If false, the output files are not shardded.
"""
pass
# job_name = job_name or ('structured_data_batch_predict_' +
# datetime.datetime.now().strftime('%Y%m%d%H%M%S'))

# if (not prediction_input_file.startswith('gs://') or
# not output_dir.startswith('gs://')):
# print('ERROR: prediction_input_file and output_dir must point to a '
# 'location on GCS.')
# return

# cmd = ['gcloud beta ml jobs submit prediction %s' % job_name,
# '--model=%s' % model_name,
# '--region=%s' % region,
# '--data-format=TEXT',
# '--input-paths=%s' % prediction_input_file,
# '--output-path=%s' % output_dir]
# if version_name:
# cmd += ['--version=%s' % version_name]

# print('CloudML batch prediction, running command: %s' % ' '.join(cmd))
# _run_cmd(' '.join(cmd))
# print('CloudML batch prediction job submitted.')

# if _is_in_IPython():
# import IPython

# dataflow_url = ('https://console.developers.google.com/ml/jobs?project=%s'
# % _default_project())
# html = ('<p>Click <a href="%s" target="_blank">here</a> to track '
# 'the prediction job %s.</p><br/>' % (dataflow_url, job_name))
# IPython.display.display_html(html, raw=True)
cmd = ['predict.py',
'--cloud',
'--project_id=%s' % _default_project(),
'--predict_data=%s' % prediction_input_file,
'--trained_model_dir=%s' % model_dir,
'--output_dir=%s' % output_dir,
'--output_format=csv',
'--batch_size=%s' % str(batch_size)]

if shard_files:
cmd.append('--shard_files')
else:
cmd.append('--no-shard_files')

print('Starting cloud batch prediction.')
predict.predict.main(args)
print('See above link for job status.')
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import predict
Loading