2525#include " cutlass_extensions/common.hpp"
2626#include " cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
2727
28+ #include < cuda_runtime.h>
29+
2830namespace vllm ::cutlass_w4a8 {
2931
3032using namespace cute ;
@@ -393,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
393395 return packed_scales;
394396}
395397
398+ /*
399+ GPU-accelerated implementation of cutlass::unified_encode_int4b.
400+ Constructs a lookup table in constant memory to map 8 bits
401+ (two 4-bit values) at a time. Assumes memory is contiguous
402+ and pointers are 16-byte aligned.
403+ */
404+ __constant__ uint8_t kNibbleLUT [256 ];
405+
406+ __global__ void unified_encode_int4b_device (const uint8_t * in, uint8_t * out,
407+ size_t nbytes) {
408+ constexpr size_t V = sizeof (uint4 ); // 16 bytes
409+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
410+ const size_t nthreads = size_t (gridDim .x ) * blockDim .x ;
411+ const size_t nvec = nbytes / V;
412+
413+ // 1-D grid-stride loop over 16-byte chunks
414+ for (size_t vec = tid; vec < nvec; vec += nthreads) {
415+ uint4 v = reinterpret_cast <const uint4 *>(in)[vec];
416+ uint8_t * b = reinterpret_cast <uint8_t *>(&v);
417+ #pragma unroll
418+ for (int i = 0 ; i < int (V); ++i) b[i] = kNibbleLUT [b[i]];
419+ reinterpret_cast <uint4 *>(out)[vec] = v;
420+ }
421+ }
422+
423+ static bool upload_lut () {
424+ std::array<uint8_t , 256 > lut{};
425+ auto map_nib = [](uint8_t v) -> uint8_t {
426+ // 1..7 -> (8 - v); keep 0 and 8..15
427+ return (v == 0 || (v & 0x8 )) ? v : uint8_t (8 - v);
428+ };
429+ for (int b = 0 ; b < 256 ; ++b) {
430+ uint8_t lo = b & 0xF ;
431+ uint8_t hi = (b >> 4 ) & 0xF ;
432+ lut[b] = uint8_t ((map_nib (hi) << 4 ) | map_nib (lo));
433+ }
434+ cudaError_t e = cudaMemcpyToSymbol (kNibbleLUT , lut.data (), lut.size (),
435+ /* offset=*/ 0 , cudaMemcpyHostToDevice);
436+
437+ return (e == cudaSuccess);
438+ }
439+
440+ static bool unified_encode_int4b (cutlass::int4b_t const * in,
441+ cutlass::int4b_t * out, size_t num_int4_elems) {
442+ // Build/upload LUT
443+ if (!upload_lut ()) return false ;
444+
445+ static_assert (sizeof (typename cutlass::int4b_t ::Storage) == 1 ,
446+ " int4 storage must be 1 byte" );
447+ const size_t nbytes = num_int4_elems >> 1 ;
448+
449+ auto * in_bytes = reinterpret_cast <uint8_t const *>(in);
450+ auto * out_bytes = reinterpret_cast <uint8_t *>(out);
451+
452+ // kernel launch params
453+ constexpr int block = 256 ;
454+ const size_t nvec = nbytes / sizeof (uint4 ); // # of 16B vectors
455+ int grid = int ((nvec + block - 1 ) / block);
456+ if (grid == 0 ) grid = 1 ; // ensure we still cover the tail in the kernel
457+
458+ unified_encode_int4b_device<<<grid, block>>> (in_bytes, out_bytes, nbytes);
459+ cudaError_t err = cudaGetLastError ();
460+ return (err == cudaSuccess);
461+ }
462+
396463torch::Tensor encode_and_reorder_int4b (torch::Tensor const & B) {
397464 TORCH_CHECK (B.dtype () == torch::kInt32 );
398465 TORCH_CHECK (B.dim () == 2 );
@@ -401,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
401468
402469 int k = B.size (0 ) * PackFactor; // logical k
403470 int n = B.size (1 );
471+ TORCH_CHECK ((n * k) % 32 == 0 , " need multiples of 32 int4s for 16B chunks" );
404472
405473 auto B_ptr = static_cast <QuantType const *>(B.const_data_ptr ());
406474 auto B_packed_ptr = static_cast <QuantType*>(B_packed.data_ptr ());
@@ -409,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
409477 LayoutB_Reordered layout_B_reordered =
410478 cute::tile_to_shape (LayoutAtomQuant{}, shape_B);
411479
412- cutlass::unified_encode_int4b (B_ptr, B_packed_ptr, n * k);
480+ bool ok =
481+ vllm::cutlass_w4a8::unified_encode_int4b (B_ptr, B_packed_ptr, n * k);
482+ TORCH_CHECK (ok, " unified_encode_int4b failed" );
413483 cutlass::reorder_tensor (B_packed_ptr, layout_B, layout_B_reordered);
414484
415485 return B_packed;
0 commit comments