Skip to content

Commit

Permalink
add save_params and load_params to Block (apache#7097)
Browse files Browse the repository at this point in the history
* add save_params and load_params to Block

* fix
  • Loading branch information
piiswrong authored Jul 21, 2017
1 parent 799ed45 commit 37c1823
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 28 deletions.
70 changes: 48 additions & 22 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ def __init__(self, prefix=None, params=None):
self._scope = _BlockScope(self)
self._children = []

def __setattr__(self, name, value):
"""Registers parameters."""
super(Block, self).__setattr__(name, value)
if isinstance(value, Block):
self.register_child(value)

def _alias(self):
return self.__class__.__name__.lower()

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
Expand All @@ -154,20 +145,14 @@ def __repr__(self):
return s.format(name=self.__class__.__name__,
modstr=modstr)

@property
def params(self):
"""Returns this `Block`'s parameter dictionary (does not include its
children's parameters)."""
return self._params
def __setattr__(self, name, value):
"""Registers parameters."""
super(Block, self).__setattr__(name, value)
if isinstance(value, Block):
self.register_child(value)

def collect_params(self):
"""Returns a `ParameterDict` containing this `Block` and all of its
children's Parameters."""
ret = ParameterDict(self._params.prefix)
ret.update(self.params)
for cld in self._children:
ret.update(cld.collect_params())
return ret
def _alias(self):
return self.__class__.__name__.lower()

@property
def prefix(self):
Expand All @@ -190,6 +175,47 @@ def name_scope(self):
"""
return self._scope

@property
def params(self):
"""Returns this `Block`'s parameter dictionary (does not include its
children's parameters)."""
return self._params

def collect_params(self):
"""Returns a `ParameterDict` containing this `Block` and all of its
children's Parameters."""
ret = ParameterDict(self._params.prefix)
ret.update(self.params)
for cld in self._children:
ret.update(cld.collect_params())
return ret

def save_params(self, filename):
"""Save parameters to file.
filename : str
Path to file.
"""
self.collect_params().save(filename, strip_prefix=self.prefix)

def load_params(self, filename, ctx, allow_missing=False,
ignore_extra=False):
"""Load parameters from file.
filename : str
Path to parameter file.
ctx : Context or list of Context
Context(s) initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
"""
self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
self.prefix)


def register_child(self, block):
"""Registers block as a child of self. `Block`s assigned to self as
attributes will be registered automatically."""
Expand Down
50 changes: 44 additions & 6 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import OrderedDict
import numpy as np


from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, context
from ..context import Context
Expand Down Expand Up @@ -425,24 +426,61 @@ def zero_grad(self):
for i in self.values():
i.zero_grad()

def save(self, filename):
def save(self, filename, strip_prefix=''):
"""Save parameters to file.
filename : str
Path to parameter file.
strip_prefix : str, default ''
Strip prefix from parameter names before saving.
"""
arg_dict = {}
for param in self.values():
block = param.list_data()
weight = sum(w.copyto(context.cpu()) for w in block) / len(block)
arg_dict[param.name] = weight
if not param.name.startswith(strip_prefix):
raise ValueError(
"Prefix %s is to be striped before saving, but Parameter " \
"%s does not start with %s. If you are using Block.save_params, " \
"This may be due to your Block shares parameters from other " \
"Blocks or you forgot to use `with name_scope()`` during init. " \
"Consider switching to Block.collect_params.save and " \
"Block.collect_params.load instead."%(
strip_prefix, param.name, strip_prefix))
arg_dict[param.name[len(strip_prefix):]] = weight
ndarray.save(filename, arg_dict)

def load(self, filename, ctx, allow_missing=False, ignore_extra=False):
arg_dict = ndarray.load(filename)
def load(self, filename, ctx, allow_missing=False,
ignore_extra=False, restore_prefix=''):
"""Load parameters from file.
filename : str
Path to parameter file.
ctx : Context or list of Context
Context(s) initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this ParameterDict.
restore_prefix : str, default ''
prepend prefix to names of stored parameters before loading.
"""
if restore_prefix:
for name in self.keys():
assert name.startswith(restore_prefix), \
"restore_prefix is %s but Parameters name %s does not start " \
"with %s"%(restore_prefix, name, restore_prefix)
lprefix = len(restore_prefix)
arg_dict = {restore_prefix+k: v for k, v in ndarray.load(filename).items()}
if not allow_missing:
for name in self.keys():
assert name in arg_dict, \
"Parameter %s is missing in file %s"%(name, filename)
"Parameter %s is missing in file %s"%(name[lprefix:], filename)
for name in arg_dict:
if name not in self._params:
assert ignore_extra, \
"Parameter %s loaded from file %s is not present in ParameterDict"%(
name, filename)
name[lprefix:], filename)
continue
self[name]._load_init(arg_dict[name], ctx)
5 changes: 5 additions & 0 deletions tests/python/unittest/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def forward(self, x):
net1.collect_params().initialize()
net2(mx.nd.zeros((3, 5)))

net1.save_params('net1.params')

net3 = Net(prefix='net3_')
net3.load_params('net1.params', mx.cpu())


def test_basic():
model = nn.Sequential()
Expand Down

0 comments on commit 37c1823

Please sign in to comment.