Skip to content

Commit

Permalink
Llama: partial 4d masks (huggingface#29731)
Browse files Browse the repository at this point in the history
* partial 4d masks

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
gante and amyeroberts authored Mar 19, 2024
1 parent 425ba56 commit 4294f0c
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 10 deletions.
14 changes: 11 additions & 3 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -967,7 +967,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -993,9 +993,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -975,7 +975,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -1002,9 +1002,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -1094,9 +1094,17 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
offset = past_seen_tokens
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice

if (
self.config._attn_implementation == "sdpa"
Expand Down
110 changes: 109 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,6 @@ def test_not_available_sdpa(self):
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))


@slow
@require_torch_gpu
class Mask4DTestBase(unittest.TestCase):
def tearDown(self):
Expand Down Expand Up @@ -2011,6 +2010,7 @@ def setUp(self):

def test_attention(self):
"""comparing outputs of attention layer"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min

Expand All @@ -2030,6 +2030,7 @@ def test_attention(self):

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
Expand All @@ -2052,6 +2053,7 @@ def setUp(self):

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
Expand All @@ -2069,3 +2071,109 @@ def test_causal_model_logits(self):
# checking tokens order for the top tokens
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))


@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def setUp(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)

def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item

batch_0 = [template.format(x) for x in items] # 3 separate lines
batch_1 = template.format(" ".join(items)) # 1 line with options concatenated

input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)

mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
dtype=torch.int64,
)

position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
# equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1) # same but nicer

return input_0, position_ids_0, input_1, mask_1, position_ids_1

def test_stacked_causal_mask(self):
# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]

# single forward run with 4D custom mask
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :] # last three tokens
decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]

self.assertEqual(decoded_0, decoded_1)

def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
# masks

# Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()

# regular batch
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
logits_0_last = logits_0[:, -1, :] # last tokens in each batch line
decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]

# 2 forward runs with custom 4D masks
part_a = 3 # split point

input_1a = input_1[:, :part_a]
position_ids_1a = position_ids_1[:, :part_a]
mask_1a = mask_1[:, :, :part_a, :part_a]

outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]

input_1b = input_1[:, part_a:]
position_ids_1b = position_ids_1[:, part_a:]
mask_1b = mask_1[:, :, part_a:, :]

outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
)

decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
]

self.assertEqual(decoded_0, decoded_1b)

0 comments on commit 4294f0c

Please sign in to comment.