Skip to content

Commit 12ac498

Browse files
authored
Add FP8 Adam (#482)
* update benchmark * add rank1 option to lpmm * add comma * update readme * remove unwanted file * update * add Adam fp8 * add FP8 AdamW and test * update readme * change reason to xfail, since 2.2 also have float8 * at guard for FP8 * update readme * fix guard
1 parent 56d46a2 commit 12ac498

File tree

7 files changed

+173
-8
lines changed

7 files changed

+173
-8
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@
2828
from torchvision.transforms import v2
2929
from tqdm import tqdm
3030

31-
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
31+
from torchao.prototype import low_bit_optim
3232

3333
# lpmm doesn't have Adam, only AdamW
3434
OPTIM_MAP = dict(
3535
Adam=torch.optim.Adam,
3636
Adam8bitBnb=bnb.optim.Adam8bit,
37-
Adam8bitAo=Adam8bit,
37+
Adam8bitAo=low_bit_optim.Adam8bit,
38+
AdamFp8Ao=low_bit_optim.AdamFp8,
3839
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
39-
Adam4bitAo=Adam4bit,
40+
Adam4bitAo=low_bit_optim.Adam4bit,
4041
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
4142
)
4243

test/prototype/test_low_bit_optim.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def test_optim_4bit_correctness(self, optim_name):
139139
for p1, p2 in zip(model1.parameters(), model2.parameters()):
140140
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
141141

142+
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
143+
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
144+
@parametrize("device", _DEVICES)
145+
def test_optim_fp8_smoke(self, optim_name, device):
146+
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
147+
pytest.skip("FP8 requires compute capability >= 8.9")
148+
149+
model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
150+
optim = getattr(low_bit_optim, optim_name)(model.parameters())
151+
152+
x = torch.randn(4, 32, device=device)
153+
loss = model(x).sum()
154+
loss.backward()
155+
optim.step()
156+
optim.zero_grad()
157+
142158

143159
instantiate_parametrized_tests(TestQuantize)
144160
instantiate_parametrized_tests(TestOptim)

torchao/prototype/low_bit_optim/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This folder implements:
44

55
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
66
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
7+
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)
78

89
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
910

