diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py index d11713db4af..e4725783d4c 100644 --- a/keras/utils/__init__.py +++ b/keras/utils/__init__.py @@ -21,3 +21,4 @@ from .vis_utils import plot_model from .np_utils import to_categorical from .np_utils import normalize +from .training_utils import multi_gpu_model diff --git a/keras/utils/training_utils.py b/keras/utils/training_utils.py new file mode 100644 index 00000000000..b8a200228a2 --- /dev/null +++ b/keras/utils/training_utils.py @@ -0,0 +1,149 @@ +from ..layers.merge import concatenate +from .. import backend as K +from ..layers.core import Lambda +from ..engine.training import Model + + +def _get_available_devices(): + from tensorflow.python.client import device_lib + local_device_protos = device_lib.list_local_devices() + return [x.name for x in local_device_protos] + + +def multi_gpu_model(model, gpus): + """Replicates a model on different GPUs. + + Specifically, this function implements single-machine + multi-GPU data parallelism. It works in the following way: + + - Divide the model's input(s) into multiple sub-batches. + - Apply a model copy on each sub-batch. Every model copy + is executed on a dedicated GPU. + - Concatenate the results (on CPU) into one big batch. + + E.g. if your `batch_size` is 64 and you use `gpus=2`, + then we will divide the input into 2 sub-batches of 32 samples, + process each sub-batch on one GPU, then return the full + batch of 64 processed samples. + + This induces quasi-linear speedup on up to 8 GPUs. + + This function is only available with the TensorFlow backend + for the time being. + + # Arguments + model: A Keras model instance. To avoid OOM errors, + this model could have been built on CPU, for instance + (see usage example below). + gpus: Integer >= 2, number of on GPUs on which to create + model replicas. + + # Returns + A Keras `Model` instance which can be used just like the initial + `model` argument, but which distributes its workload on multiple GPUs. + + # Example + + ```python + import tensorflow as tf + from keras.applications import Xception + + num_samples = 1000 + height = 224 + width = 224 + num_classes = 1000 + + # Instantiate the base model + # (here, we do it on CPU, which is optional). + with tf.device('/cpu:0'): + model = Xception(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + + # Replicates the model on 8 GPUs. + # This assumes that your machine has 8 available GPUs. + parallel_model = multi_gpu_model(model, gpus=8) + parallel_model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + # Generate dummy data. + x = np.random.random((num_samples, height, width, 3)) + y = np.random.random((num_samples, num_classes)) + + # This `fit` call will be distributed on 8 GPUs. + # Since the batch size is 256, each GPU will process 32 samples. + parallel_model.fit(x, y, epochs=20, batch_size=256) + ``` + """ + if K.backend() != 'tensorflow': + raise ValueError('`multi_gpu_model` is only available ' + 'with the TensorFlow backend.') + if gpus <= 1: + raise ValueError('For multi-gpu usage to be effective, ' + 'call `multi_gpu_model` with `gpus >= 2`. ' + 'Received: `gpus=%d`' % gpus) + + import tensorflow as tf + + target_devices = ['/cpu:0'] + ['/gpu:%d' % i for i in range(gpus)] + available_devices = _get_available_devices() + for device in target_devices: + if device not in available_devices: + raise ValueError( + 'To call `multi_gpu_model` with `gpus=%d`, ' + 'we expect the following devices to be available: %s. ' + 'However this machine only has: %s. ' + 'Try reducing `gpus`.' % (gpus, + target_devices, + available_devices)) + + def get_slice(data, i, parts): + shape = tf.shape(data) + batch_size = shape[:1] + input_shape = shape[1:] + step = batch_size // parts + if i == gpus - 1: + size = batch_size - step * i + else: + size = step + size = tf.concat([size, input_shape], axis=0) + stride = tf.concat([step, input_shape * 0], axis=0) + start = stride * i + return tf.slice(data, start, size) + + all_outputs = [] + for i in range(len(model.outputs)): + all_outputs.append([]) + + # Place a copy of the model on each GPU, + # each getting a slice of the inputs. + for i in range(gpus): + with tf.device('/gpu:%d' % i): + with tf.name_scope('replica_%d' % i): + inputs = [] + # Retrieve a slice of the input. + for x in model.inputs: + input_shape = tuple(x.get_shape().as_list())[1:] + slice_i = Lambda(get_slice, + output_shape=input_shape, + arguments={'i': i, + 'parts': gpus})(x) + inputs.append(slice_i) + + # Apply model on slice + # (creating a model replica on the target device). + outputs = model(inputs) + if not isinstance(outputs, list): + outputs = [outputs] + + # Save the outputs for merging back together later. + for o in range(len(outputs)): + all_outputs[o].append(outputs[o]) + + # Merge outputs on CPU. + with tf.device('/cpu:0'): + merged = [] + for outputs in all_outputs: + merged.append(concatenate(outputs, + axis=0)) + return Model(model.inputs, merged) diff --git a/tests/keras/utils/multi_gpu_test.py b/tests/keras/utils/multi_gpu_test.py new file mode 100644 index 00000000000..119730bd92b --- /dev/null +++ b/tests/keras/utils/multi_gpu_test.py @@ -0,0 +1,215 @@ +"""These tests are not meant to be run on CI. +""" +from __future__ import print_function + +import keras +from keras import backend as K +from keras.utils import multi_gpu_model + +import numpy as np +import pytest +import time +import tensorflow as tf +from keras.preprocessing.image import ImageDataGenerator + + +def multi_gpu_test_simple_model(): + print('####### test simple model') + num_samples = 1000 + input_dim = 10 + output_dim = 1 + hidden_dim = 10 + gpus = 8 + epochs = 2 + model = keras.models.Sequential() + model.add(keras.layers.Dense(hidden_dim, + input_shape=(input_dim,))) + model.add(keras.layers.Dense(output_dim)) + + x = np.random.random((num_samples, input_dim)) + y = np.random.random((num_samples, output_dim)) + parallel_model = multi_gpu_model(model, gpus=gpus) + + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit(x, y, epochs=epochs) + + +def multi_gpu_test_multi_io_model(): + print('####### test multi-io model') + num_samples = 1000 + input_dim_a = 10 + input_dim_b = 5 + output_dim_a = 1 + output_dim_b = 2 + hidden_dim = 10 + gpus = 8 + epochs = 2 + + input_a = keras.Input((input_dim_a,)) + input_b = keras.Input((input_dim_b,)) + a = keras.layers.Dense(hidden_dim)(input_a) + b = keras.layers.Dense(hidden_dim)(input_b) + c = keras.layers.concatenate([a, b]) + output_a = keras.layers.Dense(output_dim_a)(c) + output_b = keras.layers.Dense(output_dim_b)(c) + model = keras.models.Model([input_a, input_b], [output_a, output_b]) + + a_x = np.random.random((num_samples, input_dim_a)) + b_x = np.random.random((num_samples, input_dim_b)) + a_y = np.random.random((num_samples, output_dim_a)) + b_y = np.random.random((num_samples, output_dim_b)) + + parallel_model = multi_gpu_model(model, gpus=gpus) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) + + +def multi_gpu_test_invalid_devices(): + input_shape = (1000, 10) + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, + activation='relu', + input_shape=input_shape[1:])) + model.add(keras.layers.Dense(1, activation='sigmoid')) + model.compile(loss='mse', optimizer='rmsprop') + + x = np.random.random(input_shape) + y = np.random.random((input_shape[0], 1)) + with pytest.raises(ValueError): + parallel_model = multi_gpu_model(model, gpus=10) + parallel_model.fit(x, y, epochs=2) + + +def multi_gpu_application_np_array_benchmark(): + print('####### Xception benchmark - np i/o') + model_cls = keras.applications.Xception + + num_samples = 1000 + height = 224 + width = 224 + num_classes = 1000 + epochs = 4 + batch_size = 40 + x = np.random.random((num_samples, height, width, 3)) + y = np.random.random((num_samples, num_classes)) + + # Baseline + model = model_cls(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + # Training + start_time = time.time() + model.fit(x, y, epochs=epochs) + total_time = time.time() - start_time + print('baseline training:', total_time) + + # Inference + start_time = time.time() + model.predict(x) + total_time = time.time() - start_time + print('baseline inference:', total_time) + + for i in range(8, 9): + K.clear_session() + with tf.device('/cpu:0'): + model = model_cls(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + parallel_model = multi_gpu_model(model, gpus=i) + parallel_model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + start_time = time.time() + parallel_model.fit(x, y, epochs=epochs, batch_size=batch_size) + total_time = time.time() - start_time + print('%d gpus training:' % i, total_time) + + # Inference + start_time = time.time() + parallel_model.predict(x, batch_size=batch_size) + total_time = time.time() - start_time + print('%d gpus inference:' % i, total_time) + + +def multi_gpu_application_folder_generator_benchmark(): + """Before running this test: + + wget https://s3.amazonaws.com/img-datasets/cats_and_dogs_small.zip + unzip cats_and_dogs_small.zip + """ + print('####### Xception benchmark - folder generator i/o') + model_cls = keras.applications.Xception + + height = 150 + width = 150 + num_classes = 2 + epochs = 3 + steps_per_epoch = 100 + batch_size = 64 + + # Baseline + model = model_cls(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + datagen = ImageDataGenerator( + rotation_range=40, + width_shift_range=0.2, + height_shift_range=0.2, + shear_range=0.2, + zoom_range=0.2, + horizontal_flip=True, + fill_mode='nearest') + train_dir = '/home/ubuntu/cats_and_dogs_small/train' # Change this + train_gen = datagen.flow_from_directory( + train_dir, + target_size=(height, width), + batch_size=batch_size, + class_mode='categorical') + + # Training + start_time = time.time() + model.fit_generator(train_gen, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + workers=4) + total_time = time.time() - start_time + print('baseline training:', total_time) + + for i in range(2, 9): + K.clear_session() + with tf.device('/cpu:0'): + model = model_cls(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + parallel_model = multi_gpu_model(model, gpus=i) + parallel_model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + train_gen = datagen.flow_from_directory( + train_dir, + target_size=(height, width), + batch_size=batch_size, + class_mode='categorical') + + start_time = time.time() + parallel_model.fit_generator( + train_gen, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + workers=4 * i) + total_time = time.time() - start_time + print('%d gpus training:' % i, total_time) + + +if __name__ == '__main__': + multi_gpu_test_simple_model() + multi_gpu_test_multi_io_model() + multi_gpu_test_invalid_devices() + multi_gpu_application_np_array_benchmark() + multi_gpu_application_folder_generator_benchmark()