Skip to content

Commit 1cbabc2

Browse files
committed
Add Nf4Linear and tests
1 parent 969038f commit 1cbabc2

File tree

4 files changed

+213
-3
lines changed

4 files changed

+213
-3
lines changed

test/modules/test_nf4_linear.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import logging
2+
import unittest
3+
4+
import torch
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase
7+
from torchao.modules import FrozenNF4Linear
8+
from torchao.dtypes.nf4tensor import NF4Tensor
9+
10+
bnb_available = False
11+
12+
try:
13+
import bitsandbytes as bnb
14+
bnb_available = True
15+
except ImportError:
16+
pass
17+
18+
logging.basicConfig(
19+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
20+
)
21+
22+
class TestNF4Linear(TestCase):
23+
"""
24+
Test torchao.modules.NF4Linear
25+
"""
26+
def test_bias_unsupported(self):
27+
with self.assertRaisesRegex(RuntimeError, "does not currently support biases"):
28+
_ = FrozenNF4Linear(1, 1, bias=True)
29+
30+
def test_non_bf16_unsupported(self):
31+
with self.assertRaisesRegex(RuntimeError, "only supported with bf16"):
32+
_ = FrozenNF4Linear(1, 1)
33+
34+
def test_frozen_nf4_linear(self):
35+
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16)
36+
self.assertTrue(isinstance(nf4_linear.weight, NF4Tensor))
37+
self.assertEqual(torch.bfloat16, nf4_linear.weight.get_original_weight().dtype)
38+
39+
def test_output_bf16(self):
40+
# Test to ensure W4 A16 produces A16
41+
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16)
42+
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
43+
out = nf4_linear(inp)
44+
assert out.dtype == torch.bfloat16
45+
46+
def test_backward_bf16(self):
47+
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
48+
# to the linear's weight, as it is frozen.
49+
nf4_linear = FrozenNF4Linear(512, 512, device='cpu', dtype=torch.bfloat16)
50+
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
51+
nf4_linear(inp).sum().backward()
52+
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
53+
assert nf4_linear.weight.grad is None
54+
55+
56+
def _build_bnb_linear(self, input_weight):
57+
assert bnb_available, "Needs bitsandbytes support"
58+
param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4")
59+
bnb_linear = bnb.nn.LinearNF4(input_weight.size(0), input_weight.size(1), bias=False)
60+
bnb_linear.weight = param
61+
bnb_linear.cuda()
62+
return bnb_linear
63+
64+
@unittest.skipIf(not bnb_available, "Need bnb availble")
65+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
66+
def test_fwd_bnb_parity(self):
67+
"""
68+
Ensures fwd + backward logits and grads are at parity w/bnb
69+
"""
70+
nf4_linear = FrozenNF4Linear(512, 512, device='cuda', dtype=torch.bfloat16)
71+
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
72+
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight)
73+
74+
inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda', requires_grad=True)
75+
with torch.no_grad():
76+
inp_bnb = inp.clone()
77+
inp_bnb.requires_grad_(True)
78+
out_native = nf4_linear(inp).sum()
79+
out_bnb = bnb_nf4_linear(inp_bnb).sum()
80+
self.assertEqual(out_native, out_bnb)
81+
82+
83+
@unittest.skipIf(not bnb_available, "Need bnb availble")
84+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
85+
def test_nf4_reconstruction_vs_bnb(self):
86+
"""
87+
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when
88+
reconstructing the respective original weights.
89+
"""
90+
dim = 512
91+
nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16)
92+
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
93+
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight)
94+
95+
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
96+
bnb_reconstruction = bnb_nf4_linear(
97+
torch.eye(dim, dim, dtype=torch.bfloat16, device='cuda')
98+
)
99+
# Ensure nf4_linear and bnb reconstructions are close to each other.
100+
diff = (bnb_reconstruction.T - nf4_linear.weight.get_original_weight()).abs().max()
101+
assert diff.item() < 1e-2
102+
103+
@unittest.skipIf(not bnb_available, "Need bnb availble")
104+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
105+
def test_nf4_bnb_linear(self):
106+
"""
107+
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
108+
error compared to a bf16 linear is not more than BNB's implementation.
109+
"""
110+
dim = 512
111+
nf4_linear = FrozenNF4Linear(dim, dim, device='cuda', dtype=torch.bfloat16)
112+
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
113+
bnb_nf4_linear = self._build_bnb_linear(input_weight=orig_weight)
114+
bf16_linear = torch.nn.Linear(dim, dim, device='cuda', dtype=torch.bfloat16)
115+
116+
inp = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda')
117+
118+
out_nf4 = nf4_linear(inp).sum()
119+
out_bnb = bnb_nf4_linear(inp).sum()
120+
out_ref = bf16_linear(inp).sum()
121+
122+
err_bnb = (out_bnb - out_ref).abs().max()
123+
err_native = (out_nf4 - out_ref).abs().max()
124+
assert err_native.item() <= err_bnb
125+
126+
127+
128+
if __name__ == "__main__":
129+
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,33 @@
22
from typing import Dict, Tuple
33

