forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from ..layers.merge import concatenate | ||
from .. import backend as K | ||
from ..layers.core import Lambda | ||
from ..engine.training import Model | ||
|
||
|
||
def _get_available_devices(): | ||
from tensorflow.python.client import device_lib | ||
local_device_protos = device_lib.list_local_devices() | ||
return [x.name for x in local_device_protos] | ||
|
||
|
||
def multi_gpu_model(model, gpus): | ||
"""Replicates a model on different GPUs. | ||
Specifically, this function implements single-machine | ||
multi-GPU data parallelism. It works in the following way: | ||
- Divide the model's input(s) into multiple sub-batches. | ||
- Apply a model copy on each sub-batch. Every model copy | ||
is executed on a dedicated GPU. | ||
- Concatenate the results (on CPU) into one big batch. | ||
E.g. if your `batch_size` is 64 and you use `gpus=2`, | ||
then we will divide the input into 2 sub-batches of 32 samples, | ||
process each sub-batch on one GPU, then return the full | ||
batch of 64 processed samples. | ||
This induces quasi-linear speedup on up to 8 GPUs. | ||
This function is only available with the TensorFlow backend | ||
for the time being. | ||
# Arguments | ||
model: A Keras model instance. To avoid OOM errors, | ||
this model could have been built on CPU, for instance | ||
(see usage example below). | ||
gpus: Integer >= 2, number of on GPUs on which to create | ||
model replicas. | ||
# Returns | ||
A Keras `Model` instance which can be used just like the initial | ||
`model` argument, but which distributes its workload on multiple GPUs. | ||
# Example | ||
```python | ||
import tensorflow as tf | ||
from keras.applications import Xception | ||
num_samples = 1000 | ||
height = 224 | ||
width = 224 | ||
num_classes = 1000 | ||
# Instantiate the base model | ||
# (here, we do it on CPU, which is optional). | ||
with tf.device('/cpu:0'): | ||
model = Xception(weights=None, | ||
input_shape=(height, width, 3), | ||
classes=num_classes) | ||
# Replicates the model on 8 GPUs. | ||
# This assumes that your machine has 8 available GPUs. | ||
parallel_model = multi_gpu_model(model, gpus=8) | ||
parallel_model.compile(loss='categorical_crossentropy', | ||
optimizer='rmsprop') | ||
# Generate dummy data. | ||
x = np.random.random((num_samples, height, width, 3)) | ||
y = np.random.random((num_samples, num_classes)) | ||
# This `fit` call will be distributed on 8 GPUs. | ||
# Since the batch size is 256, each GPU will process 32 samples. | ||
parallel_model.fit(x, y, epochs=20, batch_size=256) | ||
``` | ||
""" | ||
if K.backend() != 'tensorflow': | ||
raise ValueError('`multi_gpu_model` is only available ' | ||
'with the TensorFlow backend.') | ||
if gpus <= 1: | ||
raise ValueError('For multi-gpu usage to be effective, ' | ||
'call `multi_gpu_model` with `gpus >= 2`. ' | ||
'Received: `gpus=%d`' % gpus) | ||
|
||
import tensorflow as tf | ||
|
||
target_devices = ['/cpu:0'] + ['/gpu:%d' % i for i in range(gpus)] | ||
available_devices = _get_available_devices() | ||
for device in target_devices: | ||
if device not in available_devices: | ||
raise ValueError( | ||
'To call `multi_gpu_model` with `gpus=%d`, ' | ||
'we expect the following devices to be available: %s. ' | ||
'However this machine only has: %s. ' | ||
'Try reducing `gpus`.' % (gpus, | ||
target_devices, | ||
available_devices)) | ||
|
||
def get_slice(data, i, parts): | ||
shape = tf.shape(data) | ||
batch_size = shape[:1] | ||
input_shape = shape[1:] | ||
step = batch_size // parts | ||
if i == gpus - 1: | ||
size = batch_size - step * i | ||
else: | ||
size = step | ||
size = tf.concat([size, input_shape], axis=0) | ||
stride = tf.concat([step, input_shape * 0], axis=0) | ||
start = stride * i | ||
return tf.slice(data, start, size) | ||
|
||
all_outputs = [] | ||
for i in range(len(model.outputs)): | ||
all_outputs.append([]) | ||
|
||
# Place a copy of the model on each GPU, | ||
# each getting a slice of the inputs. | ||
for i in range(gpus): | ||
with tf.device('/gpu:%d' % i): | ||
with tf.name_scope('replica_%d' % i): | ||
inputs = [] | ||
# Retrieve a slice of the input. | ||
for x in model.inputs: | ||
input_shape = tuple(x.get_shape().as_list())[1:] | ||
slice_i = Lambda(get_slice, | ||
output_shape=input_shape, | ||
arguments={'i': i, | ||
'parts': gpus})(x) | ||
inputs.append(slice_i) | ||
|
||
# Apply model on slice | ||
# (creating a model replica on the target device). | ||
outputs = model(inputs) | ||
if not isinstance(outputs, list): | ||
outputs = [outputs] | ||
|
||
# Save the outputs for merging back together later. | ||
for o in range(len(outputs)): | ||
all_outputs[o].append(outputs[o]) | ||
|
||
# Merge outputs on CPU. | ||
with tf.device('/cpu:0'): | ||
merged = [] | ||
for outputs in all_outputs: | ||
merged.append(concatenate(outputs, | ||
axis=0)) | ||
return Model(model.inputs, merged) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
"""These tests are not meant to be run on CI. | ||
""" | ||
from __future__ import print_function | ||
|
||
import keras | ||
from keras import backend as K | ||
from keras.utils import multi_gpu_model | ||
|
||
import numpy as np | ||
import pytest | ||
import time | ||
import tensorflow as tf | ||
from keras.preprocessing.image import ImageDataGenerator | ||
|
||
|
||
def multi_gpu_test_simple_model(): | ||
print('####### test simple model') | ||
num_samples = 1000 | ||
input_dim = 10 | ||
output_dim = 1 | ||
hidden_dim = 10 | ||
gpus = 8 | ||
epochs = 2 | ||
model = keras.models.Sequential() | ||
model.add(keras.layers.Dense(hidden_dim, | ||
input_shape=(input_dim,))) | ||
model.add(keras.layers.Dense(output_dim)) | ||
|
||
x = np.random.random((num_samples, input_dim)) | ||
y = np.random.random((num_samples, output_dim)) | ||
parallel_model = multi_gpu_model(model, gpus=gpus) | ||
|
||
parallel_model.compile(loss='mse', optimizer='rmsprop') | ||
parallel_model.fit(x, y, epochs=epochs) | ||
|
||
|
||
def multi_gpu_test_multi_io_model(): | ||
print('####### test multi-io model') | ||
num_samples = 1000 | ||
input_dim_a = 10 | ||
input_dim_b = 5 | ||
output_dim_a = 1 | ||
output_dim_b = 2 | ||
hidden_dim = 10 | ||
gpus = 8 | ||
epochs = 2 | ||
|
||
input_a = keras.Input((input_dim_a,)) | ||
input_b = keras.Input((input_dim_b,)) | ||
a = keras.layers.Dense(hidden_dim)(input_a) | ||
b = keras.layers.Dense(hidden_dim)(input_b) | ||
c = keras.layers.concatenate([a, b]) | ||
output_a = keras.layers.Dense(output_dim_a)(c) | ||
output_b = keras.layers.Dense(output_dim_b)(c) | ||
model = keras.models.Model([input_a, input_b], [output_a, output_b]) | ||
|
||
a_x = np.random.random((num_samples, input_dim_a)) | ||
b_x = np.random.random((num_samples, input_dim_b)) | ||
a_y = np.random.random((num_samples, output_dim_a)) | ||
b_y = np.random.random((num_samples, output_dim_b)) | ||
|
||
parallel_model = multi_gpu_model(model, gpus=gpus) | ||
parallel_model.compile(loss='mse', optimizer='rmsprop') | ||
parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) | ||
|
||
|
||
def multi_gpu_test_invalid_devices(): | ||
input_shape = (1000, 10) | ||
model = keras.models.Sequential() | ||
model.add(keras.layers.Dense(10, | ||
activation='relu', | ||
input_shape=input_shape[1:])) | ||
model.add(keras.layers.Dense(1, activation='sigmoid')) | ||
model.compile(loss='mse', optimizer='rmsprop') | ||
|
||
x = np.random.random(input_shape) | ||
y = np.random.random((input_shape[0], 1)) | ||
with pytest.raises(ValueError): | ||
parallel_model = multi_gpu_model(model, gpus=10) | ||
parallel_model.fit(x, y, epochs=2) | ||
|
||
|
||
def multi_gpu_application_np_array_benchmark(): | ||
print('####### Xception benchmark - np i/o') | ||
model_cls = keras.applications.Xception | ||
|
||
num_samples = 1000 | ||
height = 224 | ||
width = 224 | ||
num_classes = 1000 | ||
epochs = 4 | ||
batch_size = 40 | ||
x = np.random.random((num_samples, height, width, 3)) | ||
y = np.random.random((num_samples, num_classes)) | ||
|
||
# Baseline | ||
model = model_cls(weights=None, | ||
input_shape=(height, width, 3), | ||
classes=num_classes) | ||
model.compile(loss='categorical_crossentropy', | ||
optimizer='rmsprop') | ||
|
||
# Training | ||
start_time = time.time() | ||
model.fit(x, y, epochs=epochs) | ||
total_time = time.time() - start_time | ||
print('baseline training:', total_time) | ||
|
||
# Inference | ||
start_time = time.time() | ||
model.predict(x) | ||
total_time = time.time() - start_time | ||
print('baseline inference:', total_time) | ||
|
||
for i in range(8, 9): | ||
K.clear_session() | ||
with tf.device('/cpu:0'): | ||
model = model_cls(weights=None, | ||
input_shape=(height, width, 3), | ||
classes=num_classes) | ||
parallel_model = multi_gpu_model(model, gpus=i) | ||
parallel_model.compile(loss='categorical_crossentropy', | ||
optimizer='rmsprop') | ||
|
||
start_time = time.time() | ||
parallel_model.fit(x, y, epochs=epochs, batch_size=batch_size) | ||
total_time = time.time() - start_time | ||
print('%d gpus training:' % i, total_time) | ||
|
||
# Inference | ||
start_time = time.time() | ||
parallel_model.predict(x, batch_size=batch_size) | ||
total_time = time.time() - start_time | ||
print('%d gpus inference:' % i, total_time) | ||
|
||
|
||
def multi_gpu_application_folder_generator_benchmark(): | ||
"""Before running this test: | ||
wget https://s3.amazonaws.com/img-datasets/cats_and_dogs_small.zip | ||
unzip cats_and_dogs_small.zip | ||
""" | ||
print('####### Xception benchmark - folder generator i/o') | ||
model_cls = keras.applications.Xception | ||
|
||
height = 150 | ||
width = 150 | ||
num_classes = 2 | ||
epochs = 3 | ||
steps_per_epoch = 100 | ||
batch_size = 64 | ||
|
||
# Baseline | ||
model = model_cls(weights=None, | ||
input_shape=(height, width, 3), | ||
classes=num_classes) | ||
model.compile(loss='categorical_crossentropy', | ||
optimizer='rmsprop') | ||
|
||
datagen = ImageDataGenerator( | ||
rotation_range=40, | ||
width_shift_range=0.2, | ||
height_shift_range=0.2, | ||
shear_range=0.2, | ||
zoom_range=0.2, | ||
horizontal_flip=True, | ||
fill_mode='nearest') | ||
train_dir = '/home/ubuntu/cats_and_dogs_small/train' # Change this | ||
train_gen = datagen.flow_from_directory( | ||
train_dir, | ||
target_size=(height, width), | ||
batch_size=batch_size, | ||
class_mode='categorical') | ||
|
||
# Training | ||
start_time = time.time() | ||
model.fit_generator(train_gen, | ||
steps_per_epoch=steps_per_epoch, | ||
epochs=epochs, | ||
workers=4) | ||
total_time = time.time() - start_time | ||
print('baseline training:', total_time) | ||
|
||
for i in range(2, 9): | ||
K.clear_session() | ||
with tf.device('/cpu:0'): | ||
model = model_cls(weights=None, | ||
input_shape=(height, width, 3), | ||
classes=num_classes) | ||
parallel_model = multi_gpu_model(model, gpus=i) | ||
parallel_model.compile(loss='categorical_crossentropy', | ||
optimizer='rmsprop') | ||
|
||
train_gen = datagen.flow_from_directory( | ||
train_dir, | ||
target_size=(height, width), | ||
batch_size=batch_size, | ||
class_mode='categorical') | ||
|
||
start_time = time.time() | ||
parallel_model.fit_generator( | ||
train_gen, | ||
steps_per_epoch=steps_per_epoch, | ||
epochs=epochs, | ||
workers=4 * i) | ||
total_time = time.time() - start_time | ||
print('%d gpus training:' % i, total_time) | ||
|
||
|
||
if __name__ == '__main__': | ||
multi_gpu_test_simple_model() | ||
multi_gpu_test_multi_io_model() | ||
multi_gpu_test_invalid_devices() | ||
multi_gpu_application_np_array_benchmark() | ||
multi_gpu_application_folder_generator_benchmark() |