|
| 1 | +import logging |
| 2 | +import unittest |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from torch.testing._internal.common_utils import TestCase |
| 7 | +from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor |
| 8 | +import torch.nn.functional as F |
| 9 | + |
| 10 | + |
| 11 | +bnb_available = False |
| 12 | + |
| 13 | +try: |
| 14 | + import bitsandbytes as bnb |
| 15 | + |
| 16 | + bnb_available = True |
| 17 | +except ImportError: |
| 18 | + pass |
| 19 | + |
| 20 | +logging.basicConfig( |
| 21 | + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +def _build_input_weight(embed_dim: int, device: torch.device): |
| 26 | + torch.manual_seed(0) |
| 27 | + input_weight = torch.empty( |
| 28 | + embed_dim, embed_dim, device=device, dtype=torch.bfloat16 |
| 29 | + ) |
| 30 | + input_weight.normal_(0, 1) |
| 31 | + return input_weight |
| 32 | + |
| 33 | +def _build_bnb_linear(input_weight, device): |
| 34 | + assert bnb_available, "Needs bitsandbytes support" |
| 35 | + param = bnb.nn.Params4bit( |
| 36 | + input_weight, requires_grad=False, quant_type="nf4" |
| 37 | + ).cuda(device) |
| 38 | + bnb_linear = bnb.nn.LinearNF4( |
| 39 | + input_weight.size(0), input_weight.size(1), bias=False |
| 40 | + ) |
| 41 | + bnb_linear.weight = param |
| 42 | + bnb_linear.to(device) |
| 43 | + return bnb_linear |
| 44 | + |
| 45 | + |
| 46 | +class TestNF4Linear(TestCase): |
| 47 | + |
| 48 | + def test_register_nf4_as_param(self): |
| 49 | + nf4_tensor = NF4Tensor.from_tensor( |
| 50 | + inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) |
| 51 | + ) |
| 52 | + |
| 53 | + # Would raise if nn.Parameter registration fails, such as no detach() |
| 54 | + # impl when calling __torch_dispatch__ |
| 55 | + param = torch.nn.Parameter(nf4_tensor, requires_grad=False) |
| 56 | + assert not param.requires_grad |
| 57 | + |
| 58 | + def test_output_bf16(self): |
| 59 | + # Test to ensure W4 A16 produces A16 |
| 60 | + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) |
| 61 | + nf4_tensor = NF4Tensor.from_tensor( |
| 62 | + inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) |
| 63 | + ) |
| 64 | + out = linear_nf4(input=inp, weight=nf4_tensor) |
| 65 | + assert out.dtype == torch.bfloat16 |
| 66 | + |
| 67 | + def test_backward_bf16(self): |
| 68 | + # Test to ensure backward pass gives activation a bf16 gradient and no gradient |
| 69 | + # to the linear's weight, as it is frozen. |
| 70 | + nf4_tensor = NF4Tensor.from_tensor( |
| 71 | + inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16) |
| 72 | + ) |
| 73 | + inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) |
| 74 | + linear_nf4(inp, nf4_tensor).sum().backward() |
| 75 | + assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 |
| 76 | + assert nf4_tensor.grad is None |
| 77 | + |
| 78 | + @unittest.skipIf(not bnb_available, "Need bnb availble") |
| 79 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 80 | + def test_reconstruction_qlora_vs_bnb(self): |
| 81 | + # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 |
| 82 | + torch.manual_seed(0) |
| 83 | + device = "cuda" |
| 84 | + embed_dim = 512 |
| 85 | + input_weight = _build_input_weight(embed_dim, device) |
| 86 | + nf4_weight = NF4Tensor.from_tensor(input_weight) |
| 87 | + bnb_linear = _build_bnb_linear(input_weight, device) |
| 88 | + bnb_reconstruction = bnb_linear( |
| 89 | + torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device) |
| 90 | + ) |
| 91 | + bnb_diff = (bnb_reconstruction.T - input_weight).abs().max() |
| 92 | + nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max() |
| 93 | + # Since we are subtle different we assume that we both reconstruct with |
| 94 | + # a similar precision |
| 95 | + assert bnb_diff < 1 |
| 96 | + assert nugs_diff < 1 |
| 97 | + assert (nugs_diff - bnb_diff).abs() < 2e-1 |
| 98 | + |
| 99 | + @unittest.skipIf(not bnb_available, "Need bnb availble") |
| 100 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 101 | + def test_nf4_bnb_linear(self): |
| 102 | + """ |
| 103 | + This test ensures that nf4_linear is "no worse" than BNB by ensuring the |
| 104 | + error compared to a bf16 linear is not more than BNB's implementation. |
| 105 | + """ |
| 106 | + torch.manual_seed(0) |
| 107 | + dim = 512 |
| 108 | + device = "cuda" |
| 109 | + input_weight = _build_input_weight(dim, device) |
| 110 | + nf4_weight = NF4Tensor.from_tensor(input_weight) |
| 111 | + bnb_linear = _build_bnb_linear(input_weight, device) |
| 112 | + |
| 113 | + inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") |
| 114 | + |
| 115 | + out_nf4 = linear_nf4(inp, nf4_weight).sum() |
| 116 | + out_bnb = bnb_linear(inp).sum() |
| 117 | + out_ref = F.linear(inp, input_weight).sum() |
| 118 | + |
| 119 | + err_bnb = (out_bnb - out_ref).abs().max() |
| 120 | + err_native = (out_nf4 - out_ref).abs().max() |
| 121 | + assert err_native < 0.5 * dim |
| 122 | + assert err_bnb < 0.5 * dim |
| 123 | + |
| 124 | + |
| 125 | +if __name__ == "__main__": |
| 126 | + unittest.main() |
0 commit comments