Skip to content

Commit 561de92

Browse files
committed
ReuseParams: Pass param_name to add_param if using custom_func
Fixes #447
1 parent c04171c commit 561de92

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

returnn/tf/layers/base.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
800800
:param bool|None saveable:
801801
:param list[list[int]]|None axes_split_info: e.g. [[n],[n]*4] for LSTM matrices
802802
:return: param
803-
:rtype tf.Variable
803+
:rtype tf.Variable|tf.Tensor
804804
"""
805805
_param = param
806806
if isinstance(param, tf.Tensor):
@@ -836,14 +836,18 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
836836
custom_update.set_on_var(param)
837837
if axes_split_info:
838838
tf_util.set_param_axes_split_info(param, axes_split_info)
839-
if self.reuse_params:
840-
name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param)
839+
if getattr(param, "RETURNN_layer_map_name", None) is not None:
840+
# Be explicit, take param_name directly from ReuseParams.variable_custom_getter
841+
param_name = param.RETURNN_layer_map_name
841842
else:
842-
name_scope_prefix = self.get_absolute_name_scope_prefix()
843-
assert param.name
844-
assert param.name[:len(name_scope_prefix)] == name_scope_prefix
845-
assert param.name[-2:] == ":0"
846-
param_name = param.name[len(name_scope_prefix):-2]
843+
if self.reuse_params:
844+
name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param)
845+
else:
846+
name_scope_prefix = self.get_absolute_name_scope_prefix()
847+
assert param.name
848+
assert param.name[:len(name_scope_prefix)] == name_scope_prefix
849+
assert param.name[-2:] == ":0"
850+
param_name = param.name[len(name_scope_prefix):-2]
847851
if param_name not in self.params:
848852
self.params[param_name] = param
849853
else:
@@ -1755,8 +1759,12 @@ def custom_getter(getter, name, *args, **kwargs):
17551759
assert name.startswith(abs_scope_prefix)
17561760
param_name = name[len(abs_scope_prefix):] # e.g. "W" (not "rec/W")
17571761
if self.custom_func:
1758-
return self.custom_func(
1762+
variable = self.custom_func(
17591763
base_layer=base_layer, reuse_layer=self.reuse_layer, name=param_name, getter=getter, full_name=name, **kwargs)
1764+
# The name of the variable created by custom_func might not match param_name.
1765+
# We store it here for LayerBase.add_param.
1766+
variable.RETURNN_layer_map_name = param_name
1767+
return variable
17601768
if self.param_map is not None:
17611769
if not self.auto_create_missing:
17621770
assert param_name in self.param_map

0 commit comments

Comments
 (0)