@@ -3278,6 +3278,194 @@ def test_name_scope_share_params():
32783278 assert_equal (set (network .get_trainable_params ()), {l1 .params ["W" ], l1 .params ["b" ]})
32793279
32803280
3281+ def test_reuse_params_map_custom_transitive_dependency ():
3282+ # target_embed_raw shares from base:source_embed_raw
3283+ # output_prob shares from target_embed_raw (via custom)
3284+ config = Config ()
3285+ n_in , n_out = 3 , 3
3286+ net_dict = {'dec_01_att_key' : {'axis' : 'F' , 'class' : 'split_dims' , 'dims' : (8 , 64 ), 'from' : ['dec_01_att_key0' ]},
3287+ 'dec_01_att_key0' : {'activation' : None ,
3288+ 'class' : 'linear' ,
3289+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3290+ 'from' : ['encoder' ],
3291+ 'n_out' : 512 ,
3292+ 'with_bias' : False },
3293+ 'dec_01_att_value' : {'axis' : 'F' , 'class' : 'split_dims' , 'dims' : (8 , 64 ), 'from' : ['dec_01_att_value0' ]},
3294+ 'dec_01_att_value0' : {'activation' : None ,
3295+ 'class' : 'linear' ,
3296+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3297+ 'from' : ['encoder' ],
3298+ 'n_out' : 512 ,
3299+ 'with_bias' : False },
3300+ 'decision' : {'class' : 'decide' , 'from' : ['output' ], 'loss' : 'edit_distance' , 'loss_opts' : {}, 'target' : 'classes' },
3301+ 'enc_01' : {'class' : 'copy' , 'from' : ['enc_01_ff_out' ]},
3302+ 'enc_01_ff_conv1' : {'activation' : 'relu' ,
3303+ 'class' : 'linear' ,
3304+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3305+ 'from' : ['enc_01_ff_laynorm' ],
3306+ 'n_out' : 2048 ,
3307+ 'with_bias' : True },
3308+ 'enc_01_ff_conv2' : {'activation' : None ,
3309+ 'class' : 'linear' ,
3310+ 'dropout' : 0.3 ,
3311+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3312+ 'from' : ['enc_01_ff_conv1' ],
3313+ 'n_out' : 512 ,
3314+ 'with_bias' : True },
3315+ 'enc_01_ff_drop' : {'class' : 'dropout' , 'dropout' : 0.3 , 'from' : ['enc_01_ff_conv2' ]},
3316+ 'enc_01_ff_laynorm' : {'class' : 'layer_norm' , 'from' : ['enc_01_self_att_out' ]},
3317+ 'enc_01_ff_out' : {'class' : 'combine' , 'from' : ['enc_01_self_att_out' , 'enc_01_ff_drop' ], 'kind' : 'add' ,
3318+ 'n_out' : 512 },
3319+ 'enc_01_self_att_att' : {'attention_dropout' : 0.3 ,
3320+ 'attention_left_only' : False ,
3321+ 'class' : 'self_attention' ,
3322+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3323+ 'from' : ['enc_01_self_att_laynorm' ],
3324+ 'n_out' : 512 ,
3325+ 'num_heads' : 8 ,
3326+ 'total_key_dim' : 512 },
3327+ 'enc_01_self_att_drop' : {'class' : 'dropout' , 'dropout' : 0.3 , 'from' : ['enc_01_self_att_lin' ]},
3328+ 'enc_01_self_att_laynorm' : {'class' : 'layer_norm' , 'from' : ['source_embed' ]},
3329+ 'enc_01_self_att_lin' : {'activation' : None ,
3330+ 'class' : 'linear' ,
3331+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3332+ 'from' : ['enc_01_self_att_att' ],
3333+ 'n_out' : 512 ,
3334+ 'with_bias' : False },
3335+ 'enc_01_self_att_out' : {'class' : 'combine' , 'from' : ['source_embed' , 'enc_01_self_att_drop' ], 'kind' : 'add' ,
3336+ 'n_out' : 512 },
3337+ 'encoder' : {'class' : 'layer_norm' , 'from' : ['enc_01' ]},
3338+ 'output' : {'class' : 'rec' ,
3339+ 'from' : [],
3340+ 'max_seq_len' : "max_len_from('base:encoder') * 3" ,
3341+ 'target' : 'classes' ,
3342+ 'unit' : {'dec_01' : {'class' : 'copy' , 'from' : ['dec_01_ff_out' ]},
3343+ 'dec_01_att0' : {'base' : 'base:dec_01_att_value' , 'class' : 'generic_attention' ,
3344+ 'weights' : 'dec_01_att_weights_drop' },
3345+ 'dec_01_att_att' : {'axes' : ['dim:8' , 'dim:64' ], 'class' : 'merge_dims' , 'from' : ['dec_01_att0' ]},
3346+ 'dec_01_att_drop' : {'class' : 'dropout' , 'dropout' : 0.3 , 'from' : ['dec_01_att_lin' ]},
3347+ 'dec_01_att_energy' : {'class' : 'dot' ,
3348+ 'from' : ['base:dec_01_att_key' , 'dec_01_att_query' ],
3349+ 'red1' : 'F' , 'red2' : 'F' , 'var1' : 'T' , 'var2' : 'T?' },
3350+ 'dec_01_att_laynorm' : {'class' : 'layer_norm' , 'from' : ['dec_01_self_att_out' ]},
3351+ 'dec_01_att_lin' : {'activation' : None ,
3352+ 'class' : 'linear' ,
3353+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3354+ 'scale=1.0)' ,
3355+ 'from' : ['dec_01_att_att' ],
3356+ 'n_out' : 512 ,
3357+ 'with_bias' : False },
3358+ 'dec_01_att_out' : {'class' : 'combine' , 'from' : ['dec_01_self_att_out' , 'dec_01_att_drop' ], 'kind' : 'add' ,
3359+ 'n_out' : 512 },
3360+ 'dec_01_att_query' : {'axis' : 'F' , 'class' : 'split_dims' , 'dims' : (8 , 64 ), 'from' : ['dec_01_att_query0' ]},
3361+ 'dec_01_att_query0' : {'activation' : None ,
3362+ 'class' : 'linear' ,
3363+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3364+ 'scale=1.0)' ,
3365+ 'from' : ['dec_01_att_laynorm' ],
3366+ 'n_out' : 512 ,
3367+ 'with_bias' : False },
3368+ 'dec_01_att_weights' : {'axis' : 'stag:extern_data:data' ,
3369+ 'class' : 'softmax_over_spatial' ,
3370+ 'energy_factor' : 0.125 ,
3371+ 'from' : ['dec_01_att_energy' ]},
3372+ 'dec_01_att_weights_drop' : {'class' : 'dropout' ,
3373+ 'dropout' : 0.3 ,
3374+ 'dropout_noise_shape' : {'*' : None },
3375+ 'from' : ['dec_01_att_weights' ]},
3376+ 'dec_01_ff_conv1' : {'activation' : 'relu' ,
3377+ 'class' : 'linear' ,
3378+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3379+ 'scale=1.0)' ,
3380+ 'from' : ['dec_01_ff_laynorm' ],
3381+ 'n_out' : 2048 ,
3382+ 'with_bias' : True },
3383+ 'dec_01_ff_conv2' : {'activation' : None ,
3384+ 'class' : 'linear' ,
3385+ 'dropout' : 0.3 ,
3386+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3387+ 'scale=1.0)' ,
3388+ 'from' : ['dec_01_ff_conv1' ],
3389+ 'n_out' : 512 ,
3390+ 'with_bias' : True },
3391+ 'dec_01_ff_drop' : {'class' : 'dropout' , 'dropout' : 0.3 , 'from' : ['dec_01_ff_conv2' ]},
3392+ 'dec_01_ff_laynorm' : {'class' : 'layer_norm' , 'from' : ['dec_01_att_out' ]},
3393+ 'dec_01_ff_out' : {'class' : 'combine' , 'from' : ['dec_01_att_out' , 'dec_01_ff_drop' ], 'kind' : 'add' ,
3394+ 'n_out' : 512 },
3395+ 'dec_01_self_att_att' : {'attention_dropout' : 0.3 ,
3396+ 'attention_left_only' : True ,
3397+ 'class' : 'self_attention' ,
3398+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', "
3399+ "distribution='uniform', scale=1.0)" ,
3400+ 'from' : ['dec_01_self_att_laynorm' ],
3401+ 'n_out' : 512 ,
3402+ 'num_heads' : 8 ,
3403+ 'total_key_dim' : 512 },
3404+ 'dec_01_self_att_drop' : {'class' : 'dropout' , 'dropout' : 0.3 , 'from' : ['dec_01_self_att_lin' ]},
3405+ 'dec_01_self_att_laynorm' : {'class' : 'layer_norm' , 'from' : ['target_embed' ]},
3406+ 'dec_01_self_att_lin' : {'activation' : None ,
3407+ 'class' : 'linear' ,
3408+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', "
3409+ "distribution='uniform', scale=1.0)" ,
3410+ 'from' : ['dec_01_self_att_att' ],
3411+ 'n_out' : 512 ,
3412+ 'with_bias' : False },
3413+ 'dec_01_self_att_out' : {'class' : 'combine' ,
3414+ 'from' : ['target_embed' , 'dec_01_self_att_drop' ],
3415+ 'kind' : 'add' ,
3416+ 'n_out' : 512 },
3417+ 'decoder' : {'class' : 'layer_norm' , 'from' : ['dec_01' ]},
3418+ 'end' : {'class' : 'compare' , 'from' : ['output' ], 'value' : 0 },
3419+ 'output' : {'beam_size' : 12 , 'class' : 'choice' , 'from' : ['output_prob' ], 'initial_output' : 0 ,
3420+ 'target' : 'classes' },
3421+ 'output_prob' : {'class' : 'softmax' ,
3422+ 'dropout' : 0.0 ,
3423+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3424+ 'scale=1.0)' ,
3425+ 'from' : ['decoder' ],
3426+ 'loss' : 'ce' ,
3427+ 'loss_opts' : {'label_smoothing' : 0.2 , 'use_normalized_loss' : True },
3428+ 'reuse_params' : {
3429+ 'map' : {'W' : {'custom' : (lambda reuse_layer , ** kwargs : tf .transpose (reuse_layer .params ["W" ])),
3430+ 'reuse_layer' : 'target_embed_raw' },
3431+ 'b' : None }},
3432+ 'target' : 'classes' ,
3433+ 'with_bias' : True },
3434+ 'target_embed' : {'class' : 'dropout' , 'dropout' : 0.0 , 'from' : ['target_embed_with_pos' ]},
3435+ 'target_embed_raw' : {'activation' : None ,
3436+ 'class' : 'linear' ,
3437+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', "
3438+ 'scale=1.0)' ,
3439+ 'from' : ['prev:output' ],
3440+ 'n_out' : 512 ,
3441+ 'reuse_params' : {'map' : {'W' : {'reuse_layer' : 'base:source_embed_raw' }, 'b' : None }},
3442+ 'with_bias' : False },
3443+ 'target_embed_weighted' : {'class' : 'eval' , 'eval' : 'source(0) * 22.627417' , 'from' : ['target_embed_raw' ]},
3444+ 'target_embed_with_pos' : {'add_to_input' : True , 'class' : 'positional_encoding' ,
3445+ 'from' : ['target_embed_weighted' ]}}},
3446+ 'source_embed' : {'class' : 'dropout' , 'dropout' : 0.0 , 'from' : ['source_embed_with_pos' ]},
3447+ 'source_embed_raw' : {'activation' : None ,
3448+ 'class' : 'linear' ,
3449+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3450+ 'n_out' : 512 ,
3451+ 'with_bias' : False , 'from' : 'data:data' },
3452+ 'source_embed_weighted' : {'class' : 'eval' , 'eval' : 'source(0) * 22.627417' , 'from' : ['source_embed_raw' ]},
3453+ 'source_embed_with_pos' : {'add_to_input' : True , 'class' : 'positional_encoding' , 'from' : ['source_embed_weighted' ]}}
3454+ config .update ({
3455+ "num_outputs" : n_out ,
3456+ "num_inputs" : n_in ,
3457+ "network" : net_dict })
3458+ with make_scope () as session :
3459+ print ("Construct for training" )
3460+ from returnn .tf .layers .rec import RecLayer , _SubnetworkRecCell
3461+ train_net = TFNetwork (config = config , train_flag = True )
3462+ train_net .construct_from_dict (config .typed_dict ["network" ])
3463+ with make_scope () as session :
3464+ print ("Construct for search" )
3465+ search_net = TFNetwork (config = config , train_flag = False , eval_flag = True , search_flag = True )
3466+ search_net .construct_from_dict (config .typed_dict ["network" ])
3467+
3468+
32813469def test_SliceLayer_output_placeholder ():
32823470 with make_scope () as session :
32833471 net = TFNetwork (extern_data = ExternData ())
0 commit comments