Skip to content

Commit 8c90bd9

Browse files
committed
Use the --endpoint flag for reverse sinks as well... add a unit test. Is there a way to check that the unit test is doing what we want?
1 parent 10dd38f commit 8c90bd9

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

stanza/models/constituency/lstm_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,10 @@ def map_word(word):
829829
else:
830830
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
831831
if self.rel_attn_reverse is not None:
832-
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]
832+
if self.args['rattn_use_endpoint_sinks']:
833+
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]
834+
else:
835+
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
833836

834837
if self.args['rattn_cat']:
835838
all_word_inputs = [torch.cat(x, axis=1) for x in rattn_inputs]

stanza/tests/constituency/test_lstm_model.py

+10
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,16 @@ def test_relative_attention_cat_sinks(pretrain_file):
455455
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '2')
456456
run_forward_checks(model)
457457

458+
def test_relative_attention_endpoint_sinks(pretrain_file):
459+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '1')
460+
run_forward_checks(model)
461+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '1')
462+
run_forward_checks(model)
463+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '2')
464+
run_forward_checks(model)
465+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '2')
466+
run_forward_checks(model)
467+
458468
def test_lstm_tree_forward(pretrain_file):
459469
"""
460470
Test the LSTM_TREE forward pass

0 commit comments

Comments
 (0)