Skip to content

Commit

Permalink
allow partial set paramter in module (#2482)
Browse files Browse the repository at this point in the history
allow partial set parameter in module
  • Loading branch information
antinucleon committed Jun 20, 2016
1 parent ba57aa9 commit 0475454
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 8 additions & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
"""
raise NotImplementedError()

def set_params(self, arg_params, aux_params):
def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
"""Assign parameter and aux state values.
Parameters
Expand All @@ -461,9 +461,15 @@ def set_params(self, arg_params, aux_params):
Dictionary of name to value (`NDArray`) mapping.
aux_params : dict
Dictionary of name to value (`NDArray`) mapping.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
"""
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
allow_missing=False, force_init=True)
allow_missing=allow_missing, force_init=force_init)

def save_params(self, fname):
"""Save model parameters to file.
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,16 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
def _impl(name, arr, cache):
"""Internal helper for parameter initialization"""
if cache is not None:
if cache.has_key(name):
if name in cache:
cache_arr = cache[name]

# just in case the cached array is just the target itself
if cache_arr is not arr:
cache_arr.copyto(arr)
else:
assert allow_missing
initializer(name, arr)
if initializer != None:
initializer(name, arr)
else:
initializer(name, arr)

Expand Down

0 comments on commit 0475454

Please sign in to comment.