You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running into an issue when I try to enable gradient checkpointing in the example sft.py training script. My jobs run fine without gradient checkpointing, but as soon as it's enabled, I run into ValueErrors (see example below).
Traceback (most recent call last):
File "/workspace/trl/examples/scripts/sft.py", line 156, in <module>
trainer.train()
File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 306, in train
output = super().train(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2732, in training_step
self.accelerator.backward(loss)
File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1905, in backward
loss.backward(**kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 271, in backward
outputs = ctx.run_function(*detached_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 796, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 342, in forward
raise ValueError(
ValueError: Attention mask should be of size (1, 1, 1024, 2048), but is torch.Size([1, 1, 1024, 1024])
Environment:
accelerate==0.25.0
bitsandbytes==0.41.3.post2
peft==0.7.1
torch==2.1.0
torchaudio==2.1.0
torchelastic==0.2.2
torchvision==0.16.0
transformers==4.36.0
triton==2.1.0
trl @ git+https://github.com/huggingface/trl@48b3ef0b7bba0f0c249c091781631ddfd98cde7 (pulled from source today)
CUDA 12.1, driver 525.85.12
Hi all,
I'm running into an issue when I try to enable gradient checkpointing in the example sft.py training script. My jobs run fine without gradient checkpointing, but as soon as it's enabled, I run into ValueErrors (see example below).
Environment:
Command to launch:
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes=4 examples/scripts/sft.py --model_name mistralai/Mixtral-8x7B-v0.1 --dataset_name trl-lib/ultrachat_200k_chatml --batch_size 1 --gradient_accumulation_steps 1 --learning_rate 2e-4 --save_steps 200_000 --use_peft --peft_lora_r 16 --peft_lora_alpha 32 --target_modules q_proj k_proj v_proj o_proj --load_in_4bit --seq_length 1024 --gradient_checkpointing True
Is gradient_checkpointing compatible with the latest Mixtral model when using SFTTrainer?
The text was updated successfully, but these errors were encountered: