@@ -800,7 +800,7 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
800800 :param bool|None saveable:
801801 :param list[list[int]]|None axes_split_info: e.g. [[n],[n]*4] for LSTM matrices
802802 :return: param
803- :rtype tf.Variable
803+ :rtype tf.Variable|tf.Tensor
804804 """
805805 _param = param
806806 if isinstance (param , tf .Tensor ):
@@ -836,14 +836,18 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
836836 custom_update .set_on_var (param )
837837 if axes_split_info :
838838 tf_util .set_param_axes_split_info (param , axes_split_info )
839- if self .reuse_params :
840- name_scope_prefix = self .reuse_params .get_absolute_name_scope_prefix (base_layer = self , param = param )
839+ if getattr (param , "RETURNN_layer_map_name" , None ) is not None :
840+ # Be explicit, take param_name directly from ReuseParams.variable_custom_getter
841+ param_name = param .RETURNN_layer_map_name
841842 else :
842- name_scope_prefix = self .get_absolute_name_scope_prefix ()
843- assert param .name
844- assert param .name [:len (name_scope_prefix )] == name_scope_prefix
845- assert param .name [- 2 :] == ":0"
846- param_name = param .name [len (name_scope_prefix ):- 2 ]
843+ if self .reuse_params :
844+ name_scope_prefix = self .reuse_params .get_absolute_name_scope_prefix (base_layer = self , param = param )
845+ else :
846+ name_scope_prefix = self .get_absolute_name_scope_prefix ()
847+ assert param .name
848+ assert param .name [:len (name_scope_prefix )] == name_scope_prefix
849+ assert param .name [- 2 :] == ":0"
850+ param_name = param .name [len (name_scope_prefix ):- 2 ]
847851 if param_name not in self .params :
848852 self .params [param_name ] = param
849853 else :
@@ -1755,8 +1759,12 @@ def custom_getter(getter, name, *args, **kwargs):
17551759 assert name .startswith (abs_scope_prefix )
17561760 param_name = name [len (abs_scope_prefix ):] # e.g. "W" (not "rec/W")
17571761 if self .custom_func :
1758- return self .custom_func (
1762+ variable = self .custom_func (
17591763 base_layer = base_layer , reuse_layer = self .reuse_layer , name = param_name , getter = getter , full_name = name , ** kwargs )
1764+ # The name of the variable created by custom_func might not match param_name.
1765+ # We store it here for LayerBase.add_param.
1766+ variable .RETURNN_layer_map_name = param_name
1767+ return variable
17601768 if self .param_map is not None :
17611769 if not self .auto_create_missing :
17621770 assert param_name in self .param_map
0 commit comments