-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add parallel_model.py for GPU_COUNT > 1
- Loading branch information
1 parent
ca1565e
commit 32abe21
Showing
1 changed file
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
""" | ||
Mask R-CNN | ||
Multi-GPU Support for Keras. | ||
Copyright (c) 2017 Matterport, Inc. | ||
Licensed under the MIT License (see LICENSE for details) | ||
Written by Waleed Abdulla | ||
Ideas and a small code snippets from these sources: | ||
https://github.com/fchollet/keras/issues/2436 | ||
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012 | ||
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/ | ||
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py | ||
""" | ||
|
||
import tensorflow as tf | ||
import keras.backend as K | ||
import keras.layers as KL | ||
import keras.models as KM | ||
|
||
|
||
class ParallelModel(KM.Model): | ||
"""Subclasses the standard Keras Model and adds multi-GPU support. | ||
It works by creating a copy of the model on each GPU. Then it slices | ||
the inputs and sends a slice to each copy of the model, and then | ||
merges the outputs together and applies the loss on the combined | ||
outputs. | ||
""" | ||
|
||
def __init__(self, keras_model, gpu_count): | ||
"""Class constructor. | ||
keras_model: The Keras model to parallelize | ||
gpu_count: Number of GPUs. Must be > 1 | ||
""" | ||
self.inner_model = keras_model | ||
self.gpu_count = gpu_count | ||
merged_outputs = self.make_parallel() | ||
super(ParallelModel, self).__init__(inputs=self.inner_model.inputs, | ||
outputs=merged_outputs) | ||
|
||
def __getattribute__(self, attrname): | ||
"""Redirect loading and saving methods to the inner model. That's where | ||
the weights are stored.""" | ||
if 'load' in attrname or 'save' in attrname: | ||
return getattr(self.inner_model, attrname) | ||
return super(ParallelModel, self).__getattribute__(attrname) | ||
|
||
def summary(self, *args, **kwargs): | ||
"""Override summary() to display summaries of both, the wrapper | ||
and inner models.""" | ||
super(ParallelModel, self).summary(*args, **kwargs) | ||
self.inner_model.summary(*args, **kwargs) | ||
|
||
def make_parallel(self): | ||
"""Creates a new wrapper model that consists of multiple replicas of | ||
the original model placed on different GPUs. | ||
""" | ||
# Slice inputs. Slice inputs on the CPU to avoid sending a copy | ||
# of the full inputs to all GPUs. Saves on bandwidth and memory. | ||
input_slices = {name: tf.split(x, self.gpu_count) | ||
for name, x in zip(self.inner_model.input_names, | ||
self.inner_model.inputs)} | ||
|
||
output_names = self.inner_model.output_names | ||
outputs_all = [] | ||
for i in range(len(self.inner_model.outputs)): | ||
outputs_all.append([]) | ||
|
||
# Run the model call() on each GPU to place the ops there | ||
for i in range(self.gpu_count): | ||
with tf.device('/gpu:%d' % i): | ||
with tf.name_scope('tower_%d' % i): | ||
# Run a slice of inputs through this replica | ||
zipped_inputs = zip(self.inner_model.input_names, | ||
self.inner_model.inputs) | ||
inputs = [ | ||
KL.Lambda(lambda s: input_slices[name][i], | ||
output_shape=lambda s: (None,) + s[1:])(tensor) | ||
for name, tensor in zipped_inputs] | ||
# Create the model replica and get the outputs | ||
outputs = self.inner_model(inputs) | ||
if not isinstance(outputs, list): | ||
outputs = [outputs] | ||
# Save the outputs for merging back together later | ||
for l, o in enumerate(outputs): | ||
outputs_all[l].append(o) | ||
|
||
# Merge outputs on CPU | ||
with tf.device('/cpu:0'): | ||
merged = [] | ||
for outputs, name in zip(outputs_all, output_names): | ||
# If outputs are numbers without dimensions, add a batch dim. | ||
def add_dim(tensor): | ||
"""Add a dimension to tensors that don't have any.""" | ||
if K.int_shape(tensor) == (): | ||
return KL.Lambda(lambda t: K.reshape(t, [1, 1]))(tensor) | ||
return tensor | ||
outputs = list(map(add_dim, outputs)) | ||
|
||
# Concatenate | ||
merged.append(KL.Concatenate(axis=0, name=name)(outputs)) | ||
return merged | ||
|
||
|
||
if __name__ == "__main__": | ||
# Testing code below. It creates a simple model to train on MNIST and | ||
# tries to run it on 2 GPUs. It saves the graph so it can be viewed | ||
# in TensorBoard. Run it as: | ||
# | ||
# python3 parallel_model.py | ||
|
||
import os | ||
import numpy as np | ||
import keras.optimizers | ||
from keras.datasets import mnist | ||
from keras.preprocessing.image import ImageDataGenerator | ||
|
||
GPU_COUNT = 2 | ||
|
||
# Root directory of the project | ||
ROOT_DIR = os.getcwd() | ||
|
||
# Directory to save logs and trained model | ||
MODEL_DIR = os.path.join(ROOT_DIR, "logs/parallel") | ||
|
||
def build_model(x_train, num_classes): | ||
# Reset default graph. Keras leaves old ops in the graph, | ||
# which are ignored for execution but clutter graph | ||
# visualization in TensorBoard. | ||
tf.reset_default_graph() | ||
|
||
inputs = KL.Input(shape=x_train.shape[1:], name="input_image") | ||
x = KL.Conv2D(32, (3, 3), activation='relu', padding="same", | ||
name="conv1")(inputs) | ||
x = KL.Conv2D(64, (3, 3), activation='relu', padding="same", | ||
name="conv2")(x) | ||
x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x) | ||
x = KL.Flatten(name="flat1")(x) | ||
x = KL.Dense(128, activation='relu', name="dense1")(x) | ||
x = KL.Dense(num_classes, activation='softmax', name="dense2")(x) | ||
|
||
return KM.Model(inputs, x, "digit_classifier_model") | ||
|
||
# Load MNIST Data | ||
(x_train, y_train), (x_test, y_test) = mnist.load_data() | ||
x_train = np.expand_dims(x_train, -1).astype('float32') / 255 | ||
x_test = np.expand_dims(x_test, -1).astype('float32') / 255 | ||
|
||
print('x_train shape:', x_train.shape) | ||
print('x_test shape:', x_test.shape) | ||
|
||
# Build data generator and model | ||
datagen = ImageDataGenerator() | ||
model = build_model(x_train, 10) | ||
|
||
# Add multi-GPU support. | ||
model = ParallelModel(model, GPU_COUNT) | ||
|
||
optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0) | ||
|
||
model.compile(loss='sparse_categorical_crossentropy', | ||
optimizer=optimizer, metrics=['accuracy']) | ||
|
||
model.summary() | ||
|
||
# Train | ||
model.fit_generator( | ||
datagen.flow(x_train, y_train, batch_size=64), | ||
steps_per_epoch=50, epochs=10, verbose=1, | ||
validation_data=(x_test, y_test), | ||
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR, | ||
write_graph=True)] | ||
) |