Skip to content

Commit

Permalink
Merge pull request #43 from kubeflow/cmle
Browse files Browse the repository at this point in the history
Fix resnet-cmle sample.
  • Loading branch information
qimingj authored Nov 5, 2018
2 parents 5fbb1e8 + 7cd1e70 commit 2fc92f3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
12 changes: 5 additions & 7 deletions components/resnet-cmle/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def parse_arguments():
type = str,
default = 'gs://flowers_resnet/tpu/resnet/data',
help = 'The data directory file generated by the preprocess component.')
parser.add_argument('--bucket',
parser.add_argument('--output',
type = str,
default = 'flowers_resnet',
help = 'Path to GCS bucket.')
help = 'Path to GCS location to store output.')
parser.add_argument('--region',
type = str,
default = 'us-central1',
Expand Down Expand Up @@ -75,7 +74,7 @@ def parse_arguments():
help = 'Number of classes.')
parser.add_argument('--TFVERSION',
type = str,
default = '1.8',
default = '1.9',
help = 'Version of TensorFlow to use.')
args = parser.parse_args()
return args
Expand All @@ -85,14 +84,13 @@ def parse_arguments():
args = parse_arguments()
job_name = 'imgclass_' + strftime("%y%m%d_%H%M%S", gmtime())

output_dir = 'gs://' + args.bucket + '/tpu/model'
output_dir = args.output + '/tpu/model'
logging.info('Submitting job for training to Cloud Machine Learning Engine')
subprocess.check_call('gcloud ml-engine jobs submit training ' + job_name + ' \
--region=' + args.region + ' \
--module-name=trainer.resnet_main \
--package-path=/resnet/resnet_model/trainer \
--job-dir=' + output_dir + ' \
--staging-bucket=gs://' + args.bucket + ' \
--scale-tier=BASIC_TPU \
--stream-logs \
--runtime-version=' + args.TFVERSION + ' \
Expand All @@ -116,4 +114,4 @@ def parse_arguments():
}]
}
with open('/mlpipeline-ui-metadata.json', 'w') as f:
json.dump(metadata, f)
json.dump(metadata, f)
6 changes: 3 additions & 3 deletions samples/resnet-cmle/resnet-train-pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def resnet_preprocess_op(project_id: 'GcpProject', output: 'GcsUri', train_csv: 'GcsUri[text/csv]', validation_csv: 'GcsUri[text/csv]', labels, step_name='preprocess'):
return dsl.ContainerOp(
name = step_name,
image = 'gcr.io/ml-pipeline/resnet-preprocess:0.0.42', # TODO: update it with a version number after a new release.
image = 'gcr.io/ml-pipeline/resnet-preprocess:staging', # TODO: update it with a version number after a new release.
arguments = [
'--project_id', project_id,
'--output', output,
Expand All @@ -34,7 +34,7 @@ def resnet_preprocess_op(project_id: 'GcpProject', output: 'GcsUri', train_csv:
def resnet_train_op(data_dir, output: 'GcsUri', region: 'GcpRegion', depth: int, train_batch_size: int, eval_batch_size: int, steps_per_eval: int, train_steps: int, num_train_images: int, num_eval_images: int, num_label_classes: int, tf_version, step_name='train'):
return dsl.ContainerOp(
name = step_name,
image = 'gcr.io/ml-pipeline/resnet-train:0.0.42',
image = 'gcr.io/ml-pipeline/resnet-train:staging', # TODO: update it with a version number after a new release.
arguments = [
'--data_dir', data_dir,
'--output', output,
Expand Down Expand Up @@ -76,7 +76,7 @@ def resnet_train(project_id: dsl.PipelineParam,
region: dsl.PipelineParam=dsl.PipelineParam(name='region', value='us-central1'),
model: dsl.PipelineParam=dsl.PipelineParam(name='model', value='bolts'),
version: dsl.PipelineParam=dsl.PipelineParam(name='version', value='beta1'),
tf_version: dsl.PipelineParam=dsl.PipelineParam(name='tf-version', value='1.8'),
tf_version: dsl.PipelineParam=dsl.PipelineParam(name='tf-version', value='1.9'), # TODO: CMLE TPU doesn't work with 1.9. Waiting for 1.11.
train_csv: dsl.PipelineParam=dsl.PipelineParam(name='train-csv', value='gs://bolts_image_dataset/bolt_images_train.csv'),
validation_csv: dsl.PipelineParam=dsl.PipelineParam(name='validation-csv', value='gs://bolts_image_dataset/bolt_images_validate.csv'),
labels: dsl.PipelineParam=dsl.PipelineParam(name='labels', value='gs://bolts_image_dataset/labels.txt'),
Expand Down

0 comments on commit 2fc92f3

Please sign in to comment.