Skip to content

Commit

Permalink
Implement multigpu processing using weird trick (data parallel for lo…
Browse files Browse the repository at this point in the history
…cal layers, model parallel for fully-connected layers)

Remove compatibility with cudanet backend for imageset based models
Remove mpi-based parallelism
Remove dependency on NoPar structures
Fix check_grad to initialize with backend
Remove need to link and initialize layers separately
Add documentation for multiple gpu usage, device id specification
  • Loading branch information
apark263 committed Jul 10, 2015
1 parent 1ac5567 commit 1982929
Show file tree
Hide file tree
Showing 40 changed files with 1,613 additions and 696 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ neon.sublime-workspace
*.pkl
*.so
*.swp
*.prof
.DS_Store
.tox
*.prof
neon.egg-info
src
build
Expand Down
9 changes: 2 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,13 @@ else
ifeq ($(GPU), cudanet)
INSTALL_REQUIRES := $(INSTALL_REQUIRES) \
'git+https://github.com/NervanaSystems/cuda-convnet2.git\#egg=cudanet>=0.2.7' \
'pycuda>=2014.1'
'pycuda>=2015.1'
endif
ifeq ($(GPU), nervanagpu)
INSTALL_REQUIRES := $(INSTALL_REQUIRES) \
'git+https://github.com/NervanaSystems/nervanagpu.git\#egg=nervanagpu>=0.3.2'
'git+https://github.com/NervanaSystems/nervanagpu.git\#egg=nervanagpu>=0.3.3'
endif
endif
ifeq ($(DIST), 0)
NOSE_ATTRS := $(NOSE_ATTRS),'!dist'
else
INSTALL_REQUIRES := $(INSTALL_REQUIRES) 'mpi4py>=1.3.1'
endif

