Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs #31629

Merged
merged 13 commits into from
Jul 23, 2024
5 changes: 5 additions & 0 deletions docs/source/en/main_classes/data_collator.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ Examples of use can be found in the [example scripts](../examples) or [example n
- numpy_mask_tokens
- tf_mask_tokens
- torch_mask_tokens

## DataCollatorWithFlattening

[[autodoc]] data.data_collator.DataCollatorWithFlattening

2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithFlattening",
"DataCollatorWithPadding",
"DefaultDataCollator",
"default_data_collator",
Expand Down Expand Up @@ -4764,6 +4765,7 @@
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,3 +1611,38 @@ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
) & masked_indices[i]

return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)


@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:

- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""

def __init__(self, *args, return_position_ids=True, **kwargs):
super().__init__(*args, **kwargs)
self.return_position_ids = return_position_ids
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)
Comment on lines +1628 to +1631
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should a user make sure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no easy way, the user will have to make sure the model supports FA2 and have position_ids
until the model accepting cu_seq_len which will allow every model remove padding once for all, i think this serve as a good warning


def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)
79 changes: 79 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,56 @@ def _upad_input(
)


def prepare_fa2_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cummulative lengths of each examples in the batch will be extracted from position_ids.

NOTE: ideally cummulative lengths should be prepared at the data collator stage

Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.

Return:
query (`torch.Tensor):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)

max_length = position_ids.max() + 1

return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -138,6 +188,7 @@ def _flash_attention_forward(
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
Expand Down Expand Up @@ -210,6 +261,34 @@ def _flash_attention_forward(
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence,
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)

attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def forward(
value_layer,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
72 changes: 72 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4327,6 +4327,78 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

max_new_tokens = 30

for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id

model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)

# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)

res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)

logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]

torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)

@is_pt_tf_cross_test
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
Expand Down
19 changes: 19 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DataCollatorForSeq2Seq,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
Expand Down Expand Up @@ -1531,6 +1532,24 @@ def test_data_collator_with_padding(self):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 8))

def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]

data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])

def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
Expand Down
Loading