Skip to content

8-bit Adam #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
7ed135a
add skeleton
gau-nernst Jun 26, 2024
69bf622
add quant and dequant
gau-nernst Jun 27, 2024
817be03
add enough ops coverage to work with Adam
gau-nernst Jun 27, 2024
da8d9be
add _dequant_list
gau-nernst Jun 27, 2024
1a98c71
simplify
gau-nernst Jun 27, 2024
3efc999
Merge branch 'pytorch:main' into 8bit_adam
gau-nernst Jun 27, 2024
72e00f6
add adam int8
gau-nernst Jun 27, 2024
f522ea2
flatten uint8 storage
gau-nernst Jun 27, 2024
777ae26
fully modified Adam
gau-nernst Jun 27, 2024
dc0bf6b
Merge branch 'pytorch:main' into 8bit_adam
gau-nernst Jun 30, 2024
2c20f63
update
gau-nernst Jun 30, 2024
6e5bfd2
clean
gau-nernst Jun 30, 2024
4d96360
more cleanup
gau-nernst Jun 30, 2024
e7e956a
update train.py
gau-nernst Jun 30, 2024
a7061bc
update train.py
gau-nernst Jun 30, 2024
139642c
fix
gau-nernst Jun 30, 2024
c44eca6
add device to state['step']
gau-nernst Jul 1, 2024
26821ab
update adam to avoid graph break
gau-nernst Jul 1, 2024
81c79dd
fix code dtype
gau-nernst Jul 1, 2024
c0e00c4
add note
gau-nernst Jul 1, 2024
2efa8cf
optimize copy
gau-nernst Jul 1, 2024
8f268ee
return
gau-nernst Jul 1, 2024
290209d
rename folder
gau-nernst Jul 1, 2024
2820aa7
Merge branch 'pytorch:main' into 8bit_adam
gau-nernst Jul 2, 2024
f242b97
add binary search impl
gau-nernst Jul 2, 2024
db62f98
move state increment outside adam step
gau-nernst Jul 2, 2024
be835a0
remove unused import
gau-nernst Jul 2, 2024
7628b52
add a version with torch._fused_adam_()
gau-nernst Jul 2, 2024
090f5d7
support fp32 state
gau-nernst Jul 2, 2024
7cdef75
make wandb optional
gau-nernst Jul 2, 2024
37fc9a9
switch quantize impl. print val_acc to stdout
gau-nernst Jul 2, 2024
cc99cd2
add profile flag
gau-nernst Jul 2, 2024
a8222c0
fix adam v2
gau-nernst Jul 2, 2024
4c88f90
make LR schedule optional. fix data transfer
gau-nernst Jul 2, 2024
a38a833
add weight decay to the kernel
gau-nernst Jul 2, 2024
72af512
some formatting
gau-nernst Jul 2, 2024
cfbf7e5
move file
gau-nernst Jul 2, 2024
2c7337b
remove impl using torch._fused_adam_
gau-nernst Jul 2, 2024
19b4f99
add tests. fix default values
gau-nernst Jul 3, 2024
c0cd149
move file
gau-nernst Jul 3, 2024
b4582de
add AdamW
gau-nernst Jul 3, 2024
3b3e785
fix Optional
gau-nernst Jul 3, 2024
28a66c7
add README
gau-nernst Jul 3, 2024
1f8eff0
Merge branch 'pytorch:main' into 8bit_adam
gau-nernst Jul 3, 2024
bb436e6
skip test for pytorch < 2.3
gau-nernst Jul 3, 2024
d86ec5e
rename to more user-friendly
gau-nernst Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions benchmarks/benchmark_adam_8bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# pip install timm wandb tqdm datasets
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
#
# python benchmarks_adam_8bit.py \
# --model "timm/vit_base_patch16_224.augreg_in21k" \
# --amp bf16 \
# --optim Adam
#
# To use bnb 8-bit optimizer, set --optim AdamBnb8bit. To use 8-bit optimizer implemented in torchao, set --optim AdamDTQ8bit
# To profile and export chrome trace, set --profile
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import math
from contextlib import nullcontext
from pathlib import Path

import bitsandbytes as bnb
import datasets
import timm
import torch
import torch.nn.functional as F
from torch.profiler import ProfilerActivity, profile
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.optim_8bit import AdamDTQ8bit


class CosineSchedule:
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None:
self.lr = lr
self.final_lr = 0
self.total_steps = total_steps
self.warmup_steps = round(total_steps * warmup)