44
import torch
5+
from torch import Tensor
56
import torch.nn.functional as F
67

78

9+
aten = torch.ops.aten
10+
c10d_functional = torch.ops.c10d_functional
11+
12+
from typing import Any
13+
NF4_OPS_TABLE: Dict[Any, Any] = {}
14+
15+
16+
17+
def implements(aten_ops):
18+
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
19+
20+
def decorator(func):
21+
for op in aten_ops:
22+
NF4_OPS_TABLE[op] = func
23+
return func
24+
25+
return decorator
26+
27+
@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
28+
def noop_detach(func, *args, **kwargs):
29+
return args[0][0]
30+
31+
832
@dataclass
933
class SubclassTensorArgs:
1034
original_shape: torch.Size
@@ -110,7 +134,7 @@ def from_tensor(
110134
assert inpt_tensor.dtype == torch.bfloat16
111135
assert (
112136
inpt_tensor.numel() % block_size == 0
113-
), "Input tensor must be divisible by block size"
137+
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
114138
assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16"
115139
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
116140
# I think I want do this
@@ -204,7 +228,7 @@ def double_quantize_scalers(
204228
# Second round of quantization
205229
assert (
206230
scalers_1.numel() % scaler_block_size == 0
207-
), "Number of scalers must be divisible by scaler block size"
231+
), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
208232
n_scaler_blocks = scalers_1.numel() // scaler_block_size
209233
scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size)
210234

@@ -397,12 +421,33 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
397421
"""TODO we are not supporting torch dispatch at the moment
398422
instead we have created a Autograd.Function to handle the linear
399423
"""
400-
raise NotImplementedError("NF4Tensor does not support torch dispatch")
424+
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
425+
# And don't support mixed tensor subclasses. This will trigger the handler for
426+
# the next type in the dispatch list
427+
def allowed_subclasses(type):
428+
return (
429+
issubclass(cls, type)
430+
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
431+
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
432+
)
433+
434+
if not all(allowed_subclasses(t) for t in types):
435+
return NotImplemented("Up to the next one to handle")
436+
437+
if func in NF4_OPS_TABLE:
438+
return NF4_OPS_TABLE[func](func, args, kwargs)
439+
raise NotImplementedError(
440+
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
441+
)
401442

402443
# Do not force the Float8Tensor type on the returned tensor
403444
__torch_function__ = torch._C._disabled_torch_function_impl
404445

405446

447+
def to_nf4(tensor: Tensor, device: torch.device, dtype: torch.dtype):
448+
self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype)
449+
pass
450+
406451
class LinearNF4(torch.autograd.Function):
407452
@staticmethod
408453
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):

torchao/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .nf4_linear import FrozenNF4Linear

torchao/modules/nf4_linear.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
import torch.nn as nn
4+
from torch import Tensor
5+
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4
6+
7+
8+
class FrozenNF4Linear(nn.Linear):
9+
"""
10+
A linear layer similar to ``torch.nn.Linear`` but uses a quantized
11+
NF4Tensor as its weight. This class also freezes its ``weight`` parameter
12+
and is meant to be used as the base Linear layer for modeling
13+
use cases such as QLoRA where base model parameters are frozen.
14+
15+
NOTE: biases are currently not supported.
16+
"""
17+
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, device=None, dtype=None, **kwargs):
18+
if bias:
19+
raise RuntimeError("FrozenNF4Linear does not currently support biases!")
20+
21+
super().__init__(in_dim, out_dim, device=device, dtype=dtype, **kwargs)
22+
self.weight.requires_grad_(False)
23+
if self.weight.dtype != torch.bfloat16:
24+
raise RuntimeError("FrozenNF4Linear is only supported with bf16 parameter currently")
25+
26+
self.nf4_weight = NF4Tensor.from_tensor(self.weight.data).to(device).to(dtype)
27+
# re-register self.weight as the nf4 weight's original precision
28+
del self.weight
29+
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False)
30+
31+
# TODO: likely need to handle state_dict save & load via hooks to properly manage
32+
# types.
33+
34+
def forward(self, input: Tensor) -> Tensor:
35+
return linear_nf4(input=input, weight=self.weight)

0 commit comments

Comments
 (0)