Skip to content

Commit 027a18d

Browse files
committed
test_reuse_params_map_custom_dep_loop2
1 parent b71bea9 commit 027a18d

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,6 +3097,78 @@ def test_reuse_params_map_custom_dep_loop():
30973097
search_net.construct_from_dict(config.typed_dict["network"])
30983098

30993099

3100+
def test_reuse_params_map_custom_dep_loop2():
3101+
# like test_reuse_params_map_custom_dep_loop, but with target_embed weights shared from encoder
3102+
config = Config()
3103+
n_in, n_out = 3, 3
3104+
config.update({
3105+
"num_outputs": n_out,
3106+
"num_inputs": n_in,
3107+
"network": {
3108+
'encoder': {
3109+
'activation': None, 'class': 'linear',
3110+
'forward_weights_init': "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)",
3111+
'n_out': 6, 'with_bias': False, "from": "data"},
3112+
"enc_ctx": {"class": "linear", "activation": None, "with_bias": True, "from": ["encoder"], "n_out": 10},
3113+
"inv_fertility": {"class": "linear", "activation": "sigmoid", "with_bias": False, "from": ["encoder"],
3114+
"n_out": 1},
3115+
"output": {"class": "rec", "from": [], "unit": {
3116+
'output': {'class': 'choice', 'target': 'classes', 'beam_size': 5, 'from': ["output_prob"],
3117+
"initial_output": 0},
3118+
"end": {"class": "compare", "from": ["output"], "value": 0},
3119+
'target_embed': {
3120+
'class': 'linear', 'activation': None, "with_bias": False, 'from': ['output'], "n_out": 6,
3121+
"initial_output": 0,
3122+
'reuse_params': {'map': {'W': {'reuse_layer': 'base:encoder'}, 'b': None}}},
3123+
"weight_feedback": {"class": "linear", "activation": None, "with_bias": False,
3124+
"from": ["prev:accum_att_weights"], "n_out": 10},
3125+
"prev_s_state": {"class": "get_last_hidden_state", "from": ["prev:s"], "n_out": 20},
3126+
"prev_s_transformed": {"class": "linear", "activation": None, "with_bias": False, "from": ["prev_s_state"],
3127+
"n_out": 10},
3128+
"energy_in": {"class": "combine", "kind": "add",
3129+
"from": ["base:enc_ctx", "weight_feedback", "prev_s_transformed"], "n_out": 10},
3130+
"energy_tanh": {"class": "activation", "activation": "tanh", "from": ["energy_in"]},
3131+
"energy": {"class": "linear", "activation": None, "with_bias": False, "from": ["energy_tanh"], "n_out": 1},
3132+
"att_weights": {"class": "softmax_over_spatial", "from": ["energy"]},
3133+
"accum_att_weights": {"class": "eval",
3134+
"from": ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
3135+
"eval": "source(0) + source(1) * source(2) * 0.5",
3136+
"out_type": {"dim": 1, "shape": (None, 1)}},
3137+
"att": {"class": "generic_attention", "weights": "att_weights", "base": "base:encoder", "auto_squeeze": True},
3138+
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["target_embed", "att"], "n_out": 10},
3139+
"readout_in": {"class": "linear", "from": ["prev:s", "prev:target_embed", "att"], "activation": None,
3140+
"n_out": 2 * 6},
3141+
"readout": {"class": "reduce_out", "mode": "max", "num_pieces": 2, "from": ["readout_in"]},
3142+
"output_prob": {
3143+
"class": "softmax", "from": ["readout"], "dropout": 0.3,
3144+
"reuse_params": {
3145+
"map": {
3146+
"W": {
3147+
"reuse_layer": "target_embed",
3148+
"custom": (lambda reuse_layer, **kwargs: tf.transpose(reuse_layer.params["W"]))},
3149+
"b": None}},
3150+
"target": "classes", "loss": "ce", "loss_opts": {"label_smoothing": 0.1}}
3151+
}, "target": "classes", "max_seq_len": "max_len_from('base:encoder')"},
3152+
}
3153+
})
3154+
with make_scope() as session:
3155+
print("Construct for training")
3156+
from returnn.tf.layers.rec import RecLayer, _SubnetworkRecCell
3157+
train_net = TFNetwork(config=config, train_flag=True)
3158+
train_net.construct_from_dict(config.typed_dict["network"])
3159+
train_rec_layer = train_net.layers["output"]
3160+
assert isinstance(train_rec_layer, RecLayer)
3161+
assert isinstance(train_rec_layer.cell, _SubnetworkRecCell)
3162+
assert_equal(set(train_rec_layer.cell.input_layers_moved_out), {"output", "target_embed"})
3163+
assert_equal(set(train_rec_layer.cell.output_layers_moved_out), {"output_prob", "readout", "readout_in"})
3164+
assert isinstance(train_rec_layer.cell.output_layers_net, TFNetwork)
3165+
assert_equal(set(train_rec_layer.cell.output_layers_net.layers["output_prob"].params.keys()), {"b"})
3166+
with make_scope() as session:
3167+
print("Construct for search")
3168+
search_net = TFNetwork(config=config, train_flag=False, eval_flag=True, search_flag=True)
3169+
search_net.construct_from_dict(config.typed_dict["network"])
3170+
3171+
31003172
def test_name_scope():
31013173
with make_scope() as session:
31023174
n_in, n_out = 2, 3

0 commit comments

Comments
 (0)