From aa340c6bec7753879583abd714011f8b298546a8 Mon Sep 17 00:00:00 2001 From: Rex Date: Mon, 12 Feb 2024 11:02:17 -0800 Subject: [PATCH] Refactor 2 awq gemm kernels into m16nXk32 (#2723) Co-authored-by: Chunan Zeng --- csrc/quantization/awq/gemm_kernels.cu | 366 ++++-------------- .../model_executor/layers/quantization/awq.py | 2 +- 2 files changed, 73 insertions(+), 295 deletions(-) diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 376c8ebfb9b7a..5aefb0bd16aef 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -27,72 +27,85 @@ __pack_half2(const half x, const half y) { return (v1 << 16) | v0; } -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +template +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + int M, + int IC, + int OC, + half* __restrict__ C) { + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else static constexpr uint32_t ZERO = 0x0; float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (128 + 8)]; - - __shared__ half scaling_factors_shared[128]; - __shared__ half zeros_shared[128]; + __shared__ half B_shared[32 * (N + 8)]; - int j_factors1 = ((OC + 128 - 1) / 128); + __shared__ half scaling_factors_shared[N]; + __shared__ half zeros_shared[N]; + + int j_factors1 = ((OC + N - 1) / N); int blockIdx_x = 0; int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); half A_shared_warp[8]; - half B_shared_warp[32]; - for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { for (int i = 0; i < 8; ++i) { C_warp[(j_0_4_init * 8) + i] = 0.0; } } static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 128; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; + static constexpr int row_stride = 2 * 32 * 8 / N; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A + half* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; - + int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 2 - + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + (((int)threadIdx.x) % (128 / 8)) * 1; + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; // Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8) ) * 8; half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) - + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) - + (((int)threadIdx.x) % (128 / 8)) * 8; - + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + ((int)threadIdx.x) % (128 / 8); - + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (128) - + (((int)threadIdx.x) % (128 / 8)) * 8; + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; - half* C_ptr = C + half* C_ptr = C + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 128 - + ((int)threadIdx.y) * 64 + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros @@ -123,13 +136,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); @@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; } __syncthreads(); @@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ); } - for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) ); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" @@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ); } } - for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( @@ -258,241 +271,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i #endif } - -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - assert(false); -#else - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (64 + 8)]; - - __shared__ half scaling_factors_shared[64]; - __shared__ half zeros_shared[64]; - - int j_factors1 = ((OC + 64 - 1) / 64); - - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - - half A_shared_warp[8]; - half B_shared_warp[16]; - for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 64; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 4 - + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + (((int)threadIdx.x) % (64 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) - + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + ((int)threadIdx.x) % (64 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (64) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 64 - + ((int)threadIdx.y) * 32 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) - { - { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) - { - { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) - { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); - } - } - } -#endif -} - __global__ void __launch_bounds__(64) dequantize_weights( int* __restrict__ B, half* __restrict__ scaling_factors, @@ -526,26 +304,24 @@ __global__ void __launch_bounds__(64) dequantize_weights( int index4 = 8 * col + (int)(row / G) * N * 8; half* scaling_factors_ptr2 = scaling_factors + index4; + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); -int j=0; + uint32_t B_loaded = *(uint32_t*)B_ptr2; + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - - *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + *(uint4*)B_shared_ptr2 = B_loaded_fp16; - for (int i=0; i<8; ++i) { + for (int i = 0; i < 8; ++i) { *(C_ptr2 + i) = B_shared[i]; } } @@ -650,19 +426,21 @@ torch::Tensor awq_gemm( // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); } else if (num_out_channels % 64 == 0) { int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - + // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); } return _out_feats.sum(0); } diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 681f95821eabb..3e1c814dd233c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -145,8 +145,8 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] - qzeros = weights["qzeros"] scales = weights["scales"] + qzeros = weights["qzeros"] pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1])