Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWQ: Up to 2.66x higher throughput #2566

Merged
merged 6 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ torch::Tensor awq_gemm(
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy);
#endif

void squeezellm_gemm(
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
Expand Down
108 changes: 108 additions & 0 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
#endif
}

__global__ void __launch_bounds__(64) dequantize_weights(
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];

half* B_shared_ptr2 = B_shared;

half B_shared_warp[32];
int OC = 512;

int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N;
half* C_ptr2 = C + index1;

int index2 = col + row * N;
int* B_ptr2 = B + index2;

int index3 = col + (int)(row / G) * N;
int* zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half* scaling_factors_ptr2 = scaling_factors + index4;


uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
int j=0;

uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));

*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;

for (int i=0; i<8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}

} // namespace awq
} // namespace vllm

torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy)
{
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);

int x_thread = thx;
int y_thread = thy;

int x_blocks = 1;
int y_blocks = 1;
if (thx==0) {
x_thread = qout_c;
}
if (thy==0) {
y_thread = in_c;
}
if (thx==0 && thy==0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}

const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));

auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);

auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());

dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
kernel, scaling_factors, zeros, de_kernel, G);

return _de_kernel;
}

// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def get_torch_arch_list() -> Set[str]:
]

if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.extend([
casper-hansen marked this conversation as resolved.
Show resolved Hide resolved
"csrc/quantization/awq/gemm_kernels.cu",
])

if not _is_neuron():
vllm_extension = CUDAExtension(
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,16 @@ def apply_weights(self,
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)

# batch_size*seq_len >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 256
casper-hansen marked this conversation as resolved.
Show resolved Hide resolved

if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
Comment on lines +161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to learn: would this copy the dequantized weights back to the memory before doing torch.matmul? And a potential optimization is through implementing a more efficient mixed precision matmul that saves 1 data transfer to the memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are probably right that there is potential to eliminate overhead. Exllama runs dequantization and then directly calls cublas for matmul inside the same CUDA kernel. Definitely something to explore!

else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
Loading