@@ -455,6 +455,16 @@ def test_relative_attention_cat_sinks(pretrain_file):
455
455
model = build_model (pretrain_file , '--no_use_lattn' , '--use_rattn' , '--rattn_heads' , '10' , '--rattn_cat' , '--rattn_sinks' , '2' )
456
456
run_forward_checks (model )
457
457
458
+ def test_relative_attention_endpoint_sinks (pretrain_file ):
459
+ model = build_model (pretrain_file , '--no_use_lattn' , '--use_rattn' , '--rattn_heads' , '10' , '--rattn_use_endpoint_sinks' , '--rattn_window' , '2' , '--rattn_sinks' , '1' )
460
+ run_forward_checks (model )
461
+ model = build_model (pretrain_file , '--no_use_lattn' , '--use_rattn' , '--rattn_heads' , '10' , '--rattn_use_endpoint_sinks' , '--rattn_sinks' , '1' )
462
+ run_forward_checks (model )
463
+ model = build_model (pretrain_file , '--no_use_lattn' , '--use_rattn' , '--rattn_heads' , '10' , '--rattn_use_endpoint_sinks' , '--rattn_window' , '2' , '--rattn_sinks' , '2' )
464
+ run_forward_checks (model )
465
+ model = build_model (pretrain_file , '--no_use_lattn' , '--use_rattn' , '--rattn_heads' , '10' , '--rattn_use_endpoint_sinks' , '--rattn_sinks' , '2' )
466
+ run_forward_checks (model )
467
+
458
468
def test_lstm_tree_forward (pretrain_file ):
459
469
"""
460
470
Test the LSTM_TREE forward pass
0 commit comments