@@ -3482,6 +3482,43 @@ def test_reclayer_optimize_out_dot():
34823482 rtol = 1e-3 )
34833483
34843484
3485+ def test_reclayer_optimize_out_dot_consistent_axes ():
3486+ # https://github.com/rwth-i6/returnn/issues/569
3487+ # Used for multi-head dot-attention.
3488+ n_heads = 4
3489+ n_key = 5
3490+ n_value = 7
3491+ n_key_total = n_heads * n_key
3492+ n_value_total = n_heads * n_value
3493+ check_reclayer_optimize_out (
3494+ {"class" : "linear" , "activation" : None , "from" : "att" },
3495+ other_subnet_layers = {
3496+ "s" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "data:source" ,
3497+ "n_out" : n_key_total }, # (B, D) -- Q (query). D should be same as enc_ctx
3498+ "att_query" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ), "from" : "s" }, # (B, H, D/H)
3499+ # Here is the main test, the dot-layer:
3500+ "energy" : {"class" : "dot" , "red1" : - 1 , "red2" : - 1 , "var1" : "T" , "var2" : "T" ,
3501+ "from" : ["base:enc_ctx" , "att_query" ]},
3502+ # energy inside the loop will be (B, H, enc-T, 1).
3503+ # energy outside the loop will be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
3504+ "att_weights" : {"class" : "softmax_over_spatial" , "from" : "energy" }, # (B, enc-T, H, 1)
3505+ "att0" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:enc_value" }, # (B, H, V)
3506+ "att" : {"class" : "merge_dims" , "axes" : "static" , "from" : "att0" }, # (B, H*V); Use "static" here.
3507+ },
3508+ shared_base_net = {
3509+ "encoder" : {"class" : "copy" , "from" : "data" },
3510+ "enc_ctx0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3511+ "n_out" : n_key_total }, # (B, enc-T, D)
3512+ "enc_ctx" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ),
3513+ "from" : "enc_ctx0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3514+ "enc_value0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3515+ "n_out" : n_value_total },
3516+ "enc_value" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_value ),
3517+ "from" : "enc_value0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3518+ },
3519+ rtol = 1e-3 )
3520+
3521+
34853522def test_reclayer_optimize_out_dot_kv_in_rec ():
34863523 # Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
34873524 AttNumHeads = 4
0 commit comments