@@ -18,12 +19,12 @@ model = ...
1819
optim = Adam8bit(model.parameters())
1920
```
2021

21-
To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.
22+
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
2223

23-
**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.
24+
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.
2425

2526
NOTE:
26-
- The low-bit optimizers require PyTorch >= 2.3
27+
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
2728
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
2829
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.
2930

@@ -38,6 +39,7 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
3839
PyTorch | 12.94 | 8m 18s | 91.14
3940
bnb 8-bit | 8.31 | 6m 50s | 90.67
4041
ao 8-bit | 8.32 | 9m 04s | 90.71
42+
ao FP8 E4M3 | 8.32 | 6m 38s | 91.08
4143
lpmm 4-bit | 7.72 | 5m 59s | 89.97
4244
ao 4-bit | 7.72 | 7m 00s | 89.94
4345
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .adam import Adam8bit, Adam4bit
2-
from .adamw import AdamW8bit, AdamW4bit
1+
from .adam import Adam8bit, Adam4bit, AdamFp8
2+
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8

torchao/prototype/low_bit_optim/adam.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .subclass_8bit import maybe_new_8bit_zero_buffer
88
from .subclass_4bit import maybe_new_4bit_zero_buffer
9+
from .subclass_fp8 import maybe_new_fp8_zero_buffer
910

1011

1112
class _Adam(Optimizer):
@@ -155,3 +156,22 @@ def __init__(
155156
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)
156157

157158
_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)
159+
160+
161+
class AdamFp8(_Adam):
162+
def __init__(
163+
self,
164+
params,
165+
lr=1e-3,
166+
betas=(0.9, 0.999),
167+
eps=1e-8,
168+
weight_decay=0,
169+
amsgrad=False,
170+
*,
171+
block_size=2048
172+
) -> None:
173+
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)
174+
175+
@staticmethod
176+
def _new_buffer(p: Tensor, signed: bool, block_size: int):
177+
return maybe_new_fp8_zero_buffer(p, block_size)

torchao/prototype/low_bit_optim/adamw.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .subclass_8bit import maybe_new_8bit_zero_buffer
88
from .subclass_4bit import maybe_new_4bit_zero_buffer
9+
from .subclass_fp8 import maybe_new_fp8_zero_buffer
910

1011

1112
class _AdamW(Optimizer):
@@ -154,3 +155,22 @@ def __init__(
154155
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)
155156

156157
_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)
158+
159+
160+
class AdamWFp8(_AdamW):
161+
def __init__(
162+
self,
163+
params,
164+
lr=1e-3,
165+
betas=(0.9, 0.999),
166+
eps=1e-8,
167+
weight_decay=1e-2,
168+
amsgrad=False,
169+
*,
170+
block_size=2048
171+
) -> None:
172+
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)
173+
174+
@staticmethod
175+
def _new_buffer(p: Tensor, signed: bool, block_size: int):
176+
return maybe_new_fp8_zero_buffer(p, block_size)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from torch import Tensor
3+
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
4+
5+
6+
aten = torch.ops.aten
7+
DTYPE = torch.float8_e4m3fn
8+
9+
10+
def quantize_fp8(input: Tensor, block_size: int):
11+
input = input.view(-1, block_size)
12+
scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max
13+
input = input / scale.view(-1, 1)
14+
codes = input.to(DTYPE).view(-1)
15+
return codes, scale
16+
17+
18+
class OptimStateFp8(Tensor):
19+
implements = classmethod(_implements)
20+
tensor_attrs = ["codes", "scale"]
21+
22+
@staticmethod
23+
def __new__(cls, codes: Tensor, scale: Tensor):
24+
return Tensor._make_wrapper_subclass(
25+
cls,
26+
codes.shape,
27+
device=codes.device,
28+
requires_grad=False,
29+
)
30+
31+
def __init__(self, codes: Tensor, scale: Tensor):
32+
assert codes.dtype is DTYPE
33+
self.codes = codes
34+
self.scale = scale
35+
36+
@property
37+
def block_size(self):
38+
return self.codes.numel() // self.scale.numel()
39+
40+
def __tensor_flatten__(self):
41+
return self.tensor_attrs, []
42+
43+
@classmethod
44+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
45+
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
46+
47+
def dequantize(self, output_dtype=None):
48+
float_data = self.codes.float()
49+
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)
50+
51+
dtype = output_dtype or torch.get_default_dtype()
52+
return float_data.view(self.codes.shape).to(dtype)
53+
54+
@classmethod
55+
def zeros(cls, shape, block_size: int = 2048, device=None):
56+
codes = torch.zeros(shape, dtype=DTYPE, device=device)
57+
scale = torch.zeros(codes.numel() // block_size, device=device)
58+
return cls(codes, scale)
59+
60+
def __repr__(self):
61+
return (
62+
f"{self.__class__.__name__}(block_size={self.block_size}, "
63+
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
64+
)
65+
66+
@classmethod
67+
def __torch_dispatch__(cls, func, types, args, kwargs):
68+
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
69+
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
70+
71+
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
72+
73+
74+
@OptimStateFp8.implements(aten.copy_.default)
75+
def _(func, *args, **kwargs):
76+
dst = args[0]
77+
src = args[1]
78+
79+
if isinstance(dst, OptimStateFp8) and isinstance(src, OptimStateFp8):
80+
assert dst.block_size == src.block_size
81+
dst.codes.copy_(src.codes)
82+
dst.scale.copy_(src.scale)
83+
84+
elif isinstance(dst, OptimStateFp8):
85+
codes, scale = quantize_fp8(src, dst.block_size)
86+
dst.codes.copy_(codes)
87+
dst.scale.copy_(scale)
88+
89+
else:
90+
dst.copy_(src.dequantize())
91+
92+
return dst
93+
94+
95+
@OptimStateFp8.implements(aten.lerp.Scalar)
96+
def _(func, *args, **kwargs):
97+
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
98+
return func(*args, **kwargs)
99+
100+
101+
def maybe_new_fp8_zero_buffer(p: Tensor, block_size: int = 2048):
102+
if p.numel() >= 4096 and p.numel() % block_size == 0:
103+
out = OptimStateFp8.zeros(p.shape, block_size, device=p.device)
104+
else:
105+
out = torch.zeros_like(p)
106+
return out

0 commit comments

Comments
 (0)