diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index 5549965c..4465d494 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -452,7 +452,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non """ raise NotImplementedError() - def set_params(self, arg_params, aux_params): + def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True): """Assign parameter and aux state values. Parameters @@ -461,9 +461,15 @@ def set_params(self, arg_params, aux_params): Dictionary of name to value (`NDArray`) mapping. aux_params : dict Dictionary of name to value (`NDArray`) mapping. + allow_missing : bool + If true, params could contain missing values, and the initializer will be + called to fill those missing params. + force_init : bool + If true, will force re-initialize even if already initialized. + """ self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params, - allow_missing=False, force_init=True) + allow_missing=allow_missing, force_init=force_init) def save_params(self, fname): """Save model parameters to file. diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 3156ac98..e327a6b4 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -169,7 +169,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non def _impl(name, arr, cache): """Internal helper for parameter initialization""" if cache is not None: - if cache.has_key(name): + if name in cache: cache_arr = cache[name] # just in case the cached array is just the target itself @@ -177,7 +177,8 @@ def _impl(name, arr, cache): cache_arr.copyto(arr) else: assert allow_missing - initializer(name, arr) + if initializer != None: + initializer(name, arr) else: initializer(name, arr)