Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Module] add save load optimizer states (#4408)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 12d51e8 commit 528dee0
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 12 deletions.
28 changes: 28 additions & 0 deletions python/mxnet/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@
import time
from .model import save_checkpoint

def module_checkpoint(mod, prefix, period=1, save_optimizer_states=False):
"""Callback to checkpoint Module to prefix every epoch.
Parameters
----------
mod : subclass of BaseModule
The module to checkpoint.
prefix : str
The file prefix to checkpoint to
period : int
How many epochs to wait before checkpointing. Default is 1.
save_optimizer_states : bool
Whether to save optimizer states for continue training
Returns
-------
callback : function
The callback function that can be passed as iter_end_callback to fit.
"""
period = int(max(1, period))
# pylint: disable=unused-argument
def _callback(iter_no, sym=None, arg=None, aux=None):
"""The checkpoint function."""
if (iter_no + 1) % period == 0:
mod.save_checkpoint(prefix, iter_no + 1, save_optimizer_states)
return _callback


def do_checkpoint(prefix, period=1):
"""Callback to checkpoint the model to prefix every epoch.
Expand Down
25 changes: 25 additions & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, handle):
"""
assert isinstance(handle, KVStoreHandle)
self.handle = handle
self._updater = None
self._updater_func = None

def __del__(self):
Expand Down Expand Up @@ -294,6 +295,29 @@ def num_workers(self):
check_call(_LIB.MXKVStoreGetGroupSize(self.handle, ctypes.byref(size)))
return size.value

def save_optimizer_states(self, fname):
"""Save optimizer (updater) state to file
Parameters
----------
fname : str
Path to output states file.
"""
assert self._updater is not None, "Cannot save states for distributed training"
with open(fname, 'wb') as fout:
fout.write(self._updater.get_states())

def load_optimizer_states(self, fname):
"""Load optimizer (updater) state from file
Parameters
----------
fname : str
Path to input states file.
"""
assert self._updater is not None, "Cannot save states for distributed training"
self._updater.set_states(open(fname, 'rb').read())

def _set_updater(self, updater):
"""Set a push updater into the store.
Expand Down Expand Up @@ -322,6 +346,7 @@ def _set_updater(self, updater):
[[ 6. 6. 6.]
[ 6. 6. 6.]]
"""
self._updater = updater
_updater_proto = ctypes.CFUNCTYPE(
None, ctypes.c_int, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
self._updater_func = _updater_proto(_updater_wrapper(updater))
Expand Down
99 changes: 97 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access, too-many-branches
# pylint: disable=too-many-public-methods
"""A `Module` implement the `BaseModule` API by wrapping a `Symbol` and one or
more `Executor` for data parallelization.
"""
Expand All @@ -11,6 +12,7 @@

from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform

from .base_module import BaseModule
Expand Down Expand Up @@ -71,11 +73,69 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
self._kvstore = None
self._update_on_kvstore = None
self._updater = None
self._preload_opt_states = None

self._exec_group = None
self._data_shapes = None
self._label_shapes = None

@staticmethod
def load(prefix, epoch, load_optimizer_states=False, **kwargs):
"""Create a model from previously saved checkpoint.
Parameters
----------
prefix : str
path prefix of saved model files. You should have
"prefix-symbol.json", "prefix-xxxx.params", and
optionally "prefix-xxxx.states", where xxxx is the
epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
data_names : list of str
Default is `('data')` for a typical model used in image classification.
label_names : list of str
Default is `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Default is `logging`.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
fixed_param_names: list of str
Default `None`, indicating no network parameters are fixed.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
mod = Module(symbol=sym, **kwargs)
mod._arg_params = args
mod._aux_params = auxs
mod.params_initialized = True
if load_optimizer_states:
mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
return mod

def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
"""Save current progress to checkpoint.
Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
Parameters
----------
prefix : str
The file prefix to checkpoint to
epoch : int
The current epoch number
save_optimizer_states : bool
Whether to save optimizer states for continue training
"""
self._symbol.save('%s-symbol.json'%prefix)
self.save_params('%s-%04d.params'%(prefix, epoch))
if save_optimizer_states:
self.save_optimizer_states('%s-%04d.states'%(prefix, epoch))

def _reset_bind(self):
"""Internal function to reset binded state."""
self.binded = False
Expand Down Expand Up @@ -341,8 +401,6 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
self._update_on_kvstore = update_on_kvstore
self._updater = None

if not update_on_kvstore:
self._updater = opt.get_updater(optimizer)
if kvstore:
# copy initialized local parameters to kvstore
_initialize_kvstore(kvstore=kvstore,
Expand All @@ -352,9 +410,15 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(self._optimizer)
else:
self._updater = opt.get_updater(optimizer)

self.optimizer_initialized = True

if self._preload_opt_states is not None:
self.load_optimizer_states(self._preload_opt_states)
self._preload_opt_states = None

def borrow_optimizer(self, shared_module):
"""Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Expand Down Expand Up @@ -472,6 +536,37 @@ def _sync_params_from_devices(self):
"""
self._exec_group.get_params(self._arg_params, self._aux_params)

def save_optimizer_states(self, fname):
"""Save optimizer (updater) state to file
Parameters
----------
fname : str
Path to output states file.
"""
assert self.optimizer_initialized

if self._update_on_kvstore:
self._kvstore.save_optimizer_states(fname)
else:
with open(fname, 'wb') as fout:
fout.write(self._updater.get_states())

def load_optimizer_states(self, fname):
"""Load optimizer (updater) state from file
Parameters
----------
fname : str
Path to input states file.
"""
assert self.optimizer_initialized

if self._update_on_kvstore:
self._kvstore.load_optimizer_states(fname)
else:
self._updater.set_states(open(fname, 'rb').read())

def install_monitor(self, mon):
""" Install monitor on all executors """
assert self.binded
Expand Down
29 changes: 21 additions & 8 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=fixme, invalid-name, unused-argument, too-many-arguments, no-name-in-module
"""Common Optimization algorithms with regularizations."""
import math
import pickle
from .ndarray import NDArray, zeros, clip, sqrt
from .ndarray import sgd_update, sgd_mom_update, adam_update
from .random import normal
Expand Down Expand Up @@ -678,6 +679,25 @@ def update(self, index, weight, grad, state):
# backward compatibility wrapper for Optimizer.CreateOptimizer
create = Optimizer.create_optimizer

class Updater(object):
"""updater for kvstore"""
def __init__(self, optimizer):
self.optimizer = optimizer
self.states = {}

def __call__(self, index, grad, weight):
"""Update weight given gradient and index"""
if index not in self.states:
self.states[index] = self.optimizer.create_state(index, weight)
self.optimizer.update(index, weight, grad, self.states[index])

def set_states(self, states):
"""set updater states"""
self.states = pickle.loads(states)

def get_states(self):
"""get updater states"""
return pickle.dumps(self.states)

def get_updater(optimizer):
"""Return a clossure of the updater needed for kvstore
Expand All @@ -692,11 +712,4 @@ def get_updater(optimizer):
updater: function
The clossure of the updater
"""
states = dict()

def updater(index, grad, weight):
"""updater for kvstore"""
if index not in states:
states[index] = optimizer.create_state(index, weight)
optimizer.update(index, weight, grad, states[index])
return updater
return Updater(optimizer)
40 changes: 40 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,45 @@ def test_module_layout():
for x in mod.get_outputs(merge_multi_context=False)[0]:
assert x.shape == hdshape

def test_save_load():
def dict_equ(a, b):
assert set(a) == set(b)
for k in a:
assert (a[k].asnumpy() == b[k].asnumpy()).all()

sym = mx.sym.Variable('data')
sym = mx.sym.FullyConnected(sym, num_hidden=100)

# single device
mod = mx.mod.Module(sym, ('data',))
mod.bind(data_shapes=[('data', (10, 10))])
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
mod.update()
mod.save_checkpoint('test', 0, save_optimizer_states=True)

mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',))
mod2.bind(data_shapes=[('data', (10, 10))])
mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
assert mod._symbol.tojson() == mod2._symbol.tojson()
dict_equ(mod.get_params()[0], mod2.get_params()[0])
dict_equ(mod._updater.states, mod2._updater.states)

# multi device
mod = mx.mod.Module(sym, ('data',), context=[mx.cpu(0), mx.cpu(1)])
mod.bind(data_shapes=[('data', (10, 10))])
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
mod.update()
mod.save_checkpoint('test', 0, save_optimizer_states=True)

mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',))
mod2.bind(data_shapes=[('data', (10, 10))])
mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
assert mod._symbol.tojson() == mod2._symbol.tojson()
dict_equ(mod.get_params()[0], mod2.get_params()[0])
dict_equ(mod._kvstore._updater.states, mod2._updater.states)

if __name__ == '__main__':
test_save_load()
test_module_layout()
4 changes: 2 additions & 2 deletions tools/im2rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def parse_args():
help='If true recursively walk through subdirs and assign an unique label\
to images in each folder. Otherwise only include images in the root folder\
and give them label 0.')
cgroup.add_argument('--shuffle', default=True, help='If this is set as True, \
cgroup.add_argument('--shuffle', type=bool, default=True, help='If this is set as True, \
im2rec will randomize the image order in <prefix>.lst')

rgroup = parser.add_argument_group('Options for creating database')
Expand All @@ -220,7 +220,7 @@ def parse_args():
-1:Loads image as such including alpha channel.')
rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
help='specify the encoding of the images.')
rgroup.add_argument('--pack-label', default=False,
rgroup.add_argument('--pack-label', type=bool, default=False,
help='Whether to also pack multi dimensional label in the record file')
args = parser.parse_args()
args.prefix = os.path.abspath(args.prefix)
Expand Down

0 comments on commit 528dee0

Please sign in to comment.