Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add import_ for SymbolBlock (#11127)
Browse files Browse the repository at this point in the history
* add import_ for SymbolBlock

* fix

* Update block.py

* add save_parameters

* fix

* fix lint

* fix

* fix

* fix

* fix

* fix

* Update save_load_params.md
  • Loading branch information
piiswrong authored Jun 14, 2018
1 parent d79e1ad commit 66ab27e
Show file tree
Hide file tree
Showing 20 changed files with 164 additions and 62 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/gluon/hybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ x = mx.sym.var('data')
y = net(x)
print(y)
y.save('model.json')
net.save_params('model.params')
net.save_parameters('model.params')
```

If your network outputs more than one value, you can use `mx.sym.Group` to
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/gluon/naming.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ except Exception as e:
Parameter 'model1_dense0_weight' is missing in file 'model.params', which contains parameters: 'model0_mydense_weight', 'model0_dense1_bias', 'model0_dense1_weight', 'model0_dense0_weight', 'model0_dense0_bias', 'model0_mydense_bias'. Please make sure source and target networks have the same prefix.


To solve this problem, we use `save_params`/`load_params` instead of `collect_params` and `save`/`load`. `save_params` uses model structure, instead of parameter name, to match parameters.
To solve this problem, we use `save_parameters`/`load_parameters` instead of `collect_params` and `save`/`load`. `save_parameters` uses model structure, instead of parameter name, to match parameters.


```python
model0.save_params('model.params')
model1.load_params('model.params')
model0.save_parameters('model.params')
model1.load_parameters('model.params')
print(mx.nd.load('model.params').keys())
```

Expand Down
16 changes: 4 additions & 12 deletions docs/tutorials/gluon/save_load_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Parameters of any Gluon model can be saved using the `save_params` and `load_par

**2. Save/load model parameters AND architecture**

The Model architecture of `Hybrid` models stays static and don't change during execution. Therefore both model parameters AND architecture can be saved and loaded using `export`, `load_checkpoint` and `load` methods.
The Model architecture of `Hybrid` models stays static and don't change during execution. Therefore both model parameters AND architecture can be saved and loaded using `export`, `imports` methods.

Let's look at the above methods in more detail. Let's start by importing the modules we'll need.

Expand Down Expand Up @@ -61,7 +61,7 @@ def build_lenet(net):
net.add(gluon.nn.Dense(512, activation="relu"))
# Second fully connected layer with as many neurons as the number of classes
net.add(gluon.nn.Dense(num_outputs))

return net

# Train a given model using MNIST data
Expand Down Expand Up @@ -240,18 +240,10 @@ One of the main reasons to serialize model architecture into a JSON file is to l

### From Python

Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and used inside Python frontend using `mx.model.load_checkpoint` and `gluon.nn.SymbolBlock`. To demonstrate that, let's load the network we serialized above.
Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and used inside Python frontend using `gluon.nn.SymbolBlock`. To demonstrate that, let's load the network we serialized above.

```python
# Load the network architecture and parameters
sym = mx.sym.load('lenet-symbol.json')
# Create a Gluon Block using the loaded network architecture.
# 'inputs' parameter specifies the name of the symbol in the computation graph
# that should be treated as input. 'data' is the default name used for input when
# a model architecture is saved to a file.
deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
# Load the parameters
deserialized_net.collect_params().load('lenet-0001.params', ctx=ctx)
deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'], "lenet-0001.params")
```

`deserialized_net` now contains the network we deserialized from files. Let's test the deserialized network to make sure it works.
Expand Down
8 changes: 4 additions & 4 deletions example/gluon/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def transformer(data, label):
logging.info('time: %f' % (time.time() - tic))

if check_point:
netG.save_params(os.path.join(outf,'generator_epoch_%d.params' %epoch))
netD.save_params(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch))
netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))

netG.save_params(os.path.join(outf, 'generator.params'))
netD.save_params(os.path.join(outf, 'discriminator.params'))
netG.save_parameters(os.path.join(outf, 'generator.params'))
netD.save_parameters(os.path.join(outf, 'discriminator.params'))
2 changes: 1 addition & 1 deletion example/gluon/embedding_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def train(epochs, ctx):
if val_accs[0] > best_val:
best_val = val_accs[0]
logging.info('Saving %s.' % opt.save_model_prefix)
net.save_params('%s.params' % opt.save_model_prefix)
net.save_parameters('%s.params' % opt.save_model_prefix)
return best_val


Expand Down
8 changes: 4 additions & 4 deletions example/gluon/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_model(model, ctx, opt):

