Skip to content

Bug in attention mask - currently ESM3 and ESMC pad tokens attend to each other #299

@lhallee

Description

@lhallee

The current ESMC implementation, has pad tokens attending to each other, which does not effect non pad tokens but does result in vastly different hidden states in aggregate.

The current mask calculation:

mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)

This short script illustrates the issue and fix:

import torch

VOCAB_SIZE = 64
PAD_TOKEN = 0

input_ids_1 = torch.randint(0, VOCAB_SIZE, (1, 6))
input_ids_2 = torch.randint(0, VOCAB_SIZE, (1, 6))

input_ids_2[:,-3:] = PAD_TOKEN

batch = torch.cat([input_ids_1, input_ids_2], dim=0)

seq_id = batch != PAD_TOKEN
print("2D attention mask:")
print(seq_id)

print("4D attention mask from ESM repo:")
mask = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask = mask.unsqueeze(1)
print(mask)
print(mask.shape)

print("A correct 4D attention mask:")
correct_mask = seq_id[:, None, :, None] & seq_id[:, None, None, :]
print(correct_mask)
print(correct_mask.shape)
2D attention mask:
tensor([[ True,  True,  True,  True,  True,  True],
        [ True,  True,  True, False, False, False]])

Correct, first element of the batch has no pad tokens, everything is attended to

4D attention mask from ESM repo:
tensor([[[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True]]],

Correct, second element has 3 pad tokens which are ignored

        [[[ True,  True,  True, False, False, False],
          [ True,  True,  True, False, False, False],
          [ True,  True,  True, False, False, False],

Incorrect, pad tokens are attending to themselves instead of nothing

          [False, False, False,  True,  True,  True],
          [False, False, False,  True,  True,  True],
          [False, False, False,  True,  True,  True]]]])

(batch_size, 1, seq_len, seq_len) shape, which is correct

torch.Size([2, 1, 6, 6])

Correct, first element of the batch has no pad tokens, everything is attended to

A correct 4D attention mask:
tensor([[[[ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True]]],

Correct, second element has 3 pad tokens which are ignored

        [[[ True,  True,  True, False, False, False],
          [ True,  True,  True, False, False, False],
          [ True,  True,  True, False, False, False],

Correct, pad tokens attend to nothing

          [False, False, False, False, False, False],
          [False, False, False, False, False, False],
          [False, False, False, False, False, False]]]])

(batch_size, 1, seq_len, seq_len) shape, which is correct

torch.Size([2, 1, 6, 6])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions