Description
System Info
transformers
version: 4.51.3- Platform: Linux-5.10.134-010.ali5000.al8.x86_64-x86_64-with-glibc2.32
- Python version: 3.10.16
- Huggingface_hub version: 0.30.2
- Safetensors version: 0.5.3
- Accelerate version: 1.6.0
- Accelerate config: not found
- DeepSpeed version: 0.15.4
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A800-SXM4-80GB
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
When using gradient_accumulation_steps
in the Trainer, the calculated loss is divided by this number before the backward pass. As shown in this:
transformers/src/transformers/trainer.py
Line 3790 in d5d007a
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
This is intended to average the loss over the accumulated steps. However, a problem arises on the very last training step if the remaining number of batches in the dataloader is less than gradient_accumulation_steps.
As show in this, When num_batches=args.gradient_accumulation_steps and num_batches > len(batch_samples)
:
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
for i, inputs in enumerate(batch_samples):
xxxx
In this scenario, the loss is still divided by the full gradient_accumulation_steps, even though the actual number of accumulated batches is smaller. This results in a final loss value that is artificially small, leading to an incorrect gradient magnitude for the final optimization step.
To Reproduce
- Initialize a Trainer.
- Use a dataset where the total number of samples is not perfectly divisible by per_device_train_batch_size * gradient_accumulation_steps.
- Train the model for one epoch.
- Observe the loss value on the final logging step. It will be significantly smaller than the others if the last accumulation cycle has fewer batches than gradient_accumulation_steps.
An simple example code is below:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch
from torch.utils.data import TensorDataset
from transformers import (
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
AutoConfig
)
from transformers.utils import logging as hf_logging
# 1. Define the model and tokenizer
model_name = "bert-base-uncased"
config = AutoConfig.from_pretrained(model_name)
# Set all dropout probabilities to 0.0, To eliminate the randomness of each forward pass
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
)
# 2. Create a simple dataset
# Total 10 samples, batch_size=2, gradient accumulation=2
# This results in 5 batches in total. The first 4 batches complete one gradient update.
# The 5th batch is the last one, forming an accumulation cycle by itself, but with only one batch.
num_samples = 10
train_dataset = [
{
"input_ids": torch.randint(100, 2000, (8,)), # random generate
"attention_mask": torch.ones(8),
"labels": torch.randint(0, 2, (1,)).item()
}
]*num_samples
# 3. Set training parameters
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_train_epochs=2,
logging_steps=1,
report_to="none",
lr_scheduler_type = "constant",
learning_rate = 0.0 #Without updating parameters
)
# 4. training
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
Running the code, the output log is:
{'loss': 0.9984, 'grad_norm': 36.21815490722656, 'learning_rate': 0.0, 'epoch': 0.4}
{'loss': 0.9984, 'grad_norm': 36.21815490722656, 'learning_rate': 0.0, 'epoch': 0.8}
{'loss': 0.4992, 'grad_norm': 18.10907745361328, 'learning_rate': 0.0, 'epoch': 1.0} <-- The problem!
{'loss': 0.9984, 'grad_norm': 36.21815490722656, 'learning_rate': 0.0, 'epoch': 1.4}
Expected behavior
The loss scaling should be adjusted based on the actual number of batches accumulated in a given cycle. For the final (and potentially incomplete) accumulation cycle, the loss should be divided by the number of batches actually processed in that cycle, not by the total gradient_accumulation_steps
.