Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
25 changes: 18 additions & 7 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __init__(self, cfg: DictConfig) -> None:

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = training.get_world_size_and_rank()
self._world_size, rank = training.get_world_size_and_rank()
self._rank = rank
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -631,7 +632,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
self._world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down Expand Up @@ -697,15 +698,24 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was ever a issue with numerical stability, another option for scaling the loss would be:

if grad_accumulation_step == 0:
	base_num_tokens = current_num_tokens
	torch.distributed.broadcast(base_num_tokens, src=0)

current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens

This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.


# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved

if (idx + 1) % self._gradient_accumulation_steps != 0:
with training.no_sync(self._model):
current_loss.backward()
else:
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
local_num_tokens = num_tokens.detach().clone()
torch.distributed.all_reduce(num_tokens)
training.scale_grads(self._model, self._world_size / num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are so many lines taking care of the all_reduce, backward, etc, that it makes me wonder if this should be a utility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe. In this case I feel like it's important enough (and tricky enough) logic to be done very explicitly. Whatever route we go I will ultimately make it more explicit what's happening here

if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -722,7 +732,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -743,7 +753,8 @@ def train(self) -> None:
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": local_num_tokens
/ time_per_step,
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
9 changes: 5 additions & 4 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -661,7 +662,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
15 changes: 7 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,15 +704,14 @@ def train(self) -> None:
class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -724,8 +723,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
10 changes: 6 additions & 4 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,15 +797,17 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -818,7 +820,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -711,7 +712,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
11 changes: 7 additions & 4 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,22 +692,25 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
4 changes: 4 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
load_from_full_model_state_dict,
load_from_full_optimizer_state_dict,
lora_fsdp_wrap_policy,
no_sync,
prepare_model_for_fsdp_with_meta_device,
set_torch_num_threads,
shard_model,
validate_no_params_on_meta_device,
)
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._profiler import (
DEFAULT_PROFILE_DIR,
DEFAULT_PROFILER_ACTIVITIES,
Expand Down Expand Up @@ -132,4 +134,6 @@
"NoOpManager",
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
"no_sync",
]
23 changes: 22 additions & 1 deletion torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
# LICENSE file in the root directory of this source tree.


import contextlib
import logging
import os
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
Type,
)

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -679,3 +691,12 @@ def shard_model(

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)


@contextlib.contextmanager
def no_sync(model: nn.Module) -> Generator[None, None, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name could be more descriptive, maybe no_grad_sync

model.set_requires_gradient_sync(False)
try:
yield
finally:
model.set_requires_gradient_sync(True)
14 changes: 14 additions & 0 deletions torchtune/training/_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need its own file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


def scale_grads(m: nn.Module, scaler: torch.Tensor) -> None:
for p in m.parameters():
if p.grad is not None:
p.grad *= scaler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any concern here around overflows for lower dtypes? We could do a scaler range check based on dtype. Or is it better to leave it to the recipe to safely choose scaler values?

Loading