Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Apr 30, 2018
1 parent c2de915 commit d7508bb
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,18 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError

def _common_prefix(names):
"""Get the common prefix for all names"""
if not names:
return ''
prefix = names[0]
for name in names:
i = 0
while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
i += 1
prefix = prefix[:i]
return prefix


class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
Expand Down Expand Up @@ -709,7 +721,8 @@ def __init__(self, outputs, inputs, params=None):
self.params.get(i, grad_req='null', allow_deferred_init=True)

self._cached_graph = syms, out
self._reg_params = dict(self._params.items())
len_prefix = len(_common_prefix(list(self._params.keys())))
self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}

def forward(self, x, *args):
if isinstance(x, NDArray):
Expand Down

0 comments on commit d7508bb

Please sign in to comment.