Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 29, 2024

The fix to normalize CE loss by total number of tokens in a step moved our backward call from per-batch to per-step. This means a bunch of activations are hanging around longer than they should and blowing up our memory. We should be able to call backward on the unnormalized values for each batch then manually scale the gradients just before optimizer step.

This is easy enough for our single-device recipes, but for our distributed recipes it's slightly more work. In fact it cannot be done in a way that is both (a) correct and (b) backwards compatible. If you want to know why this is the case, see the "Long digression" section below.

TLDR of the changes:

  1. Gradient accumulation logic changes: We now (a) call backward on every batch's unnormalized loss, (b) accumulate num_tokens over every batch, (c) (distributed only) all_reduce num_tokens on the last batch prior to stepping, then (d) manually scale gradients prior to stepping with the optimizer. This will revert the memory regression introduced by Normalize CE loss by total number of (non-padding) tokens #1875 but keep the correctness of the gradient accumulation logic.
  2. Logging changes: If we are stepping based on total_loss / total_tokens (where total is over all batches and over all ranks), that's what should show up in the logs. Similarly, we can now use the number of tokens over all ranks instead of just rank 0 (though we still normalize tokens/sec to per GPU)
  3. Optimizer in backward: This shouldn't be supported when gradient accumulation is enabled. Previously we didn't raise an explicit error about this (also it actually worked when we only call .backward() once per step, but we don't wanna do that). So now we'll raise an error (also I disabled a couple test cases that snuck in during this time period)
  4. Minor changes to KD and QAT recipes: Integrate utils.batch_to_device into QAT recipe. This was causing num_tokens to show up on CPU instead of GPU because labels weren't moved until later.

Test plan

All the single-device recipe tests succeed without any changes. However, for our distributed recipe tests some of the expected values need to change.

Why are you changing the expected values?!

Again, see the "long digression" section below, but we have been normalizing our loss by the local number of tokens, not the number of tokens seen over all ranks. Is this a huge deal? Honestly probably not, but technically it's not correct. How do I know this version is correct? I've run the following command both on main and on this PR:

pytest -m integration_test tests/recipes/test_full_finetune_distributed.py -k 'test_loss[False-llama2/7B_full-llama2-hf-1-4]'

In both cases I added logging of # of tokens and grad for the token embeddings weight. On main I also commented out the loss normalization so that we can get the raw values. Diff for my changes on main (changes on this PR are just the identical logging).

Note that after the first iteration the results differ due to the difference in how the gradients are calculated, so looking at the logs from just the first iteration:

On main

rank: 1, idx: 0, num_tokens: 12
rank: 0, idx: 0, num_tokens: 59
rank: 0, idx: 1, num_tokens: 60
rank: 1, idx: 1, num_tokens: 268
rank: 0, idx: 2, num_tokens: 14
rank: 1, idx: 2, num_tokens: 133
rank: 0, idx: 3, num_tokens: 11
rank: 1, idx: 3, num_tokens: 6
rank: 0, unnormalized grad: DTensor(local_tensor=-13.8838529586792, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))
rank: 1, unnormalized grad: DTensor(local_tensor=1.0207021236419678, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))

On this PR

rank: 1, idx: 0, num_tokens: 12
rank: 0, idx: 0, num_tokens: 59
rank: 0, idx: 1, num_tokens: 60
rank: 1, idx: 1, num_tokens: 268
rank: 0, idx: 2, num_tokens: 14
rank: 1, idx: 2, num_tokens: 133
rank: 0, idx: 3, num_tokens: 11
rank: 1, idx: 3, num_tokens: 6
rank: 0, grad: DTensor(local_tensor=-0.02466048300266266, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))
rank: 1, grad: DTensor(local_tensor=0.0018129688687622547 device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))

But what does it all mean?

In both the preceding snippets, the first step sees 12 + 59 + 60 + 268 + 14 + 133 + 11 + 6 = 563 tokens.

The grad value logged is just the sum of the elements in the tensor. On main this is -13.8838529586792 on rank 0 and 1.0207021236419678 on rank 1. On this PR it is -0.02466048300266266 on rank 0 and 0.0018129688687622547 on rank 1. In both cases, the value on this PR == the unnormalized value on main / 563, which is what we would expect

End-to-end testing

Llama3 8B full finetune on 4 devices

Repro:

tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
max_steps_per_epoch=500 gradient_accumulation_steps=4

Peak allocated memory drops substantially after this PR (back to where it was before the first gradient accumulation PR, but with actual correct loss calculation now).

Screenshot 2024-10-30 at 5 53 14 PM

Llama 3.2 1B full finetune on 2 devices

tune run --nproc_per_node full_finetune_distributed --config llama3_2/1B_full max_steps_per_epoch=500

We can see that (a) peak allocated memory drops back down to where it was before #1875, (b) tokens per second is the same as #1875 (and faster than before), and (c) the loss curves look similar.

Screenshot 2024-10-30 at 5 57 13 PM Screenshot 2024-10-30 at 6 00 45 PM Screenshot 2024-10-30 at 6 01 40 PM

Long digression: data parallel and gradient accumulation aren't so different

First: on a single device

