Skip to content

Commit

Permalink
Enhancing SFT Training Efficiency Using Packing and FlashAttention2 w…
Browse files Browse the repository at this point in the history
…ith Position IDs (#31629)

* add DataCollatorBatchFlattening

* Update data_collator.py

* change name

* new FA2 flow if position_ids is provided

* add comments

* minor fix

* minor fix data collator

* add test cases for models

* add test case for data collator

* remove extra code

* formating for ruff check and check_repo.py

* ruff format

ruff format tests src utils

* custom_init_isort.py
  • Loading branch information
RhuiDih authored Jul 23, 2024
1 parent 7d92009 commit 9cf4f2a
Show file tree
Hide file tree
Showing 20 changed files with 226 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/main_classes/data_collator.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ Examples of use can be found in the [example scripts](../examples) or [example n
- numpy_mask_tokens
- tf_mask_tokens
- torch_mask_tokens

## DataCollatorWithFlattening

[[autodoc]] data.data_collator.DataCollatorWithFlattening

2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithFlattening",
"DataCollatorWithPadding",
"DefaultDataCollator",
"default_data_collator",
Expand Down Expand Up @@ -4764,6 +4765,7 @@
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,3 +1611,38 @@ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
) & masked_indices[i]

return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)


@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""

def __init__(self, *args, return_position_ids=True, **kwargs):
super().__init__(*args, **kwargs)
self.return_position_ids = return_position_ids
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)

def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)
79 changes: 79 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,56 @@ def _upad_input(
)


def prepare_fa2_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cummulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cummulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)

max_length = position_ids.max() + 1

return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -138,6 +188,7 @@ def _flash_attention_forward(
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
Expand Down Expand Up @@ -210,6 +261,34 @@ def _flash_attention_forward(
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence,
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)

attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def forward(
value_layer,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def forward(
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
Expand Down
72 changes: 72 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4327,6 +4327,78 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

max_new_tokens = 30

for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id

model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)

# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)

res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)

logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]

torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)

@is_pt_tf_cross_test
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
Expand Down
19 changes: 19 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DataCollatorForSeq2Seq,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
Expand Down Expand Up @@ -1531,6 +1532,24 @@ def test_data_collator_with_padding(self):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 8))

def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]

data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])

def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
Expand Down

0 comments on commit 9cf4f2a

Please sign in to comment.