Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit Adam #463

Merged
merged 46 commits into from
Jul 3, 2024
Merged

8-bit Adam #463

merged 46 commits into from
Jul 3, 2024

Conversation

gau-nernst
Copy link
Collaborator

To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core

python train.py \
  --model "timm/vit_base_patch16_224.augreg_in21k" \
  --amp bf16 \
  --optim Adam

To use bnb 8-bit optimizer, set --optim AdamBnb8bit. To use 8-bit optimizer implemented in this PR, set --optim AdamDTQ8bit.

Adam impl max memory (GB) training time accuracy
PyTorch 5.26 9m 11s 93.62%
bnb 8-bit 4.78 9m 10s 93.06%
ao 8-bit 4.78 9m 15s 94.14%

To use wandb logging, set --project AdamInt8 and --run_name vit_base_bf16_amp (change as needed).
To profile and export chrome trace, set --profile
To enable cosine learning rate scheduler, set --cosine_lr_scheduler

Known limitation: when learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed decreases significantly. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

Copy link

pytorch-bot bot commented Jul 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/463

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d86ec5e with merge base d1e15b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2024
@msaroufim
Copy link
Member

Regarding the LR scheduler limitation perhaps @mlazos knows what might be going on

@mlazos
Copy link

mlazos commented Jul 2, 2024

@gau-nernst are you compiling the optimizer? If not, the LR does not need to be a tensor. Can you share a profile? I'm not sure that copying a scalar to cuda mem should have much impact on performance E2E. Although it's possible that we are launching more kernels than necessary within the scheduler itself - this could possibly be rectified by converting to scalar before performing the calculations within the scheduler.

# if it is a python float. practically, only lr is changed during training.
# NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck.
if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)
Copy link
Contributor

@janeyx99 janeyx99 Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this code correctly, the code right now initiates lr as a Tensor on the device of the current param just once per group, yes? And this fixes the recompile problem when lr changes value?

This assumes that all params in one param group have the same device, yes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this code correctly, the code right now initiates lr as a Tensor on the device of the current param just once per group, yes? And this fixes the recompile problem when lr changes value?

Yes.

This assumes that all params in one param group have the same device, yes?

Yes. I hope this is a reasonable assumption. In what scenario this assumption would not be valid? e.g. FSDP?

train.py Outdated
@@ -0,0 +1,201 @@
# pip install timm wandb tqdm datasets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this in the benchmarks folder

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to benchmarks/benchmark_adam_8bit.py



# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default
@torch.compile(fullgraph=True, dynamic=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlazos I'm compiling the optimizer step for each param here (called vertical fusion?). This is necessary to fuse dequant+adam_step+quant together.

@msaroufim the profile is quite big, 200MB when I ran it (or I can try compressing it). Perhaps you can run and get the profile and share with @mlazos internally? The following code will profile the first 50 iterations. Lmk if you have problems running the code

python train.py \
  --model "timm/vit_base_patch16_224.augreg_in21k" \
  --amp bf16 \
  --optim AdamDTQ8bit \
  --profile

train.py Outdated
grad_scaler.scale(loss).backward()

if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlazos @msaroufim To add context, this is the way I normally do LR schedule. Calculate LR (as a python float) and set it directly to param groups. I don't use LR scheduler from torch.optim because I prefer my LR schedule to be stateless (instead of stateful like in torch.optim.lr_scheduler.LRScheduler, which keeps track of step count internally).

LR schedulers from timm are also stateless, calculate LR as python float and should have the same problem
https://github.com/huggingface/pytorch-image-models/blob/20fe56bd9072af61d9f5404ce8b08e24ff10a807/timm/scheduler/cosine_lr.py#L81-L109
https://github.com/huggingface/pytorch-image-models/blob/20fe56bd9072af61d9f5404ce8b08e24ff10a807/timm/scheduler/scheduler.py#L77-L98

train.py Outdated
if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr
Copy link
Collaborator Author

@gau-nernst gau-nernst Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try changing this to param_group["lr"].copy_(lr) instead. Maybe it helps.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this was going to be my suggestion, where does the lr get wrapped in a tensor after getting retrieved from the scheduler? I don't see it in this code. If you were calling torch.tensor(lr, ..) this will cause an allocation of a scalar on every iteration, which is probably not good. copy_ is a better solution since it will populate the existing memory with the current value, which should have minimal impact on the performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping LR as a tensor is done on-the-fly inside the optimizer (you can scroll up to see where Jane commented).
I don't know if doing LR schedule like this is common. It's how I normally do it for my projects. We can state this as a known limitation (and the trick to solve is as you outlined here - once I confirm it helps).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem to help much, though the slowdown is less than I rmb.

  • Without LR schedule: 6.99 it/s
  • With LR schedule using tensor.copy_(lr) (update every step): 6.72 it/s
  • With LR schedule using python float (update every step): 6.73 it/s

Copy link

@mlazos mlazos Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm interesting, yeah an E2E 4% slowdown does seem like a little much, not totally unexpected because I do expect some slowdown. Perhaps the profile will shed more light on this. A screenshot of the relevant section is also fine too btw, or perhaps you can narrow the region to the optimizer + lr scheduler and share it to lower the size? Also try running a single iteration and seeing if it's still large

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.json.gz
I gzip-ed the file, so it's more manageable. Lmk if you prefer other format instead (for security reasons).
From my beginner's eyes, most of the optimizer's time is spent on aten::to, aten::copy and cuda stream synchronize.

@gau-nernst gau-nernst marked this pull request as ready for review July 3, 2024 00:54
@msaroufim msaroufim self-requested a review July 3, 2024 03:16
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@msaroufim msaroufim merged commit 739952b into pytorch:main Jul 3, 2024
13 checks passed
@gau-nernst gau-nernst deleted the 8bit_adam branch July 3, 2024 04:01
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants