-
Notifications
You must be signed in to change notification settings - Fork 98
/
test_galore_quant.py
91 lines (69 loc) · 3.16 KB
/
test_galore_quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import itertools
import pytest
# Skip entire test if triton is not available, otherwise CI failure
try:
import triton
except ImportError:
pytest.skip("triton is not installed", allow_module_level=True)
import bitsandbytes.functional as F
import torch
from torchao.prototype.galore.kernels import (
triton_dequant_blockwise,
triton_quantize_blockwise,
)
SEED = 0
torch.manual_seed(SEED)
DIM1 = [64, 1024, 4096]
DIM2 = [1024, 2048, 4096]
SIGNS = [True, False]
DTYPES = [torch.float32] # , torch.float16]
BLOCKSIZE = [2048]
TEST_CONFIGS = list(itertools.product(DIM1, DIM2, DTYPES, SIGNS, BLOCKSIZE))
@pytest.mark.skip("skipping for now, see comments below")
@pytest.mark.parametrize(
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
qmap = F.create_dynamic_map(signed).to(g.device)
ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)
tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
g, qmap, group_size=blocksize, return_normalized=True
)
tt_check = torch.allclose(ref_bnb, tt_q)
# see notes.md under `prototype.galore.kernels` for an explanation of the following conditions
if not tt_check:
print(
f"Failed quantization check for {dim1} x {dim2}, {dtype}, signed {signed}"
)
print(f"Absmax: {(qstate.absmax - tt_absmax).abs().max()}")
print(f"Norm diff: {(bnb_norm - tt_norm).abs().max()}")
idx_diff = (ref_bnb != tt_q).to("cuda")
print(f"Num code idx diffs: {idx_diff.sum()}")
max_idx_diff = (ref_bnb - tt_q).abs().max()
print(f"Max code idx diff: {max_idx_diff}")
# This below checks that the value being quantized falls half-way between two code buckets
# where bitsandbytes assigns to one and the triton implementation assigns to the other
# Since either bucket is technically valid, we only check that the distance between the value and the
# adjacent buckets are the same. I.e., we don't require that the triton implementation exactly matches
# bitsandbytes.
bnb_code = qmap[ref_bnb[idx_diff].tolist()]
tt_code = qmap[tt_q[idx_diff].tolist()]
bnb_dist = torch.abs(bnb_code - bnb_norm[idx_diff])
torch_dist = torch.abs(tt_code - bnb_norm[idx_diff])
dist_sum = torch.sum(bnb_dist - torch_dist)
print(f"Distance sum: {torch.sum(bnb_dist - torch_dist)}")
assert tt_check or (not tt_check and dist_sum < 1e-4)
@pytest.mark.parametrize(
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
qmap = F.create_dynamic_map(signed).to(g.device)
q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
dq_ref = F.dequantize_blockwise(q, qstate)
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
assert torch.allclose(dq, dq_ref)