Let's consider a simple example of how to calculate the correctly-normalized CE loss on a single device (i.e. no data parallel) with two gradient accumulation steps. Say our unnormalized cross-entropy loss for the first batch is L1 and our (similarly unnormalized) loss for the second batch is L2, and we have n1 tokens in the first batch and n2 tokens in the second batch. Then the properly-normalized cross-entropy loss to step with will be (L1 + L2) / (n1 + n2). The naive approach taken in #1875 is to just accumulate a running sum of both loss and number of tokens, then call backward on the ratio running_loss / running_num_tokens just before optimizer step. The problem with this is that we only call backward once, so our activations stick around and blow up our memory. What can we do instead?

Sticking with the single device case, it's not too hard to fix this. Repeated calls to .backward() accumulate gradients, so e.g. the following two are numerically equivalent:

# Single backward call
loss_1 = model(batch_1)
loss_2 = model(batch_2)
summed_loss = loss_1 + loss_2
summed_loss.backward()
# Multiple backward calls
loss_1 = model(batch_1)
loss_1.backward()
loss_2 = model(batch_2)
loss_2.backward()

Then for the single device case we can actually take the second approach with L1 and L2. But remember that we still want to normalize by n1 + n2. This isn't too hard.. we can just manually scale the gradients ourselves just before the optimizer step (since grad(c*X) = c*grad(X) for any constant c this is equivalent to scaling the loss).

Adding data parallel

This is where things get slightly messier. Let's extend the example to two devices: using Lij to refer to the unnormalized cross-entropy loss for rank i in its jth batch, and similarly for nij with number of tokens.

So in the first batch the model will see:
On rank 1: loss L11 based on n11 tokens
On rank 2: loss L21 based on n21 tokens

In the second batch, it will see:
On rank 1: loss L12 based on n12 tokens
On rank 2: loss L22 based on n22 tokens

Similarly to the single-device case, the total number of tokens seen across all batches and all ranks will be n11 + n21 + n12 + n22. This means that the properly-normalized loss should be given by (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).

What are we doing today?

Currently we take a similar approach to the single-device case described previously: we accumulate the losses and tokens, then call .backward() on the ratio running_loss / running_num_tokens just before stepping with the optimizer. See the below code:

running_loss += self._loss_fn(logits, labels) * current_num_tokens
# free logits otherwise it peaks backward memory
del logits
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()

What's wrong with this?

During data parallel training, the loss on a given rank is based only on the subset of data seen by that rank. Similarly, our calculation of running_num_tokens is based only on the tokens from that rank. This means that when we normalize we are normalizing only over iterations, not over ranks. Put another way, the line loss = running_loss / num_tokens in the above snippet will yield (L11 + L12) / (n11 + n12) on rank 1 and (L21 + L22) / (n21 + n22) on rank 2. Finally, we call .backward(), which calculates local grads before firing a hook to sync by reducing over all ranks. The upshot is that our loss winds up as [(L11 + L12) / (n11 + n12)] + [(L21 + L22) / (n21 + n22)], which is definitely not (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).

How can we fix it?

Aside from the correctness issue described in the previous section, we also need an approach that still calls .backward() on each batch to free the activation memory. How can we do this? Actually it's not so bad.. we just need to reduce the number of tokens before our final gradient normalization and let data parallel backward hooks take care of the rest. More explicitly:

First batch:
On rank 1: calculate loss L11 based on n11 tokens
On rank 2: calculate loss L21 based on n21 tokens
Call backward -> this triggers a grad sync and each rank now has the (local) grads for L11 + L21

Second batch:
On rank 1: calculate loss L12 based on n12 tokens
On rank 2: calculate loss L22 based on n22 tokens
Call backward -> this accumulates the grads from L11 + L21 locally, then triggers another sync so that we now have L11 + L21 + L12 + L22

Then we just need n11 + n21 + n12 + n22. Fortunately this is just an all-reduce on running_num_tokens. Then we manually scale the gradients just like in the single-device case.

In summary, we wind up with the following process:

For each batch:

  • Call backward on the unnormalized losses
  • Keep a running tally of the number of tokens seen (per rank)

Before optimizer step:

  • [data parallel only] Reduce the running tally of number of tokens over all ranks
  • Scale gradients by 1 / total_num_tokens

Copy link

pytorch-bot bot commented Oct 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1917

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a878829 with merge base e99b890 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2024
@ebsmothers ebsmothers marked this pull request as draft October 29, 2024 13:59
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need its own file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo

@@ -722,7 +732,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

loss.backward()
local_num_tokens = num_tokens.detach().clone()
torch.distributed.all_reduce(num_tokens)
training.scale_grads(self._model, self._world_size / num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are so many lines taking care of the all_reduce, backward, etc, that it makes me wonder if this should be a utility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe. In this case I feel like it's important enough (and tricky enough) logic to be done very explicitly. Whatever route we go I will ultimately make it more explicit what's happening here



@contextlib.contextmanager
def no_sync(model: nn.Module) -> Generator[None, None, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name could be more descriptive, maybe no_grad_sync

@ebsmothers ebsmothers changed the title [very wip] restore backward after each batch for grad accum Restore backward after each batch for grad accum Oct 30, 2024
@ebsmothers ebsmothers marked this pull request as ready for review October 31, 2024 00:48
@ebsmothers
Copy link
Contributor Author

cc @andrewor14 for review of the QAT recipe changes

@ebsmothers
Copy link
Contributor Author

also cc @lindawangg for the KD recipe changes

@felipemello1
Copy link
Contributor

lgtm

Copy link
Contributor

@lindawangg lindawangg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KD changes looks good to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants