Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

vllm - quantization : DO NOT MERGE #180

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/smoothquant/fused_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,11 @@ def main(args: argparse.Namespace):
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument(
'--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
Expand Down
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8.cuh"
#include "dtype_int8.cuh"
8 changes: 8 additions & 0 deletions csrc/attention/dtype_float32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) {
return c;
}

// for compiling, the above function seems to be useless
inline __device__ Float4_ add(Float4_ a, Float4_ b) {
Float4_ c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}

// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
Expand Down
49 changes: 49 additions & 0 deletions csrc/attention/dtype_int8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <stdint.h>
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

namespace vllm {
// define int8 vector types for quantization of kv cache

template<>
struct Vec<int8_t, 1> {
using Type = int8_t;
};

template<>
struct Vec<int8_t, 2> {
using Type = int16_t;
};

template<>
struct Vec<int8_t, 4> {
using Type = int32_t;
};

template<>
struct Vec<int8_t, 8> {
using Type = int64_t;
};

template<>
struct FloatVec<int8_t> {
using Type = float;
};

template<>
struct FloatVec<int16_t> {
using Type = float2;
};

template<>
struct FloatVec<int32_t> {
using Type = Float4_;
};

template<>
struct FloatVec<int64_t> {
using Type = Float8_;
};
}
10 changes: 9 additions & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down
21 changes: 21 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,27 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void dequant(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void dequant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale,
float weight_dequant_scale);

void quant(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
15 changes: 15 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
"quant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&quant),
"Quant.");
ops.def(
"quant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
torch::Tensor&>(
&quant),
"Per-token quant.");

// Rotary embedding
ops.def(
Expand Down
91 changes: 91 additions & 0 deletions csrc/quantization/smoothquant/fused_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <assert.h>

#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
#include "quant_utils.cuh"

namespace vllm {

template <typename scalar_t, typename scale_type, bool use_per_token_quant>
__global__ void quant_kernel(
const scalar_t* __restrict__ input,
int8_t* __restrict__ out,
scale_type scale,
const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = (float)input[token_idx * hidden_size + i];
val = val > zero ? val : -val;
if (val > amax_val)
amax_val = val;
}

__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (tid == 0) {
s_amax = block_amax_val;
scale[token_idx] = block_amax_val / 127.0f;
}
__syncthreads();

float tmp_scale = 127.0f / s_amax;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale);
}
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
}
} // namespace vllm

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float, false><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale,
hidden_size);
});
}

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale) { // [num_tokens]
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float*, true><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(),
hidden_size);
});
}
Loading