Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
torch._check_is_size(blocksize)

if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

Expand Down Expand Up @@ -270,7 +270,7 @@ def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

Expand Down Expand Up @@ -304,7 +304,7 @@ def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

Expand Down Expand Up @@ -386,7 +386,7 @@ def _dequantize_4bit_impl(
out: torch.Tensor,
) -> None:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

Expand Down
102 changes: 102 additions & 0 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,94 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
}

// Specialized kernel for blocksize=64 with 4-bit quantization
// Works on both warp32 and warp64 hardware
// Processes 2 blocks of 64 values per thread block using 64 threads
// Uses logical warps of 32: threads 0-31 handle block 0, threads 32-63 handle block 1
// - warp32: 2 hardware warps, each reduces naturally
// - warp64: 1 hardware warp split into 2 logical warps of 32
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise64(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
constexpr int BLOCK_SIZE = 64; // Size of each quantization block
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
constexpr int THREADS = 64; // Total threads per HIP block
constexpr int THREADS_PER_BLOCK = 32; // Threads handling each quantization block

const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per HIP block

T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
float local_abs_max = 0.0f;

const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-31, 1 for threads 32-63
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the quantization block (0-31)

typedef hipcub::BlockLoad<T, THREADS, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
// Logical warp size of 32: on warp32 this matches hardware warps,
// on warp64 this splits the single hardware warp into two independent reductions
typedef hipcub::WarpReduce<float, 32> WarpReduce;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
__shared__ float smem_absmax_value[2];

const int i = base_idx + block_id * BLOCK_SIZE;
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
// operations that require ALL 64 threads to participate
const bool block_valid = (i < n);

// All 64 threads participate in the load (out-of-bounds threads get 0.0f)
__syncthreads();
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);

// Each thread computes max of its values
local_abs_max = -FLT_MAX;
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

// Reduce within each logical warp of 32 threads independently
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, hipcub::Max());

if (local_thread_id == 0) {
if (block_valid) {
smem_absmax_value[block_id] = 1.0f / local_abs_max;
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
} else {
smem_absmax_value[block_id] = 0.0f;
}
}
__syncthreads();

local_abs_max = smem_absmax_value[block_id];

switch (DATA_TYPE) {
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}

// All 64 threads participate in the store (valid_items limits the actual writes)
__syncthreads();
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
}

template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
{
Expand Down Expand Up @@ -2566,6 +2654,20 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#endif

// Specialized blocksize=64 4-bit quantization kernel instantiations for ROCm
#define MAKE_kQuantizeBlockwise64(dtype, data_type_name) \
template __global__ void kQuantizeBlockwise64<dtype, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);

// FP4 instantiations
MAKE_kQuantizeBlockwise64(half, FP4)
MAKE_kQuantizeBlockwise64(float, FP4)
MAKE_kQuantizeBlockwise64(hip_bfloat16, FP4)

// NF4 instantiations
MAKE_kQuantizeBlockwise64(half, NF4)
MAKE_kQuantizeBlockwise64(float, NF4)
MAKE_kQuantizeBlockwise64(hip_bfloat16, NF4)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
Expand Down
5 changes: 5 additions & 0 deletions csrc/kernels_hip.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ __global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise64(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
Expand Down
11 changes: 8 additions & 3 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,14 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64 && BNB_WARP_SIZE == 32)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);

else if(blocksize == 64) {
// For 4-bit (FP4/NF4): use specialized kernel that processes 2 blocks of 64 per thread block
// Works on all warp sizes (32 and 64) by using logical warps of 32
if constexpr(DATA_TYPE > 0)
hipLaunchKernelGGL(( kQuantizeBlockwise64<T, DATA_TYPE>), dim3((num_blocks + 1) / 2), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
}

CUDA_CHECK_RETURN(hipPeekAtLastError());
}
Expand Down
6 changes: 3 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize",
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512, 1024, 2048, 4096],
)
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
)
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_linear_serialization(

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand All @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down