From 17f61af9931d36b83795d279eded94ed6a957fc2 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 27 Jan 2024 08:53:17 +0100 Subject: [PATCH] AWQ: Up to 2.66x higher throughput (#2566) --- csrc/ops.h | 8 ++ csrc/pybind.cpp | 1 + csrc/quantization/awq/gemm_kernels.cu | 108 ++++++++++++++++++ .../model_executor/layers/quantization/awq.py | 11 +- 4 files changed, 127 insertions(+), 1 deletion(-) diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da1417..d49619644b182 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -70,6 +70,14 @@ torch::Tensor awq_gemm( torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); + +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index e6683c446154d..88af7eac8a28f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 04dfe8fe9b889..376c8ebfb9b7a 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in #endif } +__global__ void __launch_bounds__(64) dequantize_weights( + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + half* __restrict__ C, + int G +) +{ + int j_factors1 = 4; + int row_stride2 = 4; + int split_k_iters = 1; + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + + half B_shared_warp[32]; + int OC = 512; + + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; + int index1 = 8 * col + 8 * row * N; + half* C_ptr2 = C + index1; + + int index2 = col + row * N; + int* B_ptr2 = B + index2; + + int index3 = col + (int)(row / G) * N; + int* zeros_ptr2 = zeros + index3; + int index4 = 8 * col + (int)(row / G) * N * 8; + half* scaling_factors_ptr2 = scaling_factors + index4; + + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); +int j=0; + + uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + + for (int i=0; i<8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } +} + } // namespace awq } // namespace vllm +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy) +{ + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx==0) { + x_thread = qout_c; + } + if (thy==0) { + y_thread = in_c; + } + if (thx==0 && thy==0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } + + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); + + return _de_kernel; +} + // in_feats: M, IC [float16] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // scaling_factors: IC // G, OC [float16] diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 831576b1d7cd7..4d3fd3ec0cc71 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -153,7 +153,16 @@ def apply_weights(self, pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) + + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + out = torch.matmul(reshaped_x, out) + else: + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape)