Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

Expand Down Expand Up @@ -272,7 +272,7 @@ def _dequantize_blockwise_impl(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
Expand Down Expand Up @@ -306,7 +306,7 @@ def _(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down Expand Up @@ -388,7 +388,7 @@ def _dequantize_4bit_impl(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def quantize_4bit(
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
Expand Down Expand Up @@ -953,7 +953,7 @@ def dequantize_4bit(
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.

Raises:
Expand Down
105 changes: 102 additions & 3 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,90 @@ __global__ void kQuantizeBlockwise(
}
}

// Specialized kernel for blocksize=32 with 4-bit quantization
// Processes 2 blocks of 32 values per warp to maintain full thread utilization
// Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise32(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
constexpr int BLOCK_SIZE = 32; // Size of each quantization block
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
constexpr int THREADS = 32; // Total threads (full warp)
constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block

const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block

T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
float local_abs_max = 0.0f;

const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)

typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::WarpReduce<float, 16>
WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
__shared__ float smem_absmax_value[2];

const int i = base_idx + block_id * BLOCK_SIZE;
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
// operations that require ALL 32 threads to participate
const bool block_valid = (i < n);

// All 32 threads participate in the load (out-of-bounds threads get 0.0f)
__syncthreads();
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);

// Each thread computes max of its values
local_abs_max = -FLT_MAX;
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

// Reduce within each logical warp of 16 threads independently
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX);

if (local_thread_id == 0) {
if (block_valid) {
smem_absmax_value[block_id] = 1.0f / local_abs_max;
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
} else {
smem_absmax_value[block_id] = 0.0f;
}
}
__syncthreads();

local_abs_max = smem_absmax_value[block_id];

switch (DATA_TYPE) {
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}

// All 32 threads participate in the store (valid_items limits the actual writes)
__syncthreads();
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
}

template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
Expand Down Expand Up @@ -2440,9 +2524,24 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
// Template instantiations for blocksize=32 specialized kernel (4-bit only)
#define MAKE_kQuantizeBlockwise32(dtype, data_type_name) \
template __global__ void kQuantizeBlockwise32<dtype, data_type_name>( \
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
const int rand_offset, const int n \
);

// FP4 instantiations for blocksize=32
MAKE_kQuantizeBlockwise32(half, FP4) MAKE_kQuantizeBlockwise32(float, FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4)

// NF4 instantiations for blocksize=32
MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float, NF4) MAKE_kQuantizeBlockwise32(
__nv_bfloat16, NF4
)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
Expand Down
5 changes: 5 additions & 0 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ __global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise32(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ void quantizeBlockwise(
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 32) {
// For 4-bit: use specialized kernel (kQuantizeBlockwise32) that processes 2 blocks per warp
// Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2
if (DATA_TYPE > 0) {
int num_blocks_adjusted = (num_blocks + 1) / 2;
kQuantizeBlockwise32<T, DATA_TYPE><<<num_blocks_adjusted, 32>>>(code, A, absmax, out, rand, rand_offset, n);
}
}

CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
Expand Down
14 changes: 11 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
)
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand All @@ -1122,6 +1122,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
error_dict["fp4"] = dict()
error_dict["nf4"] = dict()
error_dict["fp4"]["err"] = {
32: 0.088918,
64: 0.096545,
128: 0.102947,
256: 0.108685,
Expand All @@ -1131,6 +1132,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
4096: 0.129573,
}
error_dict["fp4"]["rel_err"] = {
32: 0.242380,
64: 0.260130,
128: 0.275734,
256: 0.289842,
Expand All @@ -1141,6 +1143,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
}

error_dict["nf4"]["err"] = {
32: 0.067745,
64: 0.072792,
128: 0.076835,
256: 0.080326,
Expand All @@ -1150,6 +1153,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
4096: 0.092537,
}
error_dict["nf4"]["rel_err"] = {
32: 0.189700,
64: 0.203299,
128: 0.215252,
256: 0.226044,
Expand All @@ -1168,7 +1172,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -1205,7 +1211,9 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
)
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
"""
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_linear_serialization(

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand All @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
Expand Down