-
Notifications
You must be signed in to change notification settings - Fork 279
Open
Description
The current ESMC implementation, has pad tokens attending to each other, which does not effect non pad tokens but does result in vastly different hidden states in aggregate.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)This short script illustrates the issue and fix:
import torch
VOCAB_SIZE = 64
PAD_TOKEN = 0
input_ids_1 = torch.randint(0, VOCAB_SIZE, (1, 6))
input_ids_2 = torch.randint(0, VOCAB_SIZE, (1, 6))
input_ids_2[:,-3:] = PAD_TOKEN
batch = torch.cat([input_ids_1, input_ids_2], dim=0)
seq_id = batch != PAD_TOKEN
print("2D attention mask:")
print(seq_id)
print("4D attention mask from ESM repo:")
mask = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask = mask.unsqueeze(1)
print(mask)
print(mask.shape)
print("A correct 4D attention mask:")
correct_mask = seq_id[:, None, :, None] & seq_id[:, None, None, :]
print(correct_mask)
print(correct_mask.shape)2D attention mask:
tensor([[ True, True, True, True, True, True],
[ True, True, True, False, False, False]])
Correct, first element of the batch has no pad tokens, everything is attended to
4D attention mask from ESM repo:
tensor([[[[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True]]],
Correct, second element has 3 pad tokens which are ignored
[[[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
Incorrect, pad tokens are attending to themselves instead of nothing
[False, False, False, True, True, True],
[False, False, False, True, True, True],
[False, False, False, True, True, True]]]])
(batch_size, 1, seq_len, seq_len) shape, which is correct
torch.Size([2, 1, 6, 6])
Correct, first element of the batch has no pad tokens, everything is attended to
A correct 4D attention mask:
tensor([[[[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True]]],
Correct, second element has 3 pad tokens which are ignored
[[[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
Correct, pad tokens attend to nothing
[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]]]])
(batch_size, 1, seq_len, seq_len) shape, which is correct
torch.Size([2, 1, 6, 6])
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels