Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4D attention_mask support #27539

Merged
merged 14 commits into from
Dec 17, 2023
31 changes: 29 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,22 @@ def _prepare_4d_causal_attention_mask(
key_value_length = input_shape[-1] + past_key_values_length

# 4d mask is passed through the layers
if attention_mask is not None:
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
Expand Down Expand Up @@ -340,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
is_tracing = torch.jit.is_tracing()

if attention_mask is not None:
if torch.all(attention_mask == 1):
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
Expand Down
98 changes: 98 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import glob
import json
import os
Expand Down Expand Up @@ -1850,3 +1851,100 @@ def test_not_available_sdpa(self):
)

self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))


@require_torch
poedator marked this conversation as resolved.
Show resolved Hide resolved
@slow
class Mask4DTest(unittest.TestCase):
def setUp(self):
self.device = torch.device("cuda:0")
poedator marked this conversation as resolved.
Show resolved Hide resolved
model_name = "JackFram/llama-160m" # small Llama-like model from FlexFlow
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(self.device)

the smaller the better for our CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that fp16 tests are more noisy, so what I did is:

  • retained fp32 testsm but used even smaller model
  • added fp16 test with relaxed tolerances
  • added fp16 testing option for the top tokens order.


def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
gc.collect()
torch.cuda.empty_cache()

def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
input_0 = torch.tensor(encoded, device=self.device)
# tensor([[ 1, 278, 6635, 3290],
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')

# Combining common prefix with the unique ending tokens:
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')

# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
device="cuda:0",
poedator marked this conversation as resolved.
Show resolved Hide resolved
dtype=torch.int64,
)

# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=self.device, dtype=torch.int64)

return input_0, input_1, mask_1, position_ids_1

def test_attention(self):
"""comparing outputs of attention layer"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
# outs_0.shape == torch.Size([3, 4, 768])

hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
)[0]
# outs_1.shape == torch.Size([1, 6, 768])

outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens, atol=1e-8)

def test_model(self):
"""comparing hidden outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
assert torch.allclose(
poedator marked this conversation as resolved.
Show resolved Hide resolved
logits_0_last_tokens, logits_1_last_tokens, atol=1e-5
) # note higher atol set to deal with noise

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
assert torch.allclose(
logits_0_last_tokens, logits_1_last_tokens, atol=1e-5
) # note higher atol set to deal with noise