Skip to content

Commit 1b1e94c

Browse files
authored
Optimizer CPU offload for single GPU training (#584)
* initial commit * use fused=True by default for PyTorch adam * detach param * try overlap D2H grad copy with backward * add customizable profile num steps * add v2 * fix various bugs * fix v1 impl * add full BF16 option * change n_profile_steps to 5 * add v3 * fix gradient accumulation * add note * add deepspeed offload * update deepspeed config * add some notes * update instructions. make some packages optional. change to AdamW * add last updated ordered dict * update deepspeed version * remove old versions * update docs * say deepspeed is untuned * add test * add test for offload_gradients. update benchmark script * update benchmark results. fix test. fix benchmark script * fix language * add save and load * pre-allocate CPU params. add note about gradient clipping * update README and remove unused imports
1 parent de4a1fb commit 1b1e94c

File tree

5 files changed

+330
-28
lines changed

5 files changed

+330
-28
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,60 @@
1-
# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git
2-
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
1+
# pip install timm wandb tqdm datasets bitsandbytes
32
#
3+
# optional:
4+
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
5+
# - DeepSpeed (ZeRO-Offload):
6+
# sudo apt install libopenmpi-dev
7+
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
8+
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
9+
#
10+
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
411
# python benchmark_low_bit_adam.py \
512
# --model "timm/vit_base_patch16_224.augreg_in21k" \
613
# --amp bf16 \
7-
# --optim Adam
14+
# --optim AdamW
815
#
916
# See OPTIM_MAP for the available optimizer options
1017
# To profile and export chrome trace, set --profile
1118
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler
1219

1320
import argparse
1421
import datetime
22+
import json
1523
import math
1624
from contextlib import nullcontext
1725
from functools import partial
1826
from pathlib import Path
1927

2028
import bitsandbytes as bnb
2129
import datasets
22-
import lpmm
2330
import timm
2431
import torch
2532
import torch.nn.functional as F
26-
from torch.profiler import ProfilerActivity, profile
2733
from torch.utils.data import DataLoader
2834
from torchvision.transforms import v2
2935
from tqdm import tqdm
3036

3137
from torchao.prototype import low_bit_optim
3238

33-
# lpmm doesn't have Adam, only AdamW
3439
OPTIM_MAP = dict(
35-
Adam=torch.optim.Adam,
36-
Adam8bitBnb=bnb.optim.Adam8bit,
37-
Adam8bitAo=low_bit_optim.Adam8bit,
38-
AdamFp8Ao=low_bit_optim.AdamFp8,
39-
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
40-
Adam4bitAo=low_bit_optim.Adam4bit,
41-
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
40+
AdamW=partial(torch.optim.AdamW, fused=True),
41+
AdamW8bitBnb=bnb.optim.AdamW8bit,
42+
AdamW8bitAo=low_bit_optim.AdamW8bit,
43+
AdamWFp8Ao=low_bit_optim.AdamWFp8,
44+
AdamW4bitAo=low_bit_optim.AdamW4bit,
4245
)
4346

47+
try:
48+
import lpmm
49+
50+
OPTIM_MAP.update(
51+
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
52+
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
53+
)
54+
55+
except ImportError:
56+
pass
57+
4458

4559
class CosineSchedule:
4660
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None:
@@ -77,19 +91,23 @@ def log(self, *args, **kwargs):
7791
def get_parser():
7892
parser = argparse.ArgumentParser()
7993
parser.add_argument("--model", required=True)
94+
parser.add_argument("--model_kwargs", type=json.loads, default=dict())
95+
parser.add_argument("--checkpoint_activations", action="store_true")
8096

8197
parser.add_argument("--amp", default="none")
98+
parser.add_argument("--full_bf16", action="store_true")
8299
parser.add_argument("--channels_last", action="store_true")
83100
parser.add_argument("--compile", action="store_true")
84101

85102
parser.add_argument("--n_epochs", type=int, default=10)
86103
parser.add_argument("--batch_size", type=int, default=64)
87104
parser.add_argument("--n_workers", type=int, default=4)
88105

89-
parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys())
106+
parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys())
90107
parser.add_argument("--lr", type=float, default=1e-4)
91108
parser.add_argument("--weight_decay", type=float, default=0)
92109
parser.add_argument("--cosine_lr_scheduler", action="store_true")
110+
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
93111

94112
parser.add_argument("--project")
95113
parser.add_argument("--run_name", default="debug")
@@ -140,6 +158,8 @@ def evaluate_model(model, args):
140158

