From 181a2e28a61e9406318a8ba08f36fc4c8e23e8c1 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Wed, 8 Aug 2018 11:14:56 -0700 Subject: [PATCH] fix unidirectional model's parameter format (#12055) * fix unidirectional model's parameter format * Update rnn_layer.py --- python/mxnet/gluon/rnn/rnn_layer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 4a7a0be2bc30..d2c6ac9d9f2f 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -21,6 +21,8 @@ # pylint: disable=too-many-lines, arguments-differ """Definition of various recurrent neural network layers.""" from __future__ import print_function +import re + __all__ = ['RNN', 'LSTM', 'GRU'] from ... import ndarray, symbol @@ -92,10 +94,17 @@ def __repr__(self): def _collect_params_with_prefix(self, prefix=''): if prefix: prefix += '.' - def convert_key(key): # for compatibility with old parameter format - key = key.split('_') - return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:])) - ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()} + pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z') + def convert_key(m, bidirectional): # for compatibility with old parameter format + d, l, g, t = [m.group(i) for i in range(1, 5)] + if bidirectional: + return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t) + else: + return '_unfused.{}.{}_{}'.format(l, g, t) + bidirectional = any(pattern.match(k).group(1) == 'r' for k in self._reg_params) + + ret = {prefix + convert_key(pattern.match(key), bidirectional) : val + for key, val in self._reg_params.items()} for name, child in self._children.items(): ret.update(child._collect_params_with_prefix(prefix + name)) return ret