|
8 | 8 | except ImportError:
|
9 | 9 | pytest.skip("triton is not installed", allow_module_level=True)
|
10 | 10 |
|
11 |
| -import bitsandbytes.functional as F |
| 11 | +from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise |
12 | 12 | import torch
|
13 | 13 |
|
14 | 14 | from torchao.prototype.galore.kernels import (
|
|
36 | 36 | def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
|
37 | 37 | g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
|
38 | 38 |
|
39 |
| - qmap = F.create_dynamic_map(signed).to(g.device) |
| 39 | + qmap = create_dynamic_map(signed).to(g.device) |
40 | 40 |
|
41 |
| - ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) |
| 41 | + ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) |
42 | 42 | bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)
|
43 | 43 |
|
44 | 44 | tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
|
@@ -82,10 +82,10 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
|
82 | 82 | def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
|
83 | 83 | g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
|
84 | 84 |
|
85 |
| - qmap = F.create_dynamic_map(signed).to(g.device) |
| 85 | + qmap = create_dynamic_map(signed).to(g.device) |
86 | 86 |
|
87 |
| - q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) |
| 87 | + q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) |
88 | 88 |
|
89 |
| - dq_ref = F.dequantize_blockwise(q, qstate) |
| 89 | + dq_ref = dequantize_blockwise(q, qstate) |
90 | 90 | dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
|
91 | 91 | assert torch.allclose(dq, dq_ref)
|
0 commit comments