Skip to content

Commit 649b00b

Browse files
committed
refactor(cpu-offload): use ruff format
Signed-off-by: dbyoung18 <yang5.yang@intel.com>
1 parent d9cce7b commit 649b00b

File tree

3 files changed

+84
-27
lines changed

3 files changed

+84
-27
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@
5353

5454
OPTIM_MAP.update(
5555
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
56-
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
56+
AdamW4bitRank1Lpmm=partial(
57+
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
58+
),
5759
)
5860

5961
except ImportError:
@@ -71,8 +73,12 @@ def get_lr(self, step: int) -> float:
7173
if step < self.warmup_steps:
7274
return self.lr * step / self.warmup_steps
7375
if step < self.total_steps:
74-
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
75-
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
76+
progress = (step - self.warmup_steps) / (
77+
self.total_steps - self.warmup_steps
78+
)
79+
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
80+
1 + math.cos(progress * math.pi)
81+
)
7682
return self.final_lr
7783

7884

@@ -96,7 +102,9 @@ def get_parser():
96102
parser.add_argument("--weight_decay", type=float, default=0)
97103
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
98104
parser.add_argument("--cosine_lr_scheduler", action="store_true")
99-
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
105+
parser.add_argument(
106+
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
107+
)
100108

101109
parser.add_argument("--project")
102110
parser.add_argument("--run_name", default="debug")
@@ -114,11 +122,15 @@ def get_dloader(args, training: bool):
114122
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])
115123

116124
transforms.append(v2.ToDtype(torch.float32, scale=True))
117-
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
125+
transforms.append(
126+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127+
)
118128
transforms = v2.Compose(transforms)
119129

120130
# use dataset from HF so download is fast
121-
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
131+
ds = datasets.load_dataset(
132+
"timm/resisc45", split="train" if training else "validation"
133+
)
122134
ds = ds.select_columns(["image", "label"])
123135
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))
124136

@@ -168,8 +180,12 @@ def evaluate_model(model, args):
168180
if args.full_bf16:
169181
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
170182
if args.optim_cpu_offload == "deepspeed":
171-
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
172-
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
183+
assert (
184+
args.amp == "none"
185+
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
186+
assert (
187+
args.optim == "AdamW"
188+
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
173189
if args.profile:
174190
args.n_epochs = 1
175191
if args.seed is not None:
@@ -189,7 +205,9 @@ def evaluate_model(model, args):
189205
dloader = get_dloader(args, True)
190206
print(f"Train dataset: {len(dloader.dataset):,} images")
191207

192-
model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
208+
model = timm.create_model(
209+
args.model, pretrained=True, num_classes=45, **args.model_kwargs
210+
)
193211
if args.checkpoint_activations:
194212
model.set_grad_checkpointing()
195213
if args.full_bf16:
@@ -231,9 +249,15 @@ def evaluate_model(model, args):
231249
optim_cls = OPTIM_MAP[args.optim]
232250

233251
if args.optim_cpu_offload == "ao":
234-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
252+
optim_cls = partial(
253+
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254+
)
235255
elif args.optim_cpu_offload == "ao_offload_grads":
236-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
256+
optim_cls = partial(
257+
low_bit_optim.CPUOffloadOptimizer,
258+
optimizer_class=optim_cls,
259+
offload_gradients=True,
260+
)
237261

238262
optim = optim_cls(
239263
model.parameters(),
@@ -250,17 +274,23 @@ def evaluate_model(model, args):
250274
step = 0
251275
for epoch_idx in range(args.n_epochs):
252276
model.train()
253-
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
277+
pbar = tqdm(
278+
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
279+
)
254280

255281
with torch.profiler.profile() if args.profile else nullcontext() as prof:
256282
for batch in pbar:
257283
if args.full_bf16:
258284
batch["image"] = batch["image"].bfloat16()
259285
if args.channels_last:
260-
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
286+
batch["image"] = batch["image"].to(
287+
memory_format=torch.channels_last
288+
)
261289

262290
with get_amp_ctx(args.amp, _DEVICE):
263-
loss = F.cross_entropy(model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE))
291+
loss = F.cross_entropy(
292+
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
293+
)
264294

265295
if args.optim_cpu_offload == "deepspeed":
266296
model.backward(loss)
@@ -279,7 +309,9 @@ def evaluate_model(model, args):
279309
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
280310
if step > 0:
281311
t1 = time.perf_counter()
282-
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
312+
log_dict["imgs_per_second"] = (
313+
args.batch_size * log_interval / (t1 - t0)
314+
)
283315
t0 = t1
284316
logger.log(log_dict, step=step)
285317

@@ -300,7 +332,9 @@ def evaluate_model(model, args):
300332

301333
else:
302334
val_acc = evaluate_model(model, args)
303-
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
335+
print(
336+
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
337+
)
304338
logger.log(dict(val_acc=val_acc), step=step)
305339

306340
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9

test/prototype/test_low_bit_optim.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile):
9595
x = torch.rand(32, device=device) * 100
9696
x_rep = x.view(-1, 1).repeat(1, 100_000)
9797

98-
func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
98+
func = torch.compile(
99+
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
100+
)
99101
x_rep_bf16 = func(x_rep)
100102
assert x_rep_bf16.dtype is torch.bfloat16
101103

@@ -170,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device):
170172
tensor = subclass.zeros(shape, device=device)
171173
offset = shape[0] // 2
172174