141159
for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"):
142160
all_labels.append(batch["label"].clone())
161+
if args.full_bf16:
162+
batch["image"] = batch["image"].bfloat16()
143163
if args.channels_last:
144164
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
145165

@@ -156,6 +176,11 @@ def evaluate_model(model, args):
156176
if __name__ == "__main__":
157177
args = get_parser().parse_args()
158178

179+
if args.full_bf16:
180+
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
181+
if args.optim_cpu_offload == "deepspeed":
182+
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
183+
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
159184
if args.profile:
160185
args.n_epochs = 1
161186
if args.seed is not None:
@@ -169,48 +194,95 @@ def evaluate_model(model, args):
169194
dloader = get_dloader(args, True)
170195
print(f"Train dataset: {len(dloader.dataset):,} images")
171196

172-
model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda()
197+
model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
198+
if args.checkpoint_activations:
199+
model.set_grad_checkpointing()
200+
if args.full_bf16:
201+
model.bfloat16()
173202
if args.channels_last:
174203
model.to(memory_format=torch.channels_last)
204+
model.cuda() # move model to CUDA after optionally convert it to BF16
175205
if args.compile:
176206
model.compile(fullgraph=True)
177207
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
178208

179-
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
180-
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
209+
if args.optim_cpu_offload == "deepspeed":
210+
import deepspeed
211+
212+
model, optim, _, _ = deepspeed.initialize(
213+
model=model,
214+
model_parameters=model.parameters(),
215+
config=dict(
216+
train_batch_size=args.batch_size,
217+
optimizer=dict(
218+
type="Adam",
219+
params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False),
220+
),
221+
bf16=dict(enabled=args.full_bf16),
222+
zero_optimization=dict(
223+
stage=2, # requires ZeRO-2 to enable overlap_comm
224+
overlap_comm=True, # interleave grad D2H with backward
225+
offload_optimizer=dict(device="cpu", pin_memory=True),
226+
),
227+
),
228+
)
229+
230+
else:
231+
optim_cls = OPTIM_MAP[args.optim]
232+
233+
if args.optim_cpu_offload == "ao":
234+
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
235+
elif args.optim_cpu_offload == "ao_offload_grads":
236+
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
181237

238+
optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
239+
240+
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
182241
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
183242

184243
step = 0
185244
for epoch_idx in range(args.n_epochs):
186245
model.train()
187-
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()
246+
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
247+
188248
start_time = datetime.datetime.now()
189249

190-
with prof:
191-
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
250+
with torch.profiler.profile() if args.profile else nullcontext() as prof:
251+
for batch in pbar:
252+
if args.full_bf16:
253+
batch["image"] = batch["image"].bfloat16()
192254
if args.channels_last:
193255
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
194256

195257
with get_amp_ctx(args.amp):
196258
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
197-
grad_scaler.scale(loss).backward()
259+
260+
if args.optim_cpu_offload == "deepspeed":
261+
model.backward(loss)
262+
else:
263+
grad_scaler.scale(loss).backward()
198264

199265
if args.cosine_lr_scheduler:
200266
lr = lr_schedule.get_lr(step)
201267
for param_group in optim.param_groups:
202268
param_group["lr"] = lr
203269

204270
if step % 100 == 0:
205-
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step)
206-
207-
grad_scaler.step(optim)
208-
grad_scaler.update()
209-
optim.zero_grad()
271+
logger.log(
272+
dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]),
273+
step=step,
274+
)
275+
276+
if args.optim_cpu_offload == "deepspeed":
277+
model.step()
278+
else:
279+
grad_scaler.step(optim)
280+
grad_scaler.update()
281+
optim.zero_grad()
210282

211283
step += 1
212284

213-
if args.profile and step == 20:
285+
if args.profile and step == 5:
214286
break
215287

216288
if args.profile:

test/prototype/test_low_bit_optim.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import tempfile
23

34
import pytest
45
import torch
@@ -157,6 +158,69 @@ def test_optim_4bit_correctness(self, optim_name):
157158
for p1, p2 in zip(model1.parameters(), model2.parameters()):
158159
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
159160

