Closed
Description
System Info
transformers
version: 4.44.2- Platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.4
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.4.0+cu121 (True)
- Tensorflow version (GPU?): 2.13.1 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: "ddp", "deepspeed_stage_2"
- Using GPU in script?: tested on 8xH100 and 1xA100-40GB
- GPU type: NVIDIA A100-SXM4-40GB
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The PR 31629 allowed packing with no cross-contamination and without requiring to deal with attention masks for flash-attention-2.
However, prepare_fa2_from_position_ids function produces an error when training with a batch_size greater than 1.
Below is an end-to-end example to reproduce the error:
import numpy as np
import torch
import lightning
from transformers import MistralForCausalLM, MistralConfig
config = MistralConfig(max_position_embeddings = 1024,
hidden_size = 1024,
intermediate_size = 3584,
num_hidden_layers = 8,
pad_token_id = 0,
bos_token_id = 2,
eos_token_id = 3)
config._attn_implementation = "flash_attention_2"
batch_size = 2
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, max_num_pack_attempts):
self.dataset = []
for _ in range(100_000):
# Generate samples of different lengths
sample_len = np.random.randint(10, config.max_position_embeddings)
tokens = np.random.randint(0, config.vocab_size, size = sample_len)
self.dataset.append(tokens)
self.max_num_pack_attempts = max_num_pack_attempts
def get_single_sample(self):
idx = np.random.randint(0, len(self.dataset))
tokens = self.dataset[idx]
return tokens.tolist()
def generate_pack(self):
input_ids = []
labels = [0] # placeholder for the model to shift right by 1
position_ids = []
num_failed_attempts = 0
while (len(input_ids) < config.max_position_embeddings) and (num_failed_attempts < self.max_num_pack_attempts):
sample = self.get_single_sample()
# If there is empty room
if len(input_ids) + len(sample) + 1 < config.max_position_embeddings:
input_ids += [config.bos_token_id] + sample
labels += sample + [config.eos_token_id]
position_ids += range(len(sample) + 1)
else:
num_failed_attempts += 1
# Pad
input_ids = input_ids + [config.pad_token_id] * (config.max_position_embeddings - len(input_ids))
position_ids = position_ids + [config.pad_token_id] * (config.max_position_embeddings - len(position_ids))
labels = labels + [-100] * (config.max_position_embeddings - len(labels))
return {
'input_ids': torch.tensor(input_ids),
'position_ids': torch.tensor(position_ids),
'labels': torch.tensor(labels, dtype=torch.long)
}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.generate_pack()
class CustomModel(lightning.LightningModule):
def __init__(self, config, learning_rate):
super(CustomModel, self).__init__()
self.model = MistralForCausalLM(config = config)
self.learning_rate = learning_rate
num_params = sum(p.numel() for p in self.model.parameters())
print(f'Number of parameters in Mistral: {num_params:,}')
def forward(self, input_ids, position_ids, labels = None):
return self.model(input_ids = input_ids,
position_ids = position_ids,
labels = labels,
use_cache=False)
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
position_ids = batch['position_ids']
labels = batch['labels']
outputs = self(input_ids, position_ids, labels)
loss = outputs.loss
return loss
def configure_optimizers(self):
opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)
return {"optimizer": opt}
data_loader = torch.utils.data.DataLoader(CustomDataset(10), batch_size=batch_size, num_workers=1)
model = CustomModel(config, 3e-4)
print(model.model.model.layers[0].self_attn) # print the model's self attention layer name to make sure it uses FA2
# TRAINER
trainer = lightning.Trainer(
max_steps = 2_000,
accelerator="gpu",
precision = "bf16-mixed",
limit_train_batches = 1_000
)
trainer.fit(model, train_dataloaders=data_loader)
The error:
Epoch 0: 0%| | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/ubuntu/fa2_from_position_ids_test.py", line 111, in <module>
trainer.fit(model, train_dataloaders=data_loader)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results = self._run_stage()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
self.fit_loop.run()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
self._optimizer_step(batch_idx, closure)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 75, in optimizer_step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
out = func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
ret = func(self, *args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/adamw.py", line 204, in step
loss = closure()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
closure_result = closure()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
self._result = self.closure(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
step_output = self._step_fn()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/home/ubuntu/fa2_from_position_ids_test.py", line 91, in training_step
outputs = self(input_ids, position_ids, labels)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/fa2_from_position_ids_test.py", line 81, in forward
return self.model(input_ids = input_ids,
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1033, in forward
outputs = self.model(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 808, in forward
layer_outputs = decoder_layer(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 549, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 382, in forward
attn_output = _flash_attention_forward(
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 272, in _flash_attention_forward
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 166, in prepare_fa2_from_position_ids
key = key.view(-1, key.size(-2), key.size(-1))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
When batch_size is set to 1, the training takes place without an error.
I conducted tests on 8xH100 and 1xA100-40GB, trying different training strategies, e.g., "ddp", "deepspeed_stage_2" and ended up with the same error.
Expected behavior
The training should be possible without an error for different batch_size values.