net = models.get_model(model, **kwargs)
if opt.resume:
net.load_params(opt.resume)
net.load_parameters(opt.resume)
elif not opt.use_pretrained:
if model in ['alexnet']:
net.initialize(mx.init.Normal())
Expand Down Expand Up @@ -176,12 +176,12 @@ def update_learning_rate(lr, trainer, epoch, ratio, steps):
def save_checkpoint(epoch, top1, best_acc):
if opt.save_frequency and (epoch + 1) % opt.save_frequency == 0:
fname = os.path.join(opt.prefix, '%s_%d_acc_%.4f.params' % (opt.model, epoch, top1))
net.save_params(fname)
net.save_parameters(fname)
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
if top1 > best_acc[0]:
best_acc[0] = top1
fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
net.save_params(fname)
net.save_parameters(fname)
logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)

def train(opt, ctx):
Expand Down Expand Up @@ -267,7 +267,7 @@ def main():
optimizer = 'sgd',
optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
initializer = mx.init.Xavier(magnitude=2))
mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
mod.save_parameters('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
else:
if opt.mode == 'hybrid':
net.hybridize()
Expand Down
2 changes: 1 addition & 1 deletion example/gluon/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def train(epochs, ctx):
name, val_acc = test(ctx)
print('[Epoch %d] Validation: %s=%f'%(epoch, name, val_acc))

net.save_params('mnist.params')
net.save_parameters('mnist.params')


if __name__ == '__main__':
Expand Down
8 changes: 4 additions & 4 deletions example/gluon/style_transfer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def train(args):
style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
if args.resume is not None:
print('Resuming, initializing using weight from {}.'.format(args.resume))
style_model.load_params(args.resume, ctx=ctx)
style_model.load_parameters(args.resume, ctx=ctx)
print('style_model:',style_model)
# optimizer and loss
trainer = gluon.Trainer(style_model.collect_params(), 'adam',
Expand Down Expand Up @@ -121,14 +121,14 @@ def train(args):
str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".params"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
style_model.save_params(save_model_path)
style_model.save_parameters(save_model_path)
print("\nCheckpoint, trained model saved at", save_model_path)

# save model
save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".params"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
style_model.save_params(save_model_path)
style_model.save_parameters(save_model_path)
print("\nDone, trained model saved at", save_model_path)


Expand All @@ -143,7 +143,7 @@ def evaluate(args):
style_image = utils.preprocess_batch(style_image)
# model
style_model = net.Net(ngf=args.ngf)
style_model.load_params(args.model, ctx=ctx)
style_model.load_parameters(args.model, ctx=ctx)
# forward
style_model.set_target(style_image)
output = style_model(content_image)
Expand Down
4 changes: 2 additions & 2 deletions example/gluon/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ def train(epoch, ctx):
print('training mse at epoch %d: %s=%f'%(i, name, acc))
test(ctx)

net.save_params('superres.params')
net.save_parameters('superres.params')

def resolve(ctx):
from PIL import Image
if isinstance(ctx, list):
ctx = [ctx[0]]
net.load_params('superres.params', ctx=ctx)
net.load_parameters('superres.params', ctx=ctx)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
Expand Down
2 changes: 1 addition & 1 deletion example/gluon/tree_lstm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test(ctx, data_iter, best, mode='validation', num_iter=-1):
if test_r >= best:
best = test_r
logging.info('New optimum found: {}. Checkpointing.'.format(best))
net.save_params('childsum_tree_lstm_{}.params'.format(num_iter))
net.save_parameters('childsum_tree_lstm_{}.params'.format(num_iter))
test(ctx, test_iter, -1, 'test')
return best

Expand Down
4 changes: 2 additions & 2 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def train():
if val_L < best_val:
best_val = val_L
test_L = eval(test_data)
model.save_params(args.save)
model.save_parameters(args.save)
print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
else:
args.lr = args.lr*0.25
trainer.set_learning_rate(args.lr)

if __name__ == '__main__':
train()
model.load_params(args.save, context)
model.load_parameters(args.save, context)
test_L = eval(test_data)
print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
90 changes: 84 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ
# pylint: disable= arguments-differ, too-many-lines
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

Expand Down Expand Up @@ -307,7 +307,7 @@ def _collect_params_with_prefix(self, prefix=''):
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def save_params(self, filename):
def save_parameters(self, filename):
"""Save parameters to file.
filename : str
Expand All @@ -317,8 +317,23 @@ def save_params(self, filename):
arg_dict = {key : val._reduce() for key, val in params.items()}
ndarray.save(filename, arg_dict)

def load_params(self, filename, ctx=None, allow_missing=False,
ignore_extra=False):
def save_params(self, filename):
"""[Deprecated] Please use save_parameters.
Save parameters to file.
filename : str
Path to file.
"""
warnings.warn("save_params is deprecated. Please use save_parameters.")
try:
self.collect_params().save(filename, strip_prefix=self.prefix)
except ValueError as e:
raise ValueError('%s\nsave_params is deprecated. Using ' \
'save_parameters may resolve this error.'%e.message)

def load_parameters(self, filename, ctx=None, allow_missing=False,
ignore_extra=False):
"""Load parameters from file.
filename : str
Expand Down Expand Up @@ -358,6 +373,25 @@ def load_params(self, filename, ctx=None, allow_missing=False,
if name in params:
params[name]._load_init(loaded[name], ctx)

def load_params(self, filename, ctx=None, allow_missing=False,
ignore_extra=False):
"""[Deprecated] Please use load_parameters.
Load parameters from file.
filename : str
Path to parameter file.
ctx : Context or list of Context, default cpu()
Context(s) initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
"""
warnings.warn("load_params is deprecated. Please use load_parameters.")
self.load_parameters(filename, ctx, allow_missing, ignore_extra)

def register_child(self, block, name=None):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
Expand Down Expand Up @@ -771,8 +805,8 @@ def infer_type(self, *args):
self._infer_attrs('infer_type', 'dtype', *args)

def export(self, path, epoch=0):
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
or the C++ interface.
"""Export HybridBlock to json format that can be loaded by
`SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface.
.. note:: When there are only one input, it will have name `data`. When there
Are more than one inputs, they will be named as `data0`, `data1`, etc.
Expand Down Expand Up @@ -886,6 +920,50 @@ class SymbolBlock(HybridBlock):
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> print(feat_model(x))
"""
@staticmethod
def imports(symbol_file, input_names, param_file=None, ctx=None):
"""Import model previously saved by `HybridBlock.export` or
`Module.save_checkpoint` as a SymbolBlock for use in Gluon.
Parameters
----------
symbol_file : str
Path to symbol file.
input_names : list of str
List of input variable names
param_file : str, optional
Path to parameter file.
ctx : Context, default None
The context to initialize SymbolBlock on.
Returns
-------
SymbolBlock
SymbolBlock loaded from symbol and parameter files.
Examples
--------
>>> net1 = gluon.model_zoo.vision.resnet18_v1(
... prefix='resnet', pretrained=True)
>>> net1.hybridize()
>>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
>>> out1 = net1(x)
>>> net1.export('net1', epoch=1)
>>>
>>> net2 = gluon.SymbolBlock.imports(
... 'net1-symbol.json', ['data'], 'net1-0001.params')
>>> out2 = net2(x)
"""
sym = symbol.load(symbol_file)
if isinstance(input_names, str):
input_names = [input_names]
inputs = [symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.collect_params().load(param_file, ctx=ctx)
return ret


def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
self._prefix = ''
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def alexnet(pretrained=False, ctx=cpu(),
net = AlexNet(**kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_params(get_model_file('alexnet', root=root), ctx=ctx)
net.load_parameters(get_model_file('alexnet', root=root), ctx=ctx)
return net
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_params(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
net.load_parameters(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
return net

def densenet121(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,5 @@ def inception_v3(pretrained=False, ctx=cpu(),
net = Inception3(**kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_params(get_model_file('inceptionv3', root=root), ctx=ctx)
net.load_parameters(get_model_file('inceptionv3', root=root), ctx=ctx)
return net
4 changes: 2 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
version_suffix = '{0:.2f}'.format(multiplier)
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
net.load_params(
net.load_parameters(
get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx)
return net

Expand Down Expand Up @@ -245,7 +245,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
version_suffix = '{0:.2f}'.format(multiplier)
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
net.load_params(
net.load_parameters(
get_model_file('mobilenetv2_%s' % version_suffix, root=root), ctx=ctx)
return net

Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
net = resnet_class(block_class, layers, channels, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_params(get_model_file('resnet%d_v%d'%(num_layers, version),
root=root), ctx=ctx)
net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
root=root), ctx=ctx)
return net

def resnet18_v1(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
net = SqueezeNet(version, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_params(get_model_file('squeezenet%s'%version, root=root), ctx=ctx)
net.load_parameters(get_model_file('squeezenet%s'%version, root=root), ctx=ctx)
return net

def squeezenet1_0(**kwargs):
Expand Down
Loading

0 comments on commit 66ab27e

Please sign in to comment.