-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
8-bit Adam #463
8-bit Adam #463
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/463
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d86ec5e with merge base d1e15b4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Regarding the LR scheduler limitation perhaps @mlazos knows what might be going on |
@gau-nernst are you compiling the optimizer? If not, the LR does not need to be a tensor. Can you share a profile? I'm not sure that copying a scalar to cuda mem should have much impact on performance E2E. Although it's possible that we are launching more kernels than necessary within the scheduler itself - this could possibly be rectified by converting to scalar before performing the calculations within the scheduler. |
# if it is a python float. practically, only lr is changed during training. | ||
# NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck. | ||
if not isinstance(group["lr"], Tensor): | ||
group["lr"] = torch.tensor(group["lr"], device=p.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand this code correctly, the code right now initiates lr as a Tensor on the device of the current param just once per group, yes? And this fixes the recompile problem when lr changes value?
This assumes that all params in one param group have the same device, yes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand this code correctly, the code right now initiates lr as a Tensor on the device of the current param just once per group, yes? And this fixes the recompile problem when lr changes value?
Yes.
This assumes that all params in one param group have the same device, yes?
Yes. I hope this is a reasonable assumption. In what scenario this assumption would not be valid? e.g. FSDP?
train.py
Outdated
@@ -0,0 +1,201 @@ | |||
# pip install timm wandb tqdm datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put this in the benchmarks folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to benchmarks/benchmark_adam_8bit.py
|
||
|
||
# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default | ||
@torch.compile(fullgraph=True, dynamic=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mlazos I'm compiling the optimizer step for each param here (called vertical fusion?). This is necessary to fuse dequant+adam_step+quant together.
@msaroufim the profile is quite big, 200MB when I ran it (or I can try compressing it). Perhaps you can run and get the profile and share with @mlazos internally? The following code will profile the first 50 iterations. Lmk if you have problems running the code
python train.py \
--model "timm/vit_base_patch16_224.augreg_in21k" \
--amp bf16 \
--optim AdamDTQ8bit \
--profile
train.py
Outdated
grad_scaler.scale(loss).backward() | ||
|
||
if args.cosine_lr_scheduler: | ||
lr = lr_schedule.get_lr(step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mlazos @msaroufim To add context, this is the way I normally do LR schedule. Calculate LR (as a python float) and set it directly to param groups. I don't use LR scheduler from torch.optim
because I prefer my LR schedule to be stateless (instead of stateful like in torch.optim.lr_scheduler.LRScheduler
, which keeps track of step count internally).
LR schedulers from timm are also stateless, calculate LR as python float and should have the same problem
https://github.com/huggingface/pytorch-image-models/blob/20fe56bd9072af61d9f5404ce8b08e24ff10a807/timm/scheduler/cosine_lr.py#L81-L109
https://github.com/huggingface/pytorch-image-models/blob/20fe56bd9072af61d9f5404ce8b08e24ff10a807/timm/scheduler/scheduler.py#L77-L98
train.py
Outdated
if args.cosine_lr_scheduler: | ||
lr = lr_schedule.get_lr(step) | ||
for param_group in optim.param_groups: | ||
param_group["lr"] = lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try changing this to param_group["lr"].copy_(lr)
instead. Maybe it helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this was going to be my suggestion, where does the lr get wrapped in a tensor after getting retrieved from the scheduler? I don't see it in this code. If you were calling torch.tensor(lr, ..) this will cause an allocation of a scalar on every iteration, which is probably not good. copy_
is a better solution since it will populate the existing memory with the current value, which should have minimal impact on the performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapping LR as a tensor is done on-the-fly inside the optimizer (you can scroll up to see where Jane commented).
I don't know if doing LR schedule like this is common. It's how I normally do it for my projects. We can state this as a known limitation (and the trick to solve is as you outlined here - once I confirm it helps).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't seem to help much, though the slowdown is less than I rmb.
- Without LR schedule: 6.99 it/s
- With LR schedule using
tensor.copy_(lr)
(update every step): 6.72 it/s - With LR schedule using python float (update every step): 6.73 it/s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm interesting, yeah an E2E 4% slowdown does seem like a little much, not totally unexpected because I do expect some slowdown. Perhaps the profile will shed more light on this. A screenshot of the relevant section is also fine too btw, or perhaps you can narrow the region to the optimizer + lr scheduler and share it to lower the size? Also try running a single iteration and seeing if it's still large
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trace.json.gz
I gzip-ed the file, so it's more manageable. Lmk if you prefer other format instead (for security reasons).
From my beginner's eyes, most of the optimizer's time is spent on aten::to, aten::copy and cuda stream synchronize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
To use bnb 8-bit optimizer, set
--optim AdamBnb8bit
. To use 8-bit optimizer implemented in this PR, set--optim AdamDTQ8bit
.To use wandb logging, set
--project AdamInt8
and--run_name vit_base_bf16_amp
(change as needed).To profile and export chrome trace, set
--profile
To enable cosine learning rate scheduler, set
--cosine_lr_scheduler
Known limitation: when learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed decreases significantly. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since
torch.compile()
will treat a Python float as a constant and trigger recompile whenever the value is changed.