@@ -38,7 +38,7 @@ def __init__(self, d_model, num_heads, window=8, dropout=0.2, reverse=False, d_o
38
38
39
39
self .register_buffer (
40
40
"mask" ,
41
- torch .tril (torch .ones (window + num_sinks , window + num_sinks ), diagonal = - 1 ).unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 )
41
+ torch .tril (torch .ones (window , window ), diagonal = - 1 ).unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 )
42
42
)
43
43
self .register_buffer (
44
44
"flipped_mask" ,
@@ -63,44 +63,43 @@ def forward(self, x):
63
63
seq_len = self .window
64
64
65
65
if self .num_sinks > 0 :
66
- orig_seq_len += self .num_sinks
67
- seq_len += self .num_sinks
68
66
# could keep a parameter to train sinks, but as it turns out,
69
67
# the position vectors just overlap that parameter space anyway
70
68
# generally the model trains the sinks to zero if we do that
71
69
sink = torch .zeros ((batch_size , self .num_sinks , d_model ), dtype = x .dtype , device = x .device )
72
70
x = torch .cat ((sink , x ), axis = 1 )
73
71
74
- # k.shape = (batch_size, num_heads, d_head, seq_len)
75
- k = self .key (x ).reshape (batch_size , seq_len , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )
72
+ # k.shape = (batch_size, num_heads, d_head, seq_len + num_sinks )
73
+ k = self .key (x ).reshape (batch_size , seq_len + self . num_sinks , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )[:, :, :, self . num_sinks :]
76
74
77
- # v.shape = (batch_size, num_heads, d_head, seq_len)
78
- v = self .value (x ).reshape (batch_size , seq_len , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )
75
+ # v.shape = (batch_size, num_heads, d_head, seq_len + num_sinks )
76
+ v = self .value (x ).reshape (batch_size , seq_len + self . num_sinks , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )
79
77
80
- # q.shape = (batch_size, num_heads, d_head, seq_len)
81
- q = self .query (x ).reshape (batch_size , seq_len , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )
82
- # q.shape = (batch_size, num_heads, d_head, window, seq_len)
78
+ # q.shape = (batch_size, num_heads, d_head, seq_len + num_sinks )
79
+ q = self .query (x ).reshape (batch_size , seq_len + self . num_sinks , self .num_heads , - 1 ).permute (0 , 2 , 3 , 1 )
80
+ # q.shape = (batch_size, num_heads, d_head, window + num_sinks , seq_len)
83
81
q = self .skew_repeat (q )
84
82
q = q + self .position
85
83
86
- # qk.shape = (batch_size, num_heads, d_head, window, seq_len)
84
+ # qk.shape = (batch_size, num_heads, d_head, window + num_sinks , seq_len)
87
85
qk = torch .einsum ('bndws,bnds->bndws' , q , k )
88
86
87
+ # TODO: fix mask
89
88
# mask out the padding spaces at the end
90
89
# can only attend to spots that aren't padded
91
90
if orig_seq_len < seq_len :
92
91
# mask out the part of the sentence which is empty
93
- shorter_mask = self .flipped_mask [:, :, :, :, - orig_seq_len :]
94
- qk [:, :, :, :, : orig_seq_len ] = qk [:, :, :, :, : orig_seq_len ]. masked_fill ( shorter_mask == 1 , float ( "-inf" ))
95
- qk = qk [:, :, :, : orig_seq_len , :orig_seq_len ]
92
+ shorter_mask = self .flipped_mask [:, :, :, :orig_seq_len , - orig_seq_len :]
93
+ qk = qk [:, :, :, :( orig_seq_len + self . num_sinks ), : orig_seq_len ]
94
+ qk [:, :, :, - orig_seq_len :, :] = qk [:, :, :, - orig_seq_len : , :]. masked_fill ( shorter_mask == 1 , float ( "-inf" ))
96
95
else :
97
- qk [:, :, :, :, - ( self .window + self .num_sinks ) :] = qk [:, :, :, :, - ( self .window + self .num_sinks ) :].masked_fill (self .flipped_mask == 1 , float ("-inf" ))
96
+ qk [:, :, :, - self .window :, - self .window :] = qk [:, :, :, - self .window :, - self .window :].masked_fill (self .flipped_mask == 1 , float ("-inf" ))
98
97
qk = F .softmax (qk , dim = 3 )
99
98
100
99
# v.shape = (batch_size, num_heads, d_head, window, seq_len)
101
100
v = self .skew_repeat (v )
102
101
if orig_seq_len < seq_len :
103
- v = v [:, :, :, :orig_seq_len , :orig_seq_len ]
102
+ v = v [:, :, :, :( orig_seq_len + self . num_sinks ) , :orig_seq_len ]
104
103
# result.shape = (batch_size, num_heads, d_head, orig_seq_len)
105
104
result = torch .einsum ('bndws,bndws->bnds' , qk , v )
106
105
# batch_size, orig_seq_len, d_output
@@ -109,32 +108,38 @@ def forward(self, x):
109
108
if self .reverse :
110
109
result = torch .flip (result , (1 ,))
111
110
112
- return self .dropout (result [:, self . num_sinks :, :] )
111
+ return self .dropout (result )
113
112
114
113
def skew_repeat (self , q ):
115
- total_window = self .window + self .num_sinks
114
+ if self .num_sinks > 0 :
115
+ q_sink = q [:, :, :, :self .num_sinks ]
116
+ q_sink = q_sink .unsqueeze (3 )
117
+ q_sink = q_sink .repeat (1 , 1 , 1 , 1 , q .shape [- 1 ] - self .num_sinks )
118
+ q = q [:, :, :, self .num_sinks :]
116
119
# make stripes that look like this
117
120
# (seq_len 5, window 3)
118
121
# 1 2 3 4 5
119
122
# 1 2 3 4 5
120
123
# 1 2 3 4 5
121
- q = q .unsqueeze (4 ).repeat (1 , 1 , 1 , 1 , total_window ).transpose (3 , 4 )
124
+ q = q .unsqueeze (4 ).repeat (1 , 1 , 1 , 1 , self . window ).transpose (3 , 4 )
122
125
# now the stripes look like
123
126
# 1 2 3 4 5
124
127
# 0 2 3 4 5
125
128
# 0 0 3 4 5
126
- q [:, :, :, :, :total_window ] = q [:, :, :, :, :total_window ].masked_fill (self .mask == 1 , 0 )
129
+ q [:, :, :, :, :self . window ] = q [:, :, :, :, :self . window ].masked_fill (self .mask == 1 , 0 )
127
130
q_shape = list (q .shape )
128
131
q_new_shape = list (q .shape )[:- 2 ] + [- 1 ]
129
132
q = q .reshape (q_new_shape )
130
133
zeros = torch .zeros_like (q [:, :, :, :1 ])
131
- zeros = zeros .repeat (1 , 1 , 1 , total_window )
134
+ zeros = zeros .repeat (1 , 1 , 1 , self . window )
132
135
q = torch .cat ((q , zeros ), axis = - 1 )
133
- q_new_shape = q_new_shape [:- 1 ] + [total_window , - 1 ]
136
+ q_new_shape = q_new_shape [:- 1 ] + [self . window , - 1 ]
134
137
# now the stripes look like
135
138
# 1 2 3 4 5
136
139
# 2 3 4 5 0
137
140
# 3 4 5 0 0
138
141
# q.shape = (batch_size, num_heads, d_head, window, seq_len)
139
142
q = q .reshape (q_new_shape )[:, :, :, :, :- 1 ]
143
+ if self .num_sinks > 0 :
144
+ q = torch .cat ([q_sink , q ], dim = 3 )
140
145
return q
0 commit comments