.PHONY: default build develop install uninstall test test_all sanity speed \
grad all clean_pyc clean doc html style lint bench dist publish_doc \
Expand Down
15 changes: 5 additions & 10 deletions bin/neon
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ def parse_args():
help=' 1 for stochastic rounding, 0 for deterministic')
parser.add_argument('-n', '--nrv', action='store_true',
help='Attempt to run using the Nervana Engine hardware')
parser.add_argument('-p', '--datapar', action='store_true',
help='Use parallelization to partition the data over '
'multiple nodes')
parser.add_argument('-m', '--modelpar', action='store_true',
help='Use parallelization to partition the model over '
'multiple nodes')
parser.add_argument('-l', '--live', action='store_true',
help='Perform inference on live data')
parser.add_argument('-f', '--flexpoint', action='store_true',
Expand All @@ -71,8 +65,10 @@ def parse_args():
parser.add_argument('-r', '--rng_seed', type=int,
help='Seed the random number generator for the backend'
' with the specified value.')
parser.add_argument('-i', '--device_id', type=int,
help='Select accelerator device id to run process on')
parser.add_argument('-i', '--device_id', type=int, nargs='+',
help='Select accelerator device id(s) to run process '
'on. (If specifying multiple devices, must occur at '
'end of command)')
parser.add_argument('-e', '--numerr_handling', action='store',
type=yaml.load, help='Set how numeric errors are '
'handled. Python dict syntax, parameters are the same'
Expand Down Expand Up @@ -193,9 +189,8 @@ def main():
# carry out the experiment
if not hasattr(experiment, 'backend') or experiment.backend is None:
backend = gen_backend(model=experiment.model, gpu=args.gpu,
nrv=args.nrv, datapar=args.datapar,
nrv=args.nrv,
stochastic_round=args.rounding,
modelpar=args.modelpar,
flexpoint=args.flexpoint,
rng_seed=args.rng_seed,
numerr_handling=args.numerr_handling,
Expand Down
117 changes: 88 additions & 29 deletions doc/source/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,106 @@
.. limitations under the License.
.. ---------------------------------------------------------------------------
Distributed Implementations using MPI
=====================================
Distributed Implementations using multiple GPUs
===============================================
In an effort to reduce the amount of time it takes to train and run models, it
is often advantageous to split the computation across several processes and
nodes so that they can be run in parallel.
is often advantageous to split the computation across several devices so that
they can be run in parallel. We have implemented multi-GPU support in neon
using the strategy described by Krizhevsky in [AK2014]_.

In neon, we currently have some preliminary support for this via
`MPI <http://www.open-mpi.org/>`_ and
`mpi4py <https://github.com/mpi4py/mpi4py>`_ (see :ref:`mpi_install`) to
install the required dependencies.
Note that we only support parallel computation with multiple GPUs and not on
multiple CPUs. Moreover, multi-GPU computation is only supported via our
``nervanagpu`` backend, which requires Maxwell class devices.

Note that distributed processing support in neon is still very experimental.
Performance speed-ups (at increasing scale) are still forthcoming.
The parallel implementation used in neon has been tested on up to 8 gpus.

Parallelization Model
---------------------
The "weird trick" parallelization model implemented by Krizhevsky uses data
parallel mode in local layers (convolutional, pooling), where the activations
outnumber the model parameters, and model parallel mode in fully connected
layers, where the model parameters outnumber the activations.

Available Models
----------------
In data-parallel mode, activations are fragmented and sent to different
devices, each of which contains a replica of the model parameters. During the
model's update step, the local parameter gradients are shared with each other
to generate a total gradient that is applied using the learning rule for that
layer.

Existing Models and Datasets can be parallelized by adding the ``--datapar`` or
``--modelpar`` command line parameters.
In model-parallel mode, activations are shared so that each device receives the
same input activations. Each device only retains a fragment of the model
parameters (a slice of the weight matrix), which are used to compute a portion
of the output activations. Output activations are then combined to generate
the replica activations that are used for the next layer.

In the ``--datapar`` (data parallel) approach, data examples are partitioned
and distributed across multiple processes. A separate model replica lives on
each process, and parameter values are synchronized across the models
to ensure each replica remains (eventually) consistent.
Requirements
------------
In order to parallelize across ``N`` nodes, the following conditions must be
satisfied:

In the ``--modelpar`` (model parallel) approach, layer nodes are partitioned
and distributed across multiple processes. Activations are then communicated
between processes whose nodes are connected. At this time, we support model
parallelism on fully connected model layers only.
- In data parallel mode, the minibatch size must be a multiple of ``N``.
- In model parallel mode, the output of each fully connected layer must be a
multiple of ``N``

Parameter server based asynchronous SGD is not yet implemented, but please
contact us if this is something you need for your use case.
For example, an MLP with no convolutional layers that has 3 hidden layers with
6, 200, and 20 hidden nodes can be parallelized across at most 2 gpus (because
``GCD(6, 200, 20) == 2``). If the first layer had 12 hidden nodes, the model
could be parallelized across 4 gpus.

Examples
--------
Since alexnet has fully connected layers with outputs of 4096, 4096, and 1000,
it can be split across up to 8 gpus (``GCD(4096, 1000) = 8``) as long as the
minibatch supplied is divisible by 8.

The following example illustrates how to train a data parallel convnet on the
MNIST dataset using 4 neon processes (on the same host):
Known Issues
============
Dropout Layers
--------------
Dropout layers occur between fully connected layers, which have replicated
activations across devices. However, since the binary masks used for dropout
are generated on device, each activation replica undergoes a different random
masking. This leads to slightly different results when training the same model
in parallel mode versus single-gpu mode. One way to mitigate this difference
would be to share masks during fprop, but this would introduce additional
communication overhead, and in practice we do not observe a penalty in network
performance with the current approach.


Batch Normalization
-------------------
For convolutional networks, using batch normalization with multiple gpus leads
to faster convergence compared to using a single gpu. This is because each
device is seeing only a portion of the overall batch, and the fragment batch
statistics are not shared during fprop. In our implementation, we average
batch norm parameter gradients prior to updating to ensure that parameters stay
consistent across model replicas.

In fully connected layers, since activations are replicated on each device, the
batch normalization parameters should be identical without need for sharing.

Usage
-----

The following example illustrates how to train a convnet on the MNIST dataset
across 2 gpus (devices selected by default):

.. code-block:: bash
neon --gpu nervanagpu2 examples/convnet/mnist-small.yaml
The following example illustrates how to train the same convnet with 2 gpus,
but specifying devices 1 and 2 (Note that the device_ids specified here do not
necessarily correspond to how they appear when running ``nvidia-smi``):

.. code-block:: bash
neon --gpu nervanagpu2 examples/convnet/mnist-small.yaml --device_id 1 2
The following example illustrates how to train a convnet on the i1k alexnet
model included with neon across 4 gpus:

.. code-block:: bash
mpirun -n 4 neon --datapar examples/convnet/mnist-small.yaml
neon --gpu nervanagpu4 examples/convnet/i1k-alexnet-fp32.yaml
.. [AK2014] Alex Krizhevsky, One weird trick for parallelizing convolutional neural networks. http://arxiv.org/abs/1404.5997
11 changes: 5 additions & 6 deletions examples/run_integration_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,16 @@ mkdir -p "$(dirname $LOG_FILE)"

cpu_yaml=("${THIS_DIR}/recurrent/mobydick-lstm-small.yaml" \
"${THIS_DIR}/recurrent/mobydick-rnn-small.yaml")
hpu_yaml=("${THIS_DIR}/convnet/i1k-alexnet-fp16.yaml")
gpu_yaml=("${THIS_DIR}/convnet/i1k-alexnet-fp32.yaml")
gpu_yaml=("${THIS_DIR}/convnet/i1k-alexnet-fp16.yaml" \
"${THIS_DIR}/convnet/i1k-alexnet-fp32.yaml")
all_yaml=("${THIS_DIR}/convnet/mnist-small.yaml" \
"${THIS_DIR}/mlp/mnist-small.yaml" \
"${THIS_DIR}/convnet/cifar10-small.yaml" \
"${THIS_DIR}/mlp/cifar10-small.yaml")

cpu_back=("cpu")
hpu_back=("nervanagpu")
gpu_back=("cudanet" "nervanagpu")
all_back=("cpu" "cudanet" "nervanagpu")
gpu_back=("nervanagpu")
all_back=("cpu" "nervanagpu")


run_yamls()
Expand Down Expand Up @@ -77,7 +76,7 @@ run_yamls()
done
}

run_yamls hpu_yaml[@] hpu_back[@]
# run_yamls hpu_yaml[@] hpu_back[@]
run_yamls gpu_yaml[@] gpu_back[@]
run_yamls cpu_yaml[@] cpu_back[@]
run_yamls all_yaml[@] all_back[@]
66 changes: 20 additions & 46 deletions neon/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@

# import shortcuts
from neon.backends.cpu import CPU
from neon.backends.par import NoPar, ModelPar, DataPar


def gen_backend(model=None, gpu=None, nrv=False, datapar=False, modelpar=False,
flexpoint=False, rng_seed=None, numerr_handling=None,
half=False, stochastic_round=0, device_id=None):
def gen_backend(model=None, gpu=None, nrv=False, flexpoint=False,
rng_seed=None, numerr_handling=None, half=False,
stochastic_round=0, device_id=None):
"""
Construct and return a backend instance of the appropriate type based on
the arguments given. With no parameters, a single CPU core, float32
Expand All @@ -48,18 +47,6 @@ def gen_backend(model=None, gpu=None, nrv=False, datapar=False, modelpar=False,
for computation (must be installed on the
system). Defaults to False which implies a CPU
based backend.
datapar (bool, optional): Set to True to ensure that data is
partitioned and each chunk is processed in
parallel on different compute cores. Requires
mpi4py. Defaults to False which implies that
all data will be processed sequentially on a
single compute core.
modelpar (bool, optional): Set to True to ensure that the nodes in each
model layer are partitioned and distributed
across multiple compute cores. Requires
mpi4py. Defaults to False which implies
that all nodes in all model layers will be
processed by the same single compute core.
flexpoint (bool, optional): If True, attempt to use FlexPoint(TM)
element typed data instead of the default
float32 which is in place if set to False.
Expand Down Expand Up @@ -98,22 +85,6 @@ def gen_backend(model=None, gpu=None, nrv=False, datapar=False, modelpar=False,
logger = logging.getLogger(__name__)
gpuflag = False

if datapar and modelpar:
raise NotImplementedError('Hybrid parallelization scheme not '
'implemented yet. Try with at most one of'
'datapar or modelpar')
if modelpar:
par = ModelPar()
elif datapar:
par = DataPar()
else:
par = NoPar()

if par.device_id is not None:
if device_id is not None:
logger.warn('Ignoring device id specified in command line.')
device_id = par.device_id

if gpu is not None:
gpu = gpu.lower()
if sys.platform.startswith("linux"):
Expand All @@ -130,23 +101,27 @@ def gen_backend(model=None, gpu=None, nrv=False, datapar=False, modelpar=False,
except ImportError:
logger.warning("cudanet not found, can't run via GPU")
gpuflag = False
elif gpuflag and gpu == 'nervanagpu':
elif gpuflag and gpu.startswith('nervanagpu'):
try:
import nervanagpu # noqa
try:
# import pycuda.autoinit
import pycuda.driver as drv
drv.init()
device_id = device_id if device_id is not None else 0
global ctx
ctx = drv.Device(device_id).make_context()
import atexit
atexit.register(ctx.pop)
from neon.backends.gpu import GPU
be_name = 'NervanaGPU'
be = GPU(rng_seed=rng_seed,
stochastic_round=stochastic_round,
device_id=device_id)
if gpu == 'nervanagpu':
device_id = 0 if device_id is None else device_id[0]
from neon.backends.gpu import GPU
be = GPU(rng_seed=rng_seed,
stochastic_round=stochastic_round,
device_id=device_id)
else:
from neon.backends.mgpu import MGPU
num_dev = int(gpu.strip('nervanagpu'))
if device_id is not None and len(device_id) != num_dev:
raise RuntimeError("Incorrect number of devices"
" specified ", device_id,
num_dev)
be = MGPU(rng_seed=rng_seed,
stochastic_round=stochastic_round,
device_id=device_id, num_dev=num_dev)
except ImportError:
logger.warning("pycuda error, can't run via GPU")
gpuflag = False
Expand Down Expand Up @@ -176,5 +151,4 @@ def gen_backend(model=None, gpu=None, nrv=False, datapar=False, modelpar=False,
logger.info("{} backend, RNG seed: {}, numerr: {}".format
(be_name, rng_seed, numerr_handling))

par.associate(be)
return be
25 changes: 19 additions & 6 deletions neon/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Backend(YAMLable):
Notes:
See the list of `implemented backends </backends.html>`_
"""
is_dist = False

def empty(self, shape, dtype=None, persist_values=True):
"""
Expand Down Expand Up @@ -1210,17 +1211,29 @@ def exp_mavg(self, mavg, newval, rho):
"""
raise NotImplementedError()

def distribute(self, data, dtype):
return self.par.distribute(data, dtype)
def empty_like(self, ary, dtype=None, persist_values=True):
return self.empty(ary.shape, dtype=dtype,
persist_values=persist_values)

def rank(self):
return self.par.rank()
def zeros_like(self, ary, dtype=None, persist_values=True):
return self.zeros(ary.shape, dtype=dtype,
persist_values=persist_values)

def set(self, tensor, data):
tensor[:] = data

def is_distributed(self):
return self.par.is_distributed()
return False

def reduce_tensor(self, tensor):
return self.par.reduce_tensor(tensor)
return tensor.asnumpyarray()

def scatter(self, src, dest):
dest.copy_from(src)

def allocate_fragment(self, buf_shape, dtype=None, persist_values=True):
return self.empty(buf_shape, dtype=dtype,
persist_values=persist_values)


class Tensor(object):
Expand Down
Loading

0 comments on commit 1982929

Please sign in to comment.