def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
return self.final_lr


class WandbLogger:
def __init__(self, args):
if args.project is not None and not args.profile:
import wandb

Path("wandb_logs").mkdir(exist_ok=True)
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")

else:
self.run = None

def log(self, *args, **kwargs):
if self.run is not None:
self.run.log(*args, **kwargs)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)

parser.add_argument("--amp", default="none")
parser.add_argument("--channels_last", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--n_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--n_workers", type=int, default=4)

parser.add_argument("--optim", default="Adam")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
parser.add_argument("--profile", action="store_true")
return parser


def get_dloader(args, training: bool):
transforms = [v2.ToImage()]

if training:
transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()])
else:
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

return DataLoader(
ds,
batch_size=args.batch_size,
shuffle=training,
num_workers=args.n_workers,
pin_memory=training,
drop_last=training,
)


def get_amp_ctx(amp):
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")


@torch.no_grad()
def evaluate_model(model, args):
model.eval()
val_dloader = get_dloader(args, False)

all_labels = []
all_preds = []

for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"):
all_labels.append(batch["label"].clone())
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())

all_labels = torch.cat(all_labels, dim=0)
all_preds = torch.cat(all_preds, dim=0)

acc = (all_labels == all_preds).float().mean()
return acc


if __name__ == "__main__":
args = get_parser().parse_args()

if args.profile:
args.n_epochs = 1

for k, v in vars(args).items():
print(f"{k}: {v}")

# wandb is only enabled when args.project is set and args.profile is False
logger = WandbLogger(args)
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda()
if args.channels_last:
model.to(memory_format=torch.channels_last)
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

OPTIM_MAP = dict(
Adam=torch.optim.Adam,
AdamBnb8bit=bnb.optim.Adam8bit,
AdamDTQ8bit=AdamDTQ8bit,
)
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()

with prof:
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
grad_scaler.scale(loss).backward()

if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step)

grad_scaler.step(optim)
grad_scaler.update()
optim.zero_grad()

step += 1

if args.profile and step == 20:
break

if args.profile:
prof.export_chrome_trace("trace.json")

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
87 changes: 87 additions & 0 deletions test/prototype/test_optim_8bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import copy

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype.optim_8bit import AdamDTQ8bit, AdamWDTQ8bit
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED
from torchao.utils import TORCH_VERSION_AFTER_2_3

try:
import bitsandbytes as bnb
except ImportError:
bnb = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestDTQ8bit(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(QMAP_SIGNED, device=device)

actual_codes, actual_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(QMAP_SIGNED, device=device)

actual_codes, actual_scale = torch.compile(quantize_8bit_with_qmap, fullgraph=True)(x, qmap, 256)
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)


@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
class TestOptim8bit(TestCase):
@parametrize("optim_cls,bnb_optim_cls", [
(AdamDTQ8bit, bnb.optim.Adam8bit),
(AdamWDTQ8bit, bnb.optim.AdamW8bit),
])
def test_adam_8bit_correctness(self, optim_cls, bnb_optim_cls):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = bnb_optim_cls(model1.parameters())
optim2 = optim_cls(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)


instantiate_parametrized_tests(TestDTQ8bit)
instantiate_parametrized_tests(TestOptim8bit)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).

#### Roadmap

Expand Down
38 changes: 38 additions & 0 deletions torchao/prototype/optim_8bit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# 8-bit optimizers

This folder implements 8-bit optimizers using dynamic tree quantization as outlined in https://arxiv.org/abs/2110.02861. The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

## Usage

This is a drop-in replacement for `torch.optim.Adam`

```python
from torchao.prototype.optim_8bit import AdamDTQ8bit

model = ...
optim = AdamDTQ8bit(model.parameters())
```

You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer.

**Other optimizers**: AdamW is also available as `AdamWDTQ8bit`.

NOTE: this requires PyTorch >= 2.3

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py).

Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER:

Adam impl | max memory (GB) | training time | accuracy
----------|-----------------|---------------|----------
PyTorch | 5.26 | 9m 11s | 93.62%
bnb 8-bit | 4.78 | 9m 10s | 93.06%
ao 8-bit | 4.78 | 9m 15s | 94.14%

**Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed

## Credits

Credits to Tim Dettmers for creating the wonderful bitsandbytes library.
2 changes: 2 additions & 0 deletions torchao/prototype/optim_8bit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .adam import AdamDTQ8bit
from .adamw import AdamWDTQ8bit
Loading
Loading