Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #113 from tqchen/master
Browse files Browse the repository at this point in the history
Make model training multiple device
  • Loading branch information
tqchen committed Sep 21, 2015
2 parents a48a9e0 + 0bb9546 commit a5941d8
Showing 1 changed file with 152 additions and 66 deletions.
218 changes: 152 additions & 66 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import symbol as sym
from . import optimizer as opt
from . import metric
from . import kvstore
from .context import Context, cpu
from .initializer import Xavier

Expand Down Expand Up @@ -74,12 +75,54 @@ def _check_arguments(symbol):
return (data_index, label_index)


def _train(symbol, ctx, input_shape,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, logger=None):
"""Inernal training function.
def _split_input_slice(input_shape, num_split):
"""Get input slice from the input shape.
Parameters
----------
input_shape : tuple
The input shape of the net.
num_split : int
The number of split we want to have.
Returns
-------
slices : list of slice
The split slices to get a specific slice.
shapes : list of tuples
The shape of each split slice.
Raises
------
ValueError
If there are two many splits such that some slice can be empty.
"""
batch_size = input_shape[0]
step = (batch_size + num_split - 1) / num_split
slices = []
shapes = []
for k in range(num_split):
begin = min(k * step, batch_size)
end = min((k+1) * step, batch_size)
if begin == end:
raise ValueError('Too many slices such that some splits are empty')
slices.append(slice(begin, end))
s = list(input_shape)
s[0] = end - begin
shapes.append(tuple(s))
return (slices, shapes)


def _train_multi_device(symbol, ctx, input_shape,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, logger=None):
"""Internal training function on multiple devices.
This function will also work for single device as well.
Parameters
----------
Expand Down Expand Up @@ -127,80 +170,121 @@ def _train(symbol, ctx, input_shape,
-----
This function will inplace update the NDArrays in arg_parans and aux_states.
"""
assert(len(ctx) == 1)
if logger is None:
logger = logging
# bind the symbol
train_exec = symbol.simple_bind(ctx[0], data=input_shape, grad_req='write')
# preparation
num_device = len(ctx)
logging.info('Start training with %d devices', num_device)

slices, shapes = _split_input_slice(input_shape, num_device)
train_execs = [symbol.simple_bind(ctx=c, data=s, grad_req='write')
for c, s in zip(ctx, shapes)]
arg_names = symbol.list_arguments()
aux_names = symbol.list_auxiliary_states()
arg_arrays = train_exec.arg_arrays
grad_arrays = train_exec.grad_arrays
aux_arrays = train_exec.aux_arrays
# copy initialized parameters to executor parameters
for key, weight in zip(arg_names, arg_arrays):
if key in arg_params:
arg_params[key].copyto(weight)
for key, weight in zip(aux_names, aux_arrays):
if key in aux_params:
aux_params[key].copyto(weight)
# setup helper data structures
# data structure
arg_blocks = [
[x.arg_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].arg_arrays))]
grad_blocks = [
[x.grad_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].grad_arrays))]
aux_blocks = [
[x.aux_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].aux_arrays))]
for name, block in zip(arg_names, arg_blocks):
if name in arg_params:
for w in block:
arg_params[name].copyto(w)
for name, block in zip(aux_names, aux_blocks):
if name in aux_params:
for w in block:
aux_params[name].copyto(w)
# ky value store
kv = kvstore.create() if num_device != 1 else None
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg, grad = pair
if kv and grad[0] is not None:
kv.init(index, arg[0])
# Input and output data structure
data_index, label_index = _check_arguments(symbol)
data_array, label_array = arg_arrays[data_index], arg_arrays[label_index]
out_array = train_exec.outputs[0]
out_cpu_array = nd.zeros(out_array.shape)
arg_blocks = list(zip(arg_arrays, grad_arrays))

for i in range(begin_round, end_round):
# training phase
merged_shape = list(train_execs[0].outputs[0].shape)
merged_shape[0] = input_shape[0]
merged_shape = tuple(merged_shape)
out_cpu_array = nd.zeros(merged_shape, cpu())

# Now start training
for iteration in range(begin_round, end_round):
# Training phase
tic = time.time()
train_data.reset()
optimizer.begin_round(i)
optimizer.begin_round(iteration)
eval_metric.reset()

# Iterate over training data.
for data, label in train_data:
label.copyto(label_array)
data.copyto(data_array)
train_exec.forward()
out_array.copyto(out_cpu_array)
train_exec.backward()
# Copy data into the target
for target, islice in zip(arg_blocks[label_index], slices):
label[islice].copyto(target)
for target, islice in zip(arg_blocks[data_index], slices):
data[islice].copyto(target)
# forward backward pass
for texec, islice in zip(train_execs, slices):
texec.forward()
texec.outputs[0].copyto(out_cpu_array[islice])
for texec in train_execs:
texec.backward()
# update the parameters
for index, block in enumerate(arg_blocks):
weight, grad = block
if grad is not None:
optimizer.update(index, weight, grad)
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# Gradient synchronization
if kv:
# push gradient
kv.push(index, grad_list)
# pull back the sum, to the same locations.
kv.pull(index, grad_list)
# optimize
for w, g in zip(arg_list, grad_list):
optimizer.update(index, w, g)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)

name, value = eval_metric.get()
logger.info('Iteration[%d] Train-%s=%f', i, name, value)
logger.info('Iteration[%d] Train-%s=%f', iteration, name, value)
toc = time.time()
logger.info('Iteration[%d] Time cost=%.3f', i, (toc - tic))

# evaluation phase
if eval_data is not None:
logger.info('Iteration[%d] Time cost=%.3f', iteration, (toc - tic))
# evaluation
if eval_data:
eval_metric.reset()
eval_data.reset()
for data, label in eval_data:
data.copyto(data_array)
# TODO(bing): add is_train=False
train_exec.forward(is_train=False)
out_array.copyto(out_cpu_array)
eval_metric.update(out_array, label)

# Copy data into the target
for target, islice in zip(arg_blocks[label_index], slices):
label[islice].copyto(target)
for target, islice in zip(arg_blocks[data_index], slices):
data[islice].copyto(target)
# forward pass
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=False)
texec.outputs[0].copyto(out_cpu_array[islice])
eval_metric.update(out_cpu_array, label)
name, value = eval_metric.get()
logger.info('Iteration[%d] Validation-%s=%f', i, name, value)
logger.info('Iteration[%d] Validation-%s=%f', iteration, name, value)

if iter_end_callback or i + 1 == end_round:
if iter_end_callback or iteration + 1 == end_round:
# copy data back to cpu
for key, weight in zip(arg_names, arg_arrays):
if key in arg_params:
weight.copyto(arg_params[key])
for key, arr in zip(aux_names, aux_arrays):
arr.copyto(aux_params[key])
for name, block in zip(arg_names, arg_blocks):
if name in arg_params:
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(arg_params[name])
for name, block in zip(aux_names, aux_blocks):
if name in aux_params:
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(aux_params[name])
if iter_end_callback:
iter_end_callback(i, symbol, arg_params, aux_params)
# end of the function
iter_end_callback(iteration, symbol, arg_params, aux_params)
# end of all iterations
return


Expand Down Expand Up @@ -332,6 +416,8 @@ def __init__(self, symbol, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
arg_params=None, aux_params=None,
**kwargs):
# check if symbol contain duplicated names.
_check_arguments(symbol)
# basic configuration
self.symbol = symbol
if ctx is None:
Expand Down Expand Up @@ -467,14 +553,14 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
batch_size = input_shape[0]
optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs))
# do training
_train(self.symbol, self.ctx, input_shape,
self.arg_params, self.aux_params,
begin_round=0, end_round=self.num_round,
optimizer=optimizer,
train_data=X, eval_data=eval_data,
eval_metric=eval_metric,
iter_end_callback=iter_end_callback,
logger=logger)
_train_multi_device(self.symbol, self.ctx, input_shape,
self.arg_params, self.aux_params,
begin_round=0, end_round=self.num_round,
optimizer=optimizer,
train_data=X, eval_data=eval_data,
eval_metric=eval_metric,
iter_end_callback=iter_end_callback,
logger=logger)

def save(self, prefix, iteration=None):
"""Checkpoint the model checkpoint into file.
Expand Down

0 comments on commit a5941d8

Please sign in to comment.