diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 0916e2345fe4..0d47d2fc1e2c 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -137,15 +137,6 @@ def __init__(self, prefix=None, params=None): self._scope = _BlockScope(self) self._children = [] - def __setattr__(self, name, value): - """Registers parameters.""" - super(Block, self).__setattr__(name, value) - if isinstance(value, Block): - self.register_child(value) - - def _alias(self): - return self.__class__.__name__.lower() - def __repr__(self): s = '{name}(\n{modstr}\n)' modstr = '\n'.join([' ({key}): {block}'.format(key=key, @@ -154,20 +145,14 @@ def __repr__(self): return s.format(name=self.__class__.__name__, modstr=modstr) - @property - def params(self): - """Returns this `Block`'s parameter dictionary (does not include its - children's parameters).""" - return self._params + def __setattr__(self, name, value): + """Registers parameters.""" + super(Block, self).__setattr__(name, value) + if isinstance(value, Block): + self.register_child(value) - def collect_params(self): - """Returns a `ParameterDict` containing this `Block` and all of its - children's Parameters.""" - ret = ParameterDict(self._params.prefix) - ret.update(self.params) - for cld in self._children: - ret.update(cld.collect_params()) - return ret + def _alias(self): + return self.__class__.__name__.lower() @property def prefix(self): @@ -190,6 +175,47 @@ def name_scope(self): """ return self._scope + @property + def params(self): + """Returns this `Block`'s parameter dictionary (does not include its + children's parameters).""" + return self._params + + def collect_params(self): + """Returns a `ParameterDict` containing this `Block` and all of its + children's Parameters.""" + ret = ParameterDict(self._params.prefix) + ret.update(self.params) + for cld in self._children: + ret.update(cld.collect_params()) + return ret + + def save_params(self, filename): + """Save parameters to file. + + filename : str + Path to file. + """ + self.collect_params().save(filename, strip_prefix=self.prefix) + + def load_params(self, filename, ctx, allow_missing=False, + ignore_extra=False): + """Load parameters from file. + + filename : str + Path to parameter file. + ctx : Context or list of Context + Context(s) initialize loaded parameters on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represents in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this Block. + """ + self.collect_params().load(filename, ctx, allow_missing, ignore_extra, + self.prefix) + + def register_child(self, block): """Registers block as a child of self. `Block`s assigned to self as attributes will be registered automatically.""" diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 1bf48f93a6b8..981b78b721e7 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -5,6 +5,7 @@ from collections import OrderedDict import numpy as np + from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, context from ..context import Context @@ -425,24 +426,61 @@ def zero_grad(self): for i in self.values(): i.zero_grad() - def save(self, filename): + def save(self, filename, strip_prefix=''): + """Save parameters to file. + + filename : str + Path to parameter file. + strip_prefix : str, default '' + Strip prefix from parameter names before saving. + """ arg_dict = {} for param in self.values(): block = param.list_data() weight = sum(w.copyto(context.cpu()) for w in block) / len(block) - arg_dict[param.name] = weight + if not param.name.startswith(strip_prefix): + raise ValueError( + "Prefix %s is to be striped before saving, but Parameter " \ + "%s does not start with %s. If you are using Block.save_params, " \ + "This may be due to your Block shares parameters from other " \ + "Blocks or you forgot to use `with name_scope()`` during init. " \ + "Consider switching to Block.collect_params.save and " \ + "Block.collect_params.load instead."%( + strip_prefix, param.name, strip_prefix)) + arg_dict[param.name[len(strip_prefix):]] = weight ndarray.save(filename, arg_dict) - def load(self, filename, ctx, allow_missing=False, ignore_extra=False): - arg_dict = ndarray.load(filename) + def load(self, filename, ctx, allow_missing=False, + ignore_extra=False, restore_prefix=''): + """Load parameters from file. + + filename : str + Path to parameter file. + ctx : Context or list of Context + Context(s) initialize loaded parameters on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represents in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this ParameterDict. + restore_prefix : str, default '' + prepend prefix to names of stored parameters before loading. + """ + if restore_prefix: + for name in self.keys(): + assert name.startswith(restore_prefix), \ + "restore_prefix is %s but Parameters name %s does not start " \ + "with %s"%(restore_prefix, name, restore_prefix) + lprefix = len(restore_prefix) + arg_dict = {restore_prefix+k: v for k, v in ndarray.load(filename).items()} if not allow_missing: for name in self.keys(): assert name in arg_dict, \ - "Parameter %s is missing in file %s"%(name, filename) + "Parameter %s is missing in file %s"%(name[lprefix:], filename) for name in arg_dict: if name not in self._params: assert ignore_extra, \ "Parameter %s loaded from file %s is not present in ParameterDict"%( - name, filename) + name[lprefix:], filename) continue self[name]._load_init(arg_dict[name], ctx) diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py index cc1b2dd48553..58839785b9f2 100644 --- a/tests/python/unittest/test_nn.py +++ b/tests/python/unittest/test_nn.py @@ -39,6 +39,11 @@ def forward(self, x): net1.collect_params().initialize() net2(mx.nd.zeros((3, 5))) + net1.save_params('net1.params') + + net3 = Net(prefix='net3_') + net3.load_params('net1.params', mx.cpu()) + def test_basic(): model = nn.Sequential()