Skip to content

Commit d08f60a

Browse files
committed
Update bitsandbytes import
1 parent 9bb1b23 commit d08f60a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/quantization/test_galore_quant.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
except ImportError:
99
pytest.skip("triton is not installed", allow_module_level=True)
1010

11-
import bitsandbytes.functional as F
11+
from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise
1212
import torch
1313

1414
from torchao.prototype.galore.kernels import (
@@ -36,9 +36,9 @@
3636
def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
3737
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
3838

39-
qmap = F.create_dynamic_map(signed).to(g.device)
39+
qmap = create_dynamic_map(signed).to(g.device)
4040

41-
ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
41+
ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize)
4242
bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)
4343

4444
tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
@@ -82,10 +82,10 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
8282
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
8383
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
8484

85-
qmap = F.create_dynamic_map(signed).to(g.device)
85+
qmap = create_dynamic_map(signed).to(g.device)
8686

87-
q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
87+
q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize)
8888

89-
dq_ref = F.dequantize_blockwise(q, qstate)
89+
dq_ref = dequantize_blockwise(q, qstate)
9090
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
9191
assert torch.allclose(dq, dq_ref)

0 commit comments

Comments
 (0)