Skip to content

[Low-bit optim] Support for dcp.save() and dcp.load() #1217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 9, 2024
156 changes: 98 additions & 58 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import copy
import shutil
import tempfile
from pathlib import Path

import pytest
import torch
from packaging.version import Version
from torch import nn
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest

from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim.quant_utils import (
quantize_8bit_with_qmap,
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
quantize_4bit_with_qmap,
quantize_8bit_with_qmap,
)
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)

Expand Down Expand Up @@ -88,23 +94,15 @@ def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
x_rep
)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
x_rep_bf16 = func(x_rep)
assert x_rep_bf16.dtype is torch.bfloat16

# must cast BF16 tensor back to FP32 so that .mean() is accurate
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)


class TestOptim(TestCase):
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize(
"optim_name",
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
Expand Down Expand Up @@ -151,29 +149,46 @@ def test_optim_smoke(self, optim_name, dtype, device):
for p1, p2 in zip(model.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
# however, it's cumbersome to test it directly, since we would need to run distributed
# test 2 times with different world size, and persist checkpoint across the 2 runs.
# thus, we only test for the required op. note that future implementations of dcp.load()
# may use other ops.
@parametrize("subclass", [OptimState4bit, OptimState8bit, OptimStateFp8])
@parametrize("shape", [(4096,), (256, 256)])
@parametrize("device", _DEVICES)
def test_subclass_slice(self, subclass, shape, device):
if subclass == OptimStateFp8:
if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5")
if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 CUDA requires compute capability >= 8.9")

tensor = subclass.zeros(shape, device=device)
offset = shape[0] // 2

torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -196,15 +211,11 @@ def test_optim_8bit_correctness(self, optim_name):
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
Expand Down Expand Up @@ -238,12 +249,11 @@ def test_optim_4bit_correctness(self, optim_name):
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1[0].requires_grad_(
False
) # make sure it can work in the presence of non-trainable params
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)

# make sure it can work in the presence of non-trainable params
model1[0].requires_grad_(False)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
Expand Down Expand Up @@ -273,12 +283,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -293,9 +300,7 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand All @@ -315,16 +320,17 @@ def test_optim_cpu_offload_save_load(self):
def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
model2.parameters(),
lr=1e-5,
bf16_stochastic_round=True,
)

# overfit on this sample
Expand All @@ -350,10 +356,13 @@ def test_optim_bf16_stochastic_round_correctness(self):
)


_FSDP_WORLD_SIZE = 2


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2
return _FSDP_WORLD_SIZE

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
Expand All @@ -370,12 +379,12 @@ def test_fsdp2(self):
)

def _test_fsdp2(self, optim_cls):
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock

batch_size = 3
vocab_size = 1024
Expand Down Expand Up @@ -413,9 +422,7 @@ def _test_fsdp2(self, optim_cls):
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

Expand All @@ -428,6 +435,39 @@ def _test_fsdp2(self, optim_cls):

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())

# test for compatibility with dcp.save() and .load()
checkpoint_id = f"_fsdp_low_bit_optim_{optim_cls.__name__}"
if Path(checkpoint_id).exists():
shutil.rmtree(checkpoint_id)
dcp.save(fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)

# normally we would want to use dcp.state_dict.get_optimizer_state_dict() to initialize optim states.
# however, currently it does not respect tensor-ness of LR pytorch/pytorch#139575.
# therefore, we have to manually initialize optim state here.
resumed_fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)
for p in fsdp_model.parameters():
p.grad = torch.zeros_like(p)

# this will change model weights due to weight decay, but since we don't use the model anymore, it's fine.
resumed_fsdp_optim.step()

dcp.load(resumed_fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)
if dist.get_rank() == 0:
shutil.rmtree(checkpoint_id)

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)

for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, DTensor):
v1 = v1.to_local()
v2 = v2.to_local()
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, subclasses):
v1 = v1.dequantize()
v2 = v2.dequantize()
self.assertEqual(v1, v2)


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
51 changes: 50 additions & 1 deletion torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@ def _(func, types, args, kwargs):
)


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements(
[
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
)
def _(func, types, args, kwargs):
Expand All @@ -201,6 +203,53 @@ def _(func, types, args, kwargs):
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)


# required by torch.distributed.checkpoint.save
# note that we don't actually implement pin memory for this tensor subclass
# (pin_memory argument is ignored in aten._to_copy)
@OptimState4bit.implements(aten.is_pinned.default)
def _(func, types, args, kwargs):
return (
args[0].codes.is_pinned()
and args[0].scale.is_pinned()
and args[0].qmap.is_pinned()
)


# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding
@OptimState4bit.implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
x, dim, start, end = args[:4]
step = args[4] if len(args) > 4 else 1

# input validation
if dim != 0:
raise ValueError("Only support aten.slice along the first dim")
if step != 1:
raise ValueError("Only support aten.slice with step=1")

block_size = x.block_size
stride = math.prod(x.shape[1:])

# for 1 increment in x along the first dim,
# (flattened) scale will increment by stride / block_size
if (start * stride) % block_size != 0 or (end * stride) % block_size != 0:
raise ValueError(
f"Invalid start or end for shape={x.shape} and block_size={block_size}. "
f"Make sure start and end align with block boundary. "
f"Received start={start}, end={end}."
)

# note that for 4-bit, we store .codes as flattened buffer
# divide by 2 since we store 2x 4-bit in 1x uint8
codes = x.codes[start * stride // 2 : end * stride // 2]
scale = x.scale[start * stride // block_size : end * stride // block_size]

# adjust the first dim
shape = (x.shape[0] * codes.numel() // x.codes.numel(),) + x.shape[1:]

return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)


if TORCH_VERSION_AT_LEAST_2_5:
from torch.serialization import add_safe_globals

Expand Down
Loading
Loading