Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix symbolblock save_params (#10748)
Browse files Browse the repository at this point in the history
* fix symbolblock save_params

* fix
  • Loading branch information
piiswrong authored May 15, 2018
1 parent 275378a commit 0fb57ff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,18 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError

def _common_prefix(names):
"""Get the common prefix for all names"""
if not names:
return ''
prefix = names[0]
for name in names:
i = 0
while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
i += 1
prefix = prefix[:i]
return prefix


class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
Expand Down Expand Up @@ -710,6 +722,8 @@ def __init__(self, outputs, inputs, params=None):
self.params.get(i, grad_req='null', allow_deferred_init=True)

self._cached_graph = syms, out
len_prefix = len(_common_prefix(list(self._params.keys())))
self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}

def forward(self, x, *args):
if isinstance(x, NDArray):
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,33 @@ def test_save_load():

net.load_params('test.params')

def test_symbol_block_save_load():
class Net(gluon.HybridBlock):
def __init__(self):
super(Net, self).__init__()
with self.name_scope():
backbone = gluon.model_zoo.vision.resnet18_v1()
data = mx.sym.var('data')
featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0']
out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames]
internals = backbone(data).get_internals()
outs = [internals[out_name] for out_name in out_names]
self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params())
self.body = nn.Conv2D(3, 1)

def hybrid_forward(self, F, x):
x = self.body(x)
return self.backbone(x)

net1 = Net()
net1.initialize(mx.init.Normal())
net1.hybridize()
net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
net1.save_params('./test.params')

net2 = Net()
net2.load_params('./test.params', ctx=mx.cpu())


def test_hybrid_multi_context():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
Expand Down

0 comments on commit 0fb57ff

Please sign in to comment.