173-
torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
174-
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
175+
torch.testing.assert_close(
176+
tensor.dequantize()[:offset], tensor[:offset].dequantize()
177+
)
178+
torch.testing.assert_close(
179+
tensor.dequantize()[offset : offset * 2],
180+
tensor[offset : offset * 2].dequantize(),
181+
)
175182

176183
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
177184
@pytest.mark.skipif(
@@ -189,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name):
189196
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
190197

191198
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
192-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
199+
optim2 = getattr(low_bit_optim, optim_name)(
200+
model2.parameters(), block_size=block_size
201+
)
193202

194203
for _ in range(2):
195204
x = torch.randn(4, 32, device=device)
@@ -246,7 +255,7 @@ def test_optim_4bit_correctness(self, optim_name):
246255

247256
@pytest.mark.skipif(
248257
not torch.cuda.is_available() and not torch.xpu.is_available(),
249-
reason="optim CPU offload requires CUDA or XPU"
258+
reason="optim CPU offload requires CUDA or XPU",
250259
)
251260
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
252261
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
@@ -282,13 +291,15 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
282291

283292
@pytest.mark.skipif(
284293
not torch.cuda.is_available() and not torch.xpu.is_available(),
285-
reason="optim CPU offload requires CUDA or XPU"
294+
reason="optim CPU offload requires CUDA or XPU",
286295
)
287296
def test_optim_cpu_offload_save_load(self):
288297
device = _DEVICES[-1]
289298
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
290299
model1.to(device)
291-
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
300+
optim1 = low_bit_optim.CPUOffloadOptimizer(
301+
model1.parameters(), torch.optim.AdamW
302+
)
292303

293304
for _ in range(2):
294305
x = torch.randn(4, 32, device=device)
@@ -303,7 +314,9 @@ def test_optim_cpu_offload_save_load(self):
303314

304315
# resume training
305316
model2 = copy.deepcopy(model1)
306-
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
317+
optim2 = low_bit_optim.CPUOffloadOptimizer(
318+
model2.parameters(), torch.optim.AdamW
319+
)
307320
optim2.load_state_dict(state_dict)
308321

309322
for _ in range(2):
@@ -387,7 +400,11 @@ def _test_fsdp2(self, optim_cls):
387400
import torch.utils._pytree as pytree
388401
from torch.distributed._composable.fsdp import fully_shard
389402
from torch.distributed.tensor import DTensor
390-
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
403+
from torch.testing._internal.distributed._tensor.common_dtensor import (
404+
ModelArgs,
405+
Transformer,
406+
TransformerBlock,
407+
)
391408

392409
batch_size = 3
393410
vocab_size = 1024
@@ -460,7 +477,10 @@ def _test_fsdp2(self, optim_cls):
460477

461478
subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)
462479

463-
for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
480+
for v1, v2 in zip(
481+
pytree.tree_iter(resumed_fsdp_optim.state_dict()),
482+
pytree.tree_iter(fsdp_optim.state_dict()),
483+
):
464484
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
465485
if isinstance(v1, DTensor):
466486
v1 = v1.to_local()

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def __init__(
4141
self.param_d2h_map = dict()
4242
self.optim_dict = dict()
4343
self.device = get_available_devices()[-1]
44-
assert self.device in ["cuda", "xpu"], "CPU Offload currently only supports CUDA & XPU"
44+
assert self.device in [
45+
"cuda",
46+
"xpu",
47+
], "CPU Offload currently only supports CUDA & XPU"
4548
self.stream = getattr(torch, self.device).Stream()
4649

4750
# the queue maintains the order which param we should do optim step on first.

0 commit comments

Comments
 (0)