Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Compression config cutlass (#205)
Browse files Browse the repository at this point in the history
Use cutlass kernels.

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
2 people authored and dsikka committed Apr 24, 2024
1 parent b2c39a1 commit e8d1886
Show file tree
Hide file tree
Showing 15 changed files with 146 additions and 1,793 deletions.
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc"
"csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc"
"csrc/quantization/smoothquant/int8gemm/cuda_utils.cc"
"csrc/quantization/smoothquant/fused_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
22 changes: 0 additions & 22 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "quantization/smoothquant/int8gemm/int8_gemm.h"
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down Expand Up @@ -50,21 +49,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
"dequant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&dequant),
"Dequant.");
ops.def(
"dequant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
torch::Tensor&,
float>(&dequant),
"Per-token dequant.");
ops.def(
"quant",
py::overload_cast<
Expand Down Expand Up @@ -104,12 +88,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
pybind11::class_<I8CUGEMM>(ops, "I8CUGEMM")
.def(pybind11::init<>())
.def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32)
.def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8)
.def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_)
.def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_);
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
73 changes: 1 addition & 72 deletions csrc/quantization/smoothquant/fused_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,6 @@
#include "quant_utils.cuh"

namespace vllm {
template <typename scalar_t, bool use_per_token_dequant>
__global__ void dequant_kernel(
const int32_t* __restrict__ input,
scalar_t* __restrict__ out,
const float scale,
const int m,
const int hidden_size,
const int input_stride,
const int out_stride,
const float* __restrict__ act_scale = nullptr) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
float scale_ = scale;
if constexpr (use_per_token_dequant) {
scale_ = scale * act_scale[token_idx];
}
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * out_stride + i] =
(scalar_t)(((float)input[token_idx * input_stride + i]) * scale_);
}
}

template <typename scalar_t, typename scale_type, bool use_per_token_quant>
__global__ void quant_kernel(
Expand Down Expand Up @@ -71,56 +50,6 @@ __global__ void quant_kernel(
}
} // namespace vllm

void dequant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int input_stride = input.stride(-2);
int out_stride = out.stride(-2);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] {
vllm::dequant_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(),
out.data_ptr<scalar_t>(),
scale,
num_tokens,
hidden_size,
input_stride,
out_stride);
});
}

void dequant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale,
float weight_dequant_scale) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int input_stride = input.stride(-2);
int out_stride = out.stride(-2);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] {
vllm::dequant_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(),
out.data_ptr<scalar_t>(),
weight_dequant_scale,
num_tokens,
hidden_size,
input_stride,
out_stride,
scale.data_ptr<float>());
});
}

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
Expand Down Expand Up @@ -159,4 +88,4 @@ void quant(
scale.data_ptr<float>(),
hidden_size);
});
}
}
232 changes: 0 additions & 232 deletions csrc/quantization/smoothquant/int8gemm/allocator.h

This file was deleted.

Loading

0 comments on commit e8d1886

Please sign in to comment.