Skip to content

prepare_fa2_from_position_ids error in training with batch_size > 1 #33268

Closed
@meliksahturker

Description

@meliksahturker

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): 2.13.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: "ddp", "deepspeed_stage_2"
  • Using GPU in script?: tested on 8xH100 and 1xA100-40GB
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@RhuiDih
@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The PR 31629 allowed packing with no cross-contamination and without requiring to deal with attention masks for flash-attention-2.
However, prepare_fa2_from_position_ids function produces an error when training with a batch_size greater than 1.

Below is an end-to-end example to reproduce the error:

import numpy as np
import torch
import lightning

from transformers import MistralForCausalLM, MistralConfig


config = MistralConfig(max_position_embeddings = 1024,
                       hidden_size = 1024,
                       intermediate_size = 3584,
                       num_hidden_layers = 8,
                       pad_token_id = 0,
                       bos_token_id = 2,
                       eos_token_id = 3)
config._attn_implementation = "flash_attention_2"
batch_size = 2

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, max_num_pack_attempts):
        self.dataset = []
        for _ in range(100_000):
            # Generate samples of different lengths
            sample_len = np.random.randint(10, config.max_position_embeddings)
            tokens = np.random.randint(0, config.vocab_size, size = sample_len)
            self.dataset.append(tokens)
        self.max_num_pack_attempts = max_num_pack_attempts


    def get_single_sample(self):
        idx = np.random.randint(0, len(self.dataset))
        tokens = self.dataset[idx]
        return tokens.tolist()

    def generate_pack(self):
        input_ids = []
        labels = [0] # placeholder for the model to shift right by 1
        position_ids = []
        num_failed_attempts = 0
        
        while (len(input_ids) < config.max_position_embeddings) and (num_failed_attempts < self.max_num_pack_attempts):
            sample = self.get_single_sample()
        
            # If there is empty room
            if len(input_ids) + len(sample) + 1 < config.max_position_embeddings:
                input_ids += [config.bos_token_id] + sample
                labels += sample + [config.eos_token_id]
                position_ids += range(len(sample) + 1)
            else:
                num_failed_attempts += 1
        
        # Pad
        input_ids = input_ids + [config.pad_token_id] * (config.max_position_embeddings - len(input_ids))
        position_ids = position_ids + [config.pad_token_id] * (config.max_position_embeddings - len(position_ids))
        labels = labels + [-100] * (config.max_position_embeddings - len(labels))
    
        return {
                'input_ids': torch.tensor(input_ids),
                'position_ids': torch.tensor(position_ids),
                'labels': torch.tensor(labels, dtype=torch.long)
               }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.generate_pack()


class CustomModel(lightning.LightningModule):
    def __init__(self, config, learning_rate):
        super(CustomModel, self).__init__()
        self.model = MistralForCausalLM(config = config)
        self.learning_rate = learning_rate

        num_params = sum(p.numel() for p in self.model.parameters())
        print(f'Number of parameters in Mistral: {num_params:,}')

    def forward(self, input_ids, position_ids, labels = None):
        return self.model(input_ids = input_ids,
                          position_ids = position_ids,
                          labels = labels,
                          use_cache=False)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        position_ids = batch['position_ids']
        labels = batch['labels']
        
        outputs = self(input_ids, position_ids, labels)
        loss = outputs.loss
        return loss

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)
        return {"optimizer": opt}

data_loader = torch.utils.data.DataLoader(CustomDataset(10), batch_size=batch_size, num_workers=1)
model = CustomModel(config, 3e-4)
print(model.model.model.layers[0].self_attn) # print the model's self attention layer name to make sure it uses FA2

# TRAINER
trainer = lightning.Trainer(
    max_steps = 2_000,
    accelerator="gpu",
    precision = "bf16-mixed",
    limit_train_batches = 1_000
)

trainer.fit(model, train_dataloaders=data_loader)

The error:

Epoch 0:   0%|                                                                                                                                             | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 111, in <module>
    trainer.fit(model, train_dataloaders=data_loader)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 75, in optimizer_step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/adamw.py", line 204, in step
    loss = closure()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 91, in training_step
    outputs = self(input_ids, position_ids, labels)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 81, in forward
    return self.model(input_ids = input_ids,
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1033, in forward
    outputs = self.model(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 808, in forward
    layer_outputs = decoder_layer(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 549, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 382, in forward
    attn_output = _flash_attention_forward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 272, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 166, in prepare_fa2_from_position_ids
    key = key.view(-1, key.size(-2), key.size(-1))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

When batch_size is set to 1, the training takes place without an error.
I conducted tests on 8xH100 and 1xA100-40GB, trying different training strategies, e.g., "ddp", "deepspeed_stage_2" and ended up with the same error.

Expected behavior

The training should be possible without an error for different batch_size values.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions