@@ -3097,6 +3097,78 @@ def test_reuse_params_map_custom_dep_loop():
30973097 search_net .construct_from_dict (config .typed_dict ["network" ])
30983098
30993099
3100+ def test_reuse_params_map_custom_dep_loop2 ():
3101+ # like test_reuse_params_map_custom_dep_loop, but with target_embed weights shared from encoder
3102+ config = Config ()
3103+ n_in , n_out = 3 , 3
3104+ config .update ({
3105+ "num_outputs" : n_out ,
3106+ "num_inputs" : n_in ,
3107+ "network" : {
3108+ 'encoder' : {
3109+ 'activation' : None , 'class' : 'linear' ,
3110+ 'forward_weights_init' : "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" ,
3111+ 'n_out' : 6 , 'with_bias' : False , "from" : "data" },
3112+ "enc_ctx" : {"class" : "linear" , "activation" : None , "with_bias" : True , "from" : ["encoder" ], "n_out" : 10 },
3113+ "inv_fertility" : {"class" : "linear" , "activation" : "sigmoid" , "with_bias" : False , "from" : ["encoder" ],
3114+ "n_out" : 1 },
3115+ "output" : {"class" : "rec" , "from" : [], "unit" : {
3116+ 'output' : {'class' : 'choice' , 'target' : 'classes' , 'beam_size' : 5 , 'from' : ["output_prob" ],
3117+ "initial_output" : 0 },
3118+ "end" : {"class" : "compare" , "from" : ["output" ], "value" : 0 },
3119+ 'target_embed' : {
3120+ 'class' : 'linear' , 'activation' : None , "with_bias" : False , 'from' : ['output' ], "n_out" : 6 ,
3121+ "initial_output" : 0 ,
3122+ 'reuse_params' : {'map' : {'W' : {'reuse_layer' : 'base:encoder' }, 'b' : None }}},
3123+ "weight_feedback" : {"class" : "linear" , "activation" : None , "with_bias" : False ,
3124+ "from" : ["prev:accum_att_weights" ], "n_out" : 10 },
3125+ "prev_s_state" : {"class" : "get_last_hidden_state" , "from" : ["prev:s" ], "n_out" : 20 },
3126+ "prev_s_transformed" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : ["prev_s_state" ],
3127+ "n_out" : 10 },
3128+ "energy_in" : {"class" : "combine" , "kind" : "add" ,
3129+ "from" : ["base:enc_ctx" , "weight_feedback" , "prev_s_transformed" ], "n_out" : 10 },
3130+ "energy_tanh" : {"class" : "activation" , "activation" : "tanh" , "from" : ["energy_in" ]},
3131+ "energy" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : ["energy_tanh" ], "n_out" : 1 },
3132+ "att_weights" : {"class" : "softmax_over_spatial" , "from" : ["energy" ]},
3133+ "accum_att_weights" : {"class" : "eval" ,
3134+ "from" : ["prev:accum_att_weights" , "att_weights" , "base:inv_fertility" ],
3135+ "eval" : "source(0) + source(1) * source(2) * 0.5" ,
3136+ "out_type" : {"dim" : 1 , "shape" : (None , 1 )}},
3137+ "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" , "auto_squeeze" : True },
3138+ "s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["target_embed" , "att" ], "n_out" : 10 },
3139+ "readout_in" : {"class" : "linear" , "from" : ["prev:s" , "prev:target_embed" , "att" ], "activation" : None ,
3140+ "n_out" : 2 * 6 },
3141+ "readout" : {"class" : "reduce_out" , "mode" : "max" , "num_pieces" : 2 , "from" : ["readout_in" ]},
3142+ "output_prob" : {
3143+ "class" : "softmax" , "from" : ["readout" ], "dropout" : 0.3 ,
3144+ "reuse_params" : {
3145+ "map" : {
3146+ "W" : {
3147+ "reuse_layer" : "target_embed" ,
3148+ "custom" : (lambda reuse_layer , ** kwargs : tf .transpose (reuse_layer .params ["W" ]))},
3149+ "b" : None }},
3150+ "target" : "classes" , "loss" : "ce" , "loss_opts" : {"label_smoothing" : 0.1 }}
3151+ }, "target" : "classes" , "max_seq_len" : "max_len_from('base:encoder')" },
3152+ }
3153+ })
3154+ with make_scope () as session :
3155+ print ("Construct for training" )
3156+ from returnn .tf .layers .rec import RecLayer , _SubnetworkRecCell
3157+ train_net = TFNetwork (config = config , train_flag = True )
3158+ train_net .construct_from_dict (config .typed_dict ["network" ])
3159+ train_rec_layer = train_net .layers ["output" ]
3160+ assert isinstance (train_rec_layer , RecLayer )
3161+ assert isinstance (train_rec_layer .cell , _SubnetworkRecCell )
3162+ assert_equal (set (train_rec_layer .cell .input_layers_moved_out ), {"output" , "target_embed" })
3163+ assert_equal (set (train_rec_layer .cell .output_layers_moved_out ), {"output_prob" , "readout" , "readout_in" })
3164+ assert isinstance (train_rec_layer .cell .output_layers_net , TFNetwork )
3165+ assert_equal (set (train_rec_layer .cell .output_layers_net .layers ["output_prob" ].params .keys ()), {"b" })
3166+ with make_scope () as session :
3167+ print ("Construct for search" )
3168+ search_net = TFNetwork (config = config , train_flag = False , eval_flag = True , search_flag = True )
3169+ search_net .construct_from_dict (config .typed_dict ["network" ])
3170+
3171+
31003172def test_name_scope ():
31013173 with make_scope () as session :
31023174 n_in , n_out = 2 , 3
0 commit comments