Skip to content

Commit 10dd38f

Browse files
committed
Add an option where the boundary 'words' of a sentence are used as the sinks in the constituency relative attention module
1 parent a6288d8 commit 10dd38f

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

stanza/models/common/relative_attn.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, d_model, num_heads, window=8, dropout=0.2, reverse=False, d_o
4747

4848
self.reverse = reverse
4949

50-
def forward(self, x):
50+
def forward(self, x, sink=None):
5151
# x.shape == (batch_size, seq_len, d_model)
5252
batch_size, seq_len, d_model = x.shape
5353
if d_model != self.d_model:
@@ -66,7 +66,10 @@ def forward(self, x):
6666
# could keep a parameter to train sinks, but as it turns out,
6767
# the position vectors just overlap that parameter space anyway
6868
# generally the model trains the sinks to zero if we do that
69-
sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)
69+
if sink is None:
70+
sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)
71+
else:
72+
sink = sink.expand(batch_size, self.num_sinks, d_model)
7073
x = torch.cat((sink, x), axis=1)
7174

7275
# k.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)

stanza/models/constituency/lstm_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -824,9 +824,12 @@ def map_word(word):
824824
rattn_inputs = [[x] for x in all_word_inputs]
825825

826826
if self.rel_attn_forward is not None:
827-
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
827+
if self.args['rattn_use_endpoint_sinks']:
828+
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0), x[0][0]).squeeze(0)] for x in rattn_inputs]
829+
else:
830+
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
828831
if self.rel_attn_reverse is not None:
829-
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
832+
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]
830833

831834
if self.args['rattn_cat']:
832835
all_word_inputs = [torch.cat(x, axis=1) for x in rattn_inputs]

stanza/models/constituency_parser.py

+1
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def build_argparse():
737737
parser.add_argument('--rattn_cat', default=True, action='store_true', help='Stack the rattn layers instead of adding them')
738738
parser.add_argument('--rattn_dim', default=200, type=int, help='Dimension of the rattn output when cat')
739739
parser.add_argument('--rattn_sinks', default=0, type=int, help='Number of attention sink tokens to learn')
740+
parser.add_argument('--rattn_use_endpoint_sinks', default=False, action='store_true', help='Use the endpoints of the sentences as sinks')
740741

741742
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
742743
parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')

0 commit comments

Comments
 (0)