Skip to content

Enable CPU Offload for Intel GPU #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
# - DeepSpeed (ZeRO-Offload):
# sudo apt install libopenmpi-dev
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
#
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
Expand All @@ -31,11 +31,15 @@
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchao.utils import get_available_devices
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype import low_bit_optim

_DEVICE = get_available_devices()[-1]
assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)"

OPTIM_MAP = dict(
AdamW=partial(torch.optim.AdamW, fused=True),
AdamW8bitBnb=bnb.optim.AdamW8bit,
Expand All @@ -49,7 +53,9 @@

OPTIM_MAP.update(
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
AdamW4bitRank1Lpmm=partial(
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
),
)

except ImportError:
Expand All @@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
1 + math.cos(progress * math.pi)
)
return self.final_lr


Expand All @@ -92,7 +102,9 @@ def get_parser():
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
parser.add_argument(
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
)

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand All @@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms.append(
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = datasets.load_dataset(
"timm/resisc45", split="train" if training else "validation"
)
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

Expand All @@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
)


def get_amp_ctx(amp):
def get_amp_ctx(amp, device):
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
return torch.autocast(device, dtype=dtype, enabled=amp != "none")


@torch.no_grad()
Expand All @@ -148,8 +164,8 @@ def evaluate_model(model, args):
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
with get_amp_ctx(args.amp, _DEVICE):
all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu())

all_labels = torch.cat(all_labels, dim=0)
all_preds = torch.cat(all_preds, dim=0)
Expand All @@ -164,8 +180,12 @@ def evaluate_model(model, args):
if args.full_bf16:
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
if args.optim_cpu_offload == "deepspeed":
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
assert (
args.amp == "none"
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert (
args.optim == "AdamW"
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
if args.profile:
args.n_epochs = 1
if args.seed is not None:
Expand All @@ -185,14 +205,16 @@ def evaluate_model(model, args):
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
model = timm.create_model(
args.model, pretrained=True, num_classes=45, **args.model_kwargs
)
if args.checkpoint_activations:
model.set_grad_checkpointing()
if args.full_bf16:
model.bfloat16()
if args.channels_last:
model.to(memory_format=torch.channels_last)
model.cuda() # move model to CUDA after optionally convert it to BF16
model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Expand Down Expand Up @@ -227,9 +249,15 @@ def evaluate_model(model, args):
optim_cls = OPTIM_MAP[args.optim]

if args.optim_cpu_offload == "ao":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
)
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer,
optimizer_class=optim_cls,
offload_gradients=True,
)

optim = optim_cls(
model.parameters(),
Expand All @@ -239,24 +267,30 @@ def evaluate_model(model, args):
)

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16")
log_interval = 10
t0 = time.perf_counter()

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
pbar = tqdm(
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
)

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
batch["image"] = batch["image"].to(
memory_format=torch.channels_last
)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
with get_amp_ctx(args.amp, _DEVICE):
loss = F.cross_entropy(
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
)

if args.optim_cpu_offload == "deepspeed":
model.backward(loss)
Expand All @@ -275,7 +309,9 @@ def evaluate_model(model, args):
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
log_dict["imgs_per_second"] = (
args.batch_size * log_interval / (t1 - t0)
)
t0 = t1
logger.log(log_dict, step=step)

Expand All @@ -296,9 +332,11 @@ def evaluate_model(model, args):

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
print(
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
)
logger.log(dict(val_acc=val_acc), step=step)

peak_mem = torch.cuda.max_memory_allocated() / 1e9
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
49 changes: 36 additions & 13 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
from torchao.utils import (
get_available_devices,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
Expand All @@ -42,7 +43,7 @@
lpmm = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_DEVICES = get_available_devices()


class TestQuantize(TestCase):
Expand Down Expand Up @@ -94,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
func = torch.compile(
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
)
x_rep_bf16 = func(x_rep)
assert x_rep_bf16.dtype is torch.bfloat16

Expand Down Expand Up @@ -169,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device):
tensor = subclass.zeros(shape, device=device)
offset = shape[0] // 2

torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
torch.testing.assert_close(
tensor.dequantize()[:offset], tensor[:offset].dequantize()
)
torch.testing.assert_close(
tensor.dequantize()[offset : offset * 2],
tensor[offset : offset * 2].dequantize(),
)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(
Expand All @@ -188,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name):
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down Expand Up @@ -244,11 +254,12 @@ def test_optim_4bit_correctness(self, optim_name):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU",
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)

Expand Down Expand Up @@ -279,13 +290,16 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU",
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -300,7 +314,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand Down Expand Up @@ -384,7 +400,11 @@ def _test_fsdp2(self, optim_cls):
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

batch_size = 3
vocab_size = 1024
Expand Down Expand Up @@ -457,7 +477,10 @@ def _test_fsdp2(self, optim_cls):

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)

for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
for v1, v2 in zip(
pytree.tree_iter(resumed_fsdp_optim.state_dict()),
pytree.tree_iter(fsdp_optim.state_dict()),
):
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, DTensor):
v1 = v1.to_local()
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ All of our low-bit optimizers mentioned above also support `bf16_stochastic_roun

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA and XPU is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.

```python
import torch
Expand All @@ -97,7 +97,7 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradi

This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer.

For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)
For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize GPU and CPU params in either direction CPU->GPU and GPU->CPU, in case they are out of sync.)

```python
ckpt = torch.load("checkpoint.pth")
Expand Down
Loading
Loading