Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
35 changes: 29 additions & 6 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def __init__(self, cfg: DictConfig) -> None:
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)

if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)

# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
Expand Down Expand Up @@ -631,7 +637,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down Expand Up @@ -697,15 +703,31 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was ever a issue with numerical stability, another option for scaling the loss would be:

if grad_accumulation_step == 0:
	base_num_tokens = current_num_tokens
	torch.distributed.broadcast(base_num_tokens, src=0)

current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens

This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.


# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved

# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 783 - 810 could be much more readable as:

if self.optimizer_in_bwd:
	raise if self._clip_grad_norm # or do this in init
	...
	current_loss.backward()
elif (idx + 1) % self._gradient_accumulation_steps == 0:
	current_loss.backward()
	...
	scale_grads()
	if self._clip_grad_norm is not None:
		...
	self._optimizer.step()
	self._optimizer.zero_grad(...)

This could be used in all the distributed recipes.

torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HE'S A GENIUS

# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -722,7 +744,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -743,7 +765,8 @@ def train(self) -> None:
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
9 changes: 5 additions & 4 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -661,7 +662,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
15 changes: 7 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,15 +704,14 @@ def train(self) -> None:
class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -724,8 +723,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
20 changes: 14 additions & 6 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -797,15 +797,22 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -818,7 +825,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -833,7 +840,8 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -711,7 +712,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
35 changes: 21 additions & 14 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,7 @@ def train(self) -> None:
"""
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -668,18 +667,16 @@ def train(self) -> None:

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step

utils.batch_to_device(batch, self._device)

current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
labels = batch.pop("labels")

labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand All @@ -692,22 +689,30 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -722,7 +727,9 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": (
num_tokens / time_per_step * world_size
),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
14 changes: 7 additions & 7 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,20 @@ def _get_test_config_overrides(self):

def _fetch_expected_loss_values(self, model_type):
loss_values_map = {
"llama2": [10.5136, 10.4813, 10.5088, 10.5250],
"llama3": [12.0673, 11.9072, 11.9302, 11.9355],
"llama2": [10.5209, 10.5217, 10.4945, 10.5136],
"llama3": [11.9839, 11.9684, 11.9596, 11.93656],
}
return loss_values_map[model_type]

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
[
("llama2/7B_full", "llama2", "hf", 1, 4),
("llama3/8B_full", "llama3", "tune", 1, 4),
("llama3/8B_full", "llama3", "tune", 4, 1),
("llama2/7B_full", "llama2", "hf", 1, 4, False),
("llama3/8B_full", "llama3", "tune", 1, 4, False),
("llama3/8B_full", "llama3", "tune", 4, 1, True),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
@gpu_test(gpu_count=2)
def test_loss(
self,
Expand Down Expand Up @@ -113,3 +112,4 @@ def test_loss(
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)
raise ValueError("done")
8 changes: 4 additions & 4 deletions tests/recipes/test_qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ def _get_test_config_overrides(self):

def _fetch_expected_loss_values(self, model_type):
loss_values_map = {
"llama2": [10.5164, 10.4830, 10.5138, 10.5199],
"llama3": [12.0672, 11.9067, 11.9304, 11.9351],
"llama2": [10.5211, 10.5217, 10.4944, 10.5134],
"llama3": [11.9836, 11.9683, 11.9594, 11.9366],
}
return loss_values_map[model_type]

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
[
("llama2/7B_qat_full", "llama2", "hf", 4, 1),
("llama3/8B_qat_full", "llama3", "tune", 4, 1),
# ("llama2/7B_qat_full", "llama2", "hf", 4, 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out?

("llama3/8B_qat_full", "llama3", "tune", 4, 1),
("llama3/8B_qat_full", "llama3", "tune", 1, 4),
],
)
@gpu_test(gpu_count=2)
Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
shard_model,
validate_no_params_on_meta_device,
)
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._profiler import (
DEFAULT_PROFILE_DIR,
DEFAULT_PROFILER_ACTIVITIES,
Expand Down Expand Up @@ -132,4 +133,5 @@
"NoOpManager",
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
]
26 changes: 26 additions & 0 deletions torchtune/training/_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need its own file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
"""
Utility to scale the gradients of a model.
This is useful for gradient accumulation where we want to normalize
the gradients by the total number of tokens seen.
Inputs:
model (nn.Module): model whose gradients should be scaled
scaler (torch.Tensor): scaling factor to apply to the gradients
Outputs:
None (grad fields are modified in place)
"""
for p in model.parameters():
if p.grad is not None:
p.grad *= scaler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any concern here around overflows for lower dtypes? We could do a scaler range check based on dtype. Or is it better to leave it to the recipe to safely choose scaler values?

Loading