Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError when using gradient_checkpointing with SFTTrainer #1088

Closed
alwayshalffull opened this issue Dec 12, 2023 · 1 comment · Fixed by huggingface/transformers#28061
Closed

Comments

@alwayshalffull
Copy link

Hi all,

I'm running into an issue when I try to enable gradient checkpointing in the example sft.py training script. My jobs run fine without gradient checkpointing, but as soon as it's enabled, I run into ValueErrors (see example below).

Traceback (most recent call last):
  File "/workspace/trl/examples/scripts/sft.py", line 156, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 306, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2732, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1905, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 271, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 796, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 342, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 1024, 2048), but is torch.Size([1, 1, 1024, 1024])

Environment:

accelerate==0.25.0
bitsandbytes==0.41.3.post2
peft==0.7.1
torch==2.1.0
torchaudio==2.1.0
torchelastic==0.2.2
torchvision==0.16.0
transformers==4.36.0
triton==2.1.0
trl @ git+https://github.com/huggingface/trl@48b3ef0b7bba0f0c249c091781631ddfd98cde7 (pulled from source today)

CUDA 12.1, driver 525.85.12

Command to launch:
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes=4 examples/scripts/sft.py --model_name mistralai/Mixtral-8x7B-v0.1 --dataset_name trl-lib/ultrachat_200k_chatml --batch_size 1 --gradient_accumulation_steps 1 --learning_rate 2e-4 --save_steps 200_000 --use_peft --peft_lora_r 16 --peft_lora_alpha 32 --target_modules q_proj k_proj v_proj o_proj --load_in_4bit --seq_length 1024 --gradient_checkpointing True

Is gradient_checkpointing compatible with the latest Mixtral model when using SFTTrainer?

@younesbelkada
Copy link
Contributor

Thanks for the issue, huggingface/transformers#28031 should fix it for other models than Mixtral and huggingface/transformers#28061 will fix it for mixtral

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants