Skip to content

Commit

Permalink
Use MNIST files loaded by SageMaker
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed May 30, 2018
1 parent 5acda0b commit ee44cab
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 28 deletions.
51 changes: 41 additions & 10 deletions keras/01-custom-container/mnist_cnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3

from __future__ import print_function
import os, sys, json, traceback
import os, sys, json, traceback, gzip
import numpy as np

import keras
from keras.datasets import mnist
from keras.models import Sequential, save_mxnet_model
Expand All @@ -11,30 +13,59 @@

from keras.utils import multi_gpu_model

# SageMaker paths
prefix = '/opt/ml/'
input_path = prefix + 'input/data/'
output_path = os.path.join(prefix, 'output')
model_path = os.path.join(prefix, 'model')
param_path = os.path.join(prefix, 'input/config/hyperparameters.json')
data_path = os.path.join(prefix, 'input/config/inputdataconfig.json')

# Load MNIST data copied by SageMaker
def load_data(input_path):
# Adapted from https://github.com/keras-team/keras/blob/master/keras/datasets/fashion_mnist.py

# Training and validation files
files = ['training/train-labels-idx1-ubyte.gz', 'training/train-images-idx3-ubyte.gz',
'validation/t10k-labels-idx1-ubyte.gz', 'validation/t10k-images-idx3-ubyte.gz']
# Load training labels
with gzip.open(input_path+files[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
# Load training samples
with gzip.open(input_path+files[1], 'rb') as imgpath:
x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
# Load validation labels
with gzip.open(input_path+files[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
# Load validation samples
with gzip.open(input_path+files[3], 'rb') as imgpath:
x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
print("Files loaded")
return (x_train, y_train), (x_test, y_test)

# Main code
try:
# Read in any hyperparameters that the user passed with the training job
# Read hyper parameters passed by SageMaker
with open(param_path, 'r') as params:
trainingParams = json.load(params)
print(trainingParams)
hyperParams = json.load(params)
print("Hyper parameters: " + str(hyperParams))

lr = float(trainingParams.get('lr', '0.1'))
batch_size = int(trainingParams.get('batch_size', '128'))
epochs = int(trainingParams.get('epochs', '10'))
gpu_count = int(trainingParams.get('gpu_count', '0'))
lr = float(hyperParams.get('lr', '0.1'))
batch_size = int(hyperParams.get('batch_size', '128'))
epochs = int(hyperParams.get('epochs', '10'))
gpu_count = int(hyperParams.get('gpu_count', '0'))

num_classes = 10
# Read input data config passed by SageMaker
with open(data_path, 'r') as params:
inputParams = json.load(params)
print("Input parameters: " + str(inputParams))

num_classes = 10
# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
(x_train, y_train), (x_test, y_test) = load_data(input_path)

if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
Expand Down
36 changes: 18 additions & 18 deletions keras/01-custom-container/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
"# Building CPU and GPU containers for Keras-MXNet on Amazon SageMaker"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sagemaker\n",
"sess = sagemaker.Session()\n",
"role = sagemaker.get_execution_role()\n",
"account = sess.boto_session.client('sts').get_caller_identity()['Account']\n",
"region = sess.boto_session.region_name"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -60,19 +73,6 @@
"## Create and login to a repository in ECR"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sagemaker\n",
"sess = sagemaker.Session()\n",
"role = sagemaker.get_execution_role()\n",
"account = sess.boto_session.client('sts').get_caller_identity()['Account']\n",
"region = sess.boto_session.region_name"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -232,7 +232,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Upload training data to S3"
"## Upload MNIST data to S3"
]
},
{
Expand All @@ -242,9 +242,10 @@
"outputs": [],
"source": [
"local_directory = 'data'\n",
"prefix = repo_name+'/input'\n",
"prefix = repo_name+'/input'\n",
"\n",
"input_path = sess.upload_data(local_directory, key_prefix=prefix)"
"train_input_path = sess.upload_data(local_directory+'/train/', key_prefix=prefix+'/train')\n",
"validation_input_path = sess.upload_data(local_directory+'/validation/', key_prefix=prefix+'/validation')"
]
},
{
Expand All @@ -263,7 +264,6 @@
"output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)\n",
"image_name = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)\n",
"\n",
"print(input_path)\n",
"print(output_path)\n",
"print(image_name)\n",
"\n",
Expand All @@ -278,7 +278,7 @@
"\n",
"estimator.set_hyperparameters(lr=0.01, epochs=10, gpus=gpu_count, batch_size=batch_size)\n",
"\n",
"estimator.fit(input_path)"
"estimator.fit({'training': train_input_path, 'validation': validation_input_path})"
]
},
{
Expand Down

0 comments on commit ee44cab

Please sign in to comment.