-
Notifications
You must be signed in to change notification settings - Fork 961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Give example on how to handle gradient accumulation with cross-entropy #3193
base: main
Are you sure you want to change the base?
Give example on how to handle gradient accumulation with cross-entropy #3193
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! (We need to do the gather()
+ div by num processes in the trainer still).
Left a few nits, I think it'd be really cool if we can show full training graphs. After doing stuff with FP8 just taking "the end result is the same" at face value I don't fully trust :)
Results on a single device: | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 36 | ||
Total num items 36 | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
w/o accumulation, the final model weight is tensor([0.0953, 0.4337]) | ||
``` | ||
|
||
Results on a two devices set-up: | ||
``` | ||
initial model weight is tensor([-0.0075, 0.5364]) | ||
initial model clone weight is tensor([-0.0075, 0.5364]) | ||
Step 0 - Device 0 - num items in the local batch 52 | ||
Step 0 - Device 1 - num items in the local batch 84 | ||
Total num items 136 | ||
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
w/o accumulation, the final model weight is tensor([0.2117, 0.3172]) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly if we can let's even toss up some wandb
graphs 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it'd be great, but here we do only one single global batch size, I don't think it's worth adding a graph. Maybe should I modify the current code snippet to do this with multiple global steps ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or add some wandb graphs from the upcoming modif of examples/by_feature/gradient_accumulation
?
model_optimizer.zero_grad() | ||
|
||
|
||
logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than logger.warning
, we can do print()
here or change the default logging level :) (Just logging.warning
rather than logging.info
weirds me out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job @ylacombe ! Left a few suggestions !
num_samples_in_epoch = len(dataloader) | ||
remainder = num_samples_in_epoch % gradient_accumulation_steps | ||
remainder = remainder if remainder != 0 else gradient_accumulation_steps | ||
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) | ||
|
||
total_batched_samples = 0 | ||
for update_step in range(total_gradient_updates): | ||
# In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss | ||
# we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples | ||
batch_samples = [] | ||
num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder | ||
for _ in range(num_batches_in_step): | ||
batch_samples += [next(training_iterator)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works when we know the size of the dataloader. Can we think of a solution that doesn't require this information ? I think we can just iter on the dataloader until we have gradient_accumulation_steps
to create the batch_sample. If we can't iter anymore, then we stop also. I think that code will be easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed :) (What we do in the Trainer)
# Since we performed prefetching, we need to manually set sync_gradients | ||
if total_batched_samples % gradient_accumulation_steps != 0: | ||
accelerator.gradient_state._set_sync_gradients(False) | ||
else: | ||
accelerator.gradient_state._set_sync_gradients(True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is due to end_of_dataloader
that will modify incorrectly the sync_gradient
due to the prefetching.
Maybe we can add an option to disable do_sync
in accumulate ? This way, we won't have to put this specific piece of code under accumulate and the user will have total control of when we do sync the gradient. cc @muellerzr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, you can put it as part of this PR if you want @ylacombe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't take into account this case, but we should also set sync_gradient=True
when reaching the very last total_batched_samples
btw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to understand exactly what your point @SunMarc, reaching end_of_dataloader
also set accelerator.step
to 0. If I disable it, we'd have issues when saving the accelerator state, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be fine as we don't care about step in Trainer also. cc @muellerzr but we can leave that for a follow up PR if you want !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow up would be fine by me :)
|
||
Results on a single device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can precise the exact setup ? I think that we are doing the following ?
- dp=1 grad_acc= 2 batch_size = 4 vs dp=1 grad_acc= 1 batch_size = 8 ?
If we are only doing one update, then we won't be able to get a graph. Maybe we do this on a larger dataset where batch_size != len(data_loader) and add the graphs.
Results on a two devices set-up: | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a two devices set-up, the modification you did to take into account the dp won't be reflected here as we are only changing grad acc and batch_size. So the loss will be the same nevertheless. However, it's nice to see that the total_num_items really changed:
- dp=2 grad_acc= 2 batch_size = 4 vs dp=2 grad_acc=1 batch_size=8
Maybe we should probably do a separate section/experiment to show the following will have the same loss graph
- dp=2 batch_size =2 is the same as dp=1 batch_size=4. See this experiment for clarification
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
def test_gradient_accumulation_for_autoregressive_models(self): | ||
testargs = ["examples/by_feature/gradient_accumulation_for_autoregressive_models.py"] | ||
run_command(self.launch_args + testargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a nit: this doesn't use gradient accumulation here since it uses the default of 1
"--per_device_batch_size", | ||
type=int, | ||
default=2, | ||
help="The number of minibatches to be ran before gradients are accumulated.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be "The size of each minibatch"?
help="The number of minibatches to be ran before gradients are accumulated.", | |
help="The size of each minibatch", |
What does this PR do?
Following the recent highlights on how gradient accumulation with the cross-entropy loss is usually off, it could be great to have it mentioned in the doc. I've thus added some code and explanation of it in the gradient accumulation page.
cc @SunMarc and @muellerzr, let me know what you think of it or if I can make this any clearer!