161+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
162+
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
163+
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
164+
device = "cuda"
165+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
166+
model2 = copy.deepcopy(model1)
167+
168+
optim1 = torch.optim.AdamW(model1.parameters())
169+
optim2 = low_bit_optim.CPUOffloadOptimizer(
170+
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
171+
)
172+
173+
for _ in range(2):
174+
for _ in range(grad_accum):
175+
x = torch.randn(4, 32, device=device)
176+
model1(x).sum().backward()
177+
model2(x).sum().backward()
178+
179+
optim1.step()
180+
optim1.zero_grad()
181+
182+
optim2.step()
183+
optim2.zero_grad()
184+
185+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
186+
torch.testing.assert_close(p2, p1)
187+
188+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
189+
def test_optim_cpu_offload_save_load(self):
190+
device = "cuda"
191+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
192+
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
193+
194+
for _ in range(2):
195+
x = torch.randn(4, 32, device=device)
196+
model1(x).sum().backward()
197+
optim1.step()
198+
optim1.zero_grad()
199+
200+
# save checkpoint. make sure it can be serialized by torch.save()
201+
with tempfile.NamedTemporaryFile() as file:
202+
torch.save(optim1.state_dict(), file.name)
203+
state_dict = torch.load(file.name)
204+
205+
# resume training
206+
model2 = copy.deepcopy(model1)
207+
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
208+
optim2.load_state_dict(state_dict)
209+
210+
for _ in range(2):
211+
x = torch.randn(4, 32, device=device)
212+
213+
model1(x).sum().backward()
214+
optim1.step()
215+
optim1.zero_grad()
216+
217+
model2(x).sum().backward()
218+
optim2.step()
219+
optim2.zero_grad()
220+
221+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
222+
torch.testing.assert_close(p2, p1)
223+
160224

161225
class TestFSDP2(FSDPTest):
162226
@property

torchao/prototype/low_bit_optim/README.md

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,67 @@ lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
4646

4747
(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
4848

49+
## Optimizer CPU offload
50+
51+
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
52+
53+
```python
54+
import torch
55+
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
56+
57+
model = ...
58+
59+
# only offload optimizer state
60+
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
61+
62+
# offload optimizer state AND gradients
63+
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradients=True, fused=True)
64+
```
65+
66+
This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer.
67+
68+
For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)
69+
70+
```python
71+
ckpt = torch.load("checkpoint.pth")
72+
73+
model = ...
74+
model.load_state_dict(ckpt["model"])
75+
76+
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
77+
optim.load_state_dict(ckpt["optim"])
78+
```
79+
80+
NOTE:
81+
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
82+
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
83+
- It is recommended NOT to `torch.compile()` your whole model when `CPUOffloadOptimizer` is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately. See [#584](https://github.com/pytorch/ao/pull/584) for more information.
84+
- CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) do full BF16 training (instead of AMP), so that parameters, gradients, and optimizer states are in BF16; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation).
85+
- `offload_gradients=True` is not compatible with gradient accumulation, since we clear gradients on GPU every backward pass.
86+
- Gradient clipping is currently not supported.
87+
88+
Benchmark done for `timm/vit_giant_patch14_dinov2.lvd142m` (1.1B params), eager mode, full BF16 training, activations checkpointing, batch size 32, on 4070Ti SUPER (16GB VRAM), Ryzen 5600, DDR4 RAM. DeepSpeed is untuned.
89+
90+
Adam offload | Time per step | Max memory
91+
-----------------------|---------------|------------
92+
None | 1.27s/it | 9.82 GB
93+
DeepSpeed ZeRO-Offload | 3.13s/it | 6.85 GB
94+
ao | 1.52s/it | 5.24 GB
95+
ao (offload gradients) | 1.53s/it | 4.01 GB
96+
97+
Ablations on AMP and `torch.compile()`
98+
99+
Training config | Adam offload | Time per step | Max memory
100+
--------------------|--------------|---------------|------------
101+
Full BF16, compiled | None | 1.18s/it | 9.90 GB
102+
Full BF16, compiled | ao | 1.75s/it | 5.33 GB
103+
BF16 AMP, eager | None | OOM | OOM
104+
BF16 AMP, eager | ao | 2.18s/it | 9.90 GB
105+
49106
## Credits
50107

51-
Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
108+
Credits to
109+
110+
- Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library.
111+
- [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
112+
- [DeepSpeed](https://github.com/microsoft/DeepSpeed) team for [ZeRO-Offload](https://arxiv.org/abs/2101.06840).
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .adam import Adam8bit, Adam4bit, AdamFp8
22
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
3+
from .cpu_offload import CPUOffloadOptimizer

0 commit comments

Comments
 (0)