Skip to content
17 changes: 10 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,8 +2483,7 @@ def _inner_training_loop(
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
else:
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.is_fsdp_enabled:
Expand Down Expand Up @@ -3783,7 +3782,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
"""
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self.args.include_num_input_tokens_seen:
if self.args.include_num_input_tokens_seen != "no":
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
if start_time is not None:
logs.update(speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen))
Expand Down Expand Up @@ -4143,7 +4142,7 @@ def compute_loss(
and (self.model_accepts_loss_kwargs or self.compute_loss_func)
and num_items_in_batch is not None
):
loss *= self.accelerator.num_processes
loss *= self.accelerator.num_processes if self.args.n_gpu <= 1 else self.args.n_gpu

return (loss, outputs) if return_outputs else loss

Expand Down Expand Up @@ -5617,15 +5616,19 @@ def get_batch_samples(
pass

if num_items_in_batch is not None:
if self.args.average_tokens_across_devices:
if self.args.average_tokens_across_devices and self.args.world_size >= 1:
num_items_in_batch = self.accelerator.gather(num_items_in_batch.to(device)).sum()
elif self.args.n_gpu >= 1:
# In DP case, if we don't average, we need to divide by the number of gpu. This is the simplest approximation.
# Otherwise, we would have to scatter labels and calculate num_items_in_batch for each gpu.
num_items_in_batch = num_items_in_batch // self.args.n_gpu

if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(device)

if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
num_items_in_batch = num_items_in_batch.unsqueeze(0)
# In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated
num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1)
# Divide by number of devices with the same batch
if pc := getattr(self.accelerator, "parallelism_config", None):
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
Expand Down
12 changes: 0 additions & 12 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,18 +1790,6 @@ def __post_init__(self):
if self.framework == "pt" and is_torch_available():
self.device

# Disable average tokens when using single device
if self.average_tokens_across_devices:
try:
if self.world_size == 1:
logger.info(
"average_tokens_across_devices is True but world size is 1. Setting it to False automatically."
)
self.average_tokens_across_devices = False
except ImportError as e:
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
self.average_tokens_across_devices = False

if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down
14 changes: 12 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,18 @@ def test_adafactor_lr_none(self):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)

@require_torch_fp16
@require_torch_accelerator
def test_mixed_fp16(self):
# very basic test
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(learning_rate=0.1, fp16=True, logging_steps=1, output_dir=tmp_dir)
trainer.train()
self.check_trained_model(trainer.model, atol=ATOL, rtol=RTOL)
log_0 = trainer.state.log_history[:-1][0]
# check that the grads were properly clipped due to the grad scaler. Otherwise, we get huge values
self.assertEqual(log_0["grad_norm"] < 100, True)

@require_torch_bf16
@require_torch_accelerator
def test_mixed_bf16(self):
Expand All @@ -1286,8 +1298,6 @@ def test_mixed_bf16(self):
learning_rate=0.1, bf16=True, half_precision_backend="apex", output_dir=tmp_dir
)

# will add more specific tests once there are some bugs to fix

@require_torch_gpu
@require_torch_tf32
def test_tf32(self):
Expand Down