Skip to content

Commit c9b397d

Browse files
authored
Merge pull request #40 from pytorch-labs/nf4_linear
Add noop detach for Nf4 tensor and enhance nf4 testing
2 parents aa94639 + 39fde91 commit c9b397d

File tree

2 files changed

+170
-4
lines changed

2 files changed

+170
-4
lines changed

test/modules/test_nf4_linear.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import logging
2+
import unittest
3+
4+
import torch
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase
7+
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
8+
import torch.nn.functional as F
9+
10+
11+
bnb_available = False
12+
13+
try:
14+
import bitsandbytes as bnb
15+
16+
bnb_available = True
17+
except ImportError:
18+
pass
19+
20+
logging.basicConfig(
21+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
22+
)
23+
24+
25+
def _build_input_weight(embed_dim: int, device: torch.device):
26+
torch.manual_seed(0)
27+
input_weight = torch.empty(
28+
embed_dim, embed_dim, device=device, dtype=torch.bfloat16
29+
)
30+
input_weight.normal_(0, 1)
31+
return input_weight
32+
33+
def _build_bnb_linear(input_weight, device):
34+
assert bnb_available, "Needs bitsandbytes support"
35+
param = bnb.nn.Params4bit(
36+
input_weight, requires_grad=False, quant_type="nf4"
37+
).cuda(device)
38+
bnb_linear = bnb.nn.LinearNF4(
39+
input_weight.size(0), input_weight.size(1), bias=False
40+
)
41+
bnb_linear.weight = param
42+
bnb_linear.to(device)
43+
return bnb_linear
44+
45+
46+
class TestNF4Linear(TestCase):
47+
48+
def test_register_nf4_as_param(self):
49+
nf4_tensor = NF4Tensor.from_tensor(
50+
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
51+
)
52+
53+
# Would raise if nn.Parameter registration fails, such as no detach()
54+
# impl when calling __torch_dispatch__
55+
param = torch.nn.Parameter(nf4_tensor, requires_grad=False)
56+
assert not param.requires_grad
57+
58+
def test_output_bf16(self):
59+
# Test to ensure W4 A16 produces A16
60+
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
61+
nf4_tensor = NF4Tensor.from_tensor(
62+
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
63+
)
64+
out = linear_nf4(input=inp, weight=nf4_tensor)
65+
assert out.dtype == torch.bfloat16
66+
67+
def test_backward_bf16(self):
68+
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
69+
# to the linear's weight, as it is frozen.
70+
nf4_tensor = NF4Tensor.from_tensor(
71+
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
72+
)
73+
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
74+
linear_nf4(inp, nf4_tensor).sum().backward()
75+
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
76+
assert nf4_tensor.grad is None
77+
78+
@unittest.skipIf(not bnb_available, "Need bnb availble")
79+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
80+
def test_reconstruction_qlora_vs_bnb(self):
81+
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
82+
torch.manual_seed(0)
83+
device = "cuda"
84+
embed_dim = 512
85+
input_weight = _build_input_weight(embed_dim, device)
86+
nf4_weight = NF4Tensor.from_tensor(input_weight)
87+
bnb_linear = _build_bnb_linear(input_weight, device)
88+
bnb_reconstruction = bnb_linear(
89+
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
90+
)
91+
bnb_diff = (bnb_reconstruction.T - input_weight).abs().max()
92+
nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max()
93+
# Since we are subtle different we assume that we both reconstruct with
94+
# a similar precision
95+
assert bnb_diff < 1
96+
assert nugs_diff < 1
97+
assert (nugs_diff - bnb_diff).abs() < 2e-1
98+
99+
@unittest.skipIf(not bnb_available, "Need bnb availble")
100+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
101+
def test_nf4_bnb_linear(self):
102+
"""
103+
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
104+
error compared to a bf16 linear is not more than BNB's implementation.
105+
"""
106+
torch.manual_seed(0)
107+
dim = 512
108+
device = "cuda"
109+
input_weight = _build_input_weight(dim, device)
110+
nf4_weight = NF4Tensor.from_tensor(input_weight)
111+
bnb_linear = _build_bnb_linear(input_weight, device)
112+
113+
inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
114+
115+
out_nf4 = linear_nf4(inp, nf4_weight).sum()
116+
out_bnb = bnb_linear(inp).sum()
117+
out_ref = F.linear(inp, input_weight).sum()
118+
119+
err_bnb = (out_bnb - out_ref).abs().max()
120+
err_native = (out_nf4 - out_ref).abs().max()
121+
assert err_native < 0.5 * dim
122+
assert err_bnb < 0.5 * dim
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,33 @@
22
from typing import Dict, Tuple
33

44
import torch
5+
from torch import Tensor
56
import torch.nn.functional as F
67

78

9+
aten = torch.ops.aten
10+
c10d_functional = torch.ops.c10d_functional
11+
12+
from typing import Any
13+
NF4_OPS_TABLE: Dict[Any, Any] = {}
14+
15+
16+
17+
def implements(aten_ops):
18+
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
19+
20+
def decorator(func):
21+
for op in aten_ops:
22+
NF4_OPS_TABLE[op] = func
23+
return func
24+
25+
return decorator
26+
27+
@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
28+
def noop_detach(func, *args, **kwargs):
29+
return args[0][0]
30+
31+
832
@dataclass
933
class SubclassTensorArgs:
1034
original_shape: torch.Size
@@ -110,7 +134,7 @@ def from_tensor(
110134
assert inpt_tensor.dtype == torch.bfloat16
111135
assert (
112136
inpt_tensor.numel() % block_size == 0
113-
), "Input tensor must be divisible by block size"
137+
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
114138
assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16"
115139
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
116140
# I think I want do this
@@ -204,7 +228,7 @@ def double_quantize_scalers(
204228
# Second round of quantization
205229
assert (
206230
scalers_1.numel() % scaler_block_size == 0
207-
), "Number of scalers must be divisible by scaler block size"
231+
), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
208232
n_scaler_blocks = scalers_1.numel() // scaler_block_size
209233
scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size)
210234

@@ -397,12 +421,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
397421
"""TODO we are not supporting torch dispatch at the moment
398422
instead we have created a Autograd.Function to handle the linear
399423
"""
400-
raise NotImplementedError("NF4Tensor does not support torch dispatch")
424+
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
425+
# And don't support mixed tensor subclasses. This will trigger the handler for
426+
# the next type in the dispatch list
427+
def allowed_subclasses(type):
428+
return (
429+
issubclass(cls, type)
430+
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
431+
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
432+
)
433+
434+
if not all(allowed_subclasses(t) for t in types):
435+
return NotImplemented("Up to the next one to handle")
436+
437+
if func in NF4_OPS_TABLE:
438+
return NF4_OPS_TABLE[func](func, args, kwargs)
439+
raise NotImplementedError(
440+
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
441+
)
401442

402443
# Do not force the Float8Tensor type on the returned tensor
403444
__torch_function__ = torch._C._disabled_torch_function_impl
404445

405-
406446
class LinearNF4(torch.autograd.Function):
407447
@staticmethod
408448
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):

0 commit comments

Comments
 (0)