Skip to content

Commit ad10e67

Browse files
czhu-coheredebroy-rh
authored andcommitted
[Kernel] Faster pre-processing time for W4A8 (vllm-project#23972)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent c7d6d93 commit ad10e67

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "cutlass_extensions/common.hpp"
2626
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
2727

28+
#include <cuda_runtime.h>
29+
2830
namespace vllm::cutlass_w4a8 {
2931

3032
using namespace cute;
@@ -393,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
393395
return packed_scales;
394396
}
395397

398+
/*
399+
GPU-accelerated implementation of cutlass::unified_encode_int4b.
400+
Constructs a lookup table in constant memory to map 8 bits
401+
(two 4-bit values) at a time. Assumes memory is contiguous
402+
and pointers are 16-byte aligned.
403+
*/
404+
__constant__ uint8_t kNibbleLUT[256];
405+
406+
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
407+
size_t nbytes) {
408+
constexpr size_t V = sizeof(uint4); // 16 bytes
409+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
410+
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
411+
const size_t nvec = nbytes / V;
412+
413+
// 1-D grid-stride loop over 16-byte chunks
414+
for (size_t vec = tid; vec < nvec; vec += nthreads) {
415+
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
416+
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
417+
#pragma unroll
418+
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
419+
reinterpret_cast<uint4*>(out)[vec] = v;
420+
}
421+
}
422+
423+
static bool upload_lut() {
424+
std::array<uint8_t, 256> lut{};
425+
auto map_nib = [](uint8_t v) -> uint8_t {
426+
// 1..7 -> (8 - v); keep 0 and 8..15
427+
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
428+
};
429+
for (int b = 0; b < 256; ++b) {
430+
uint8_t lo = b & 0xF;
431+
uint8_t hi = (b >> 4) & 0xF;
432+
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
433+
}
434+
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
435+
/*offset=*/0, cudaMemcpyHostToDevice);
436+
437+
return (e == cudaSuccess);
438+
}
439+
440+
static bool unified_encode_int4b(cutlass::int4b_t const* in,
441+
cutlass::int4b_t* out, size_t num_int4_elems) {
442+
// Build/upload LUT
443+
if (!upload_lut()) return false;
444+
445+
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
446+
"int4 storage must be 1 byte");
447+
const size_t nbytes = num_int4_elems >> 1;
448+
449+
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
450+
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
451+
452+
// kernel launch params
453+
constexpr int block = 256;
454+
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
455+
int grid = int((nvec + block - 1) / block);
456+
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
457+
458+
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
459+
cudaError_t err = cudaGetLastError();
460+
return (err == cudaSuccess);
461+
}
462+
396463
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
397464
TORCH_CHECK(B.dtype() == torch::kInt32);
398465
TORCH_CHECK(B.dim() == 2);
@@ -401,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
401468

402469
int k = B.size(0) * PackFactor; // logical k
403470
int n = B.size(1);
471+
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
404472

405473
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
406474
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -409,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
409477
LayoutB_Reordered layout_B_reordered =
410478
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
411479

412-
cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
480+
bool ok =
481+
vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
482+
TORCH_CHECK(ok, "unified_encode_int4b failed");
413483
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
414484

415485
return B_packed;

0 commit comments

Comments
 (0)