Description
Here is my understanding of the existing state of things and what I think we should be doing to make our lower-bit kernels more performant at both small and larger batch sizes. I'm making this an RFC because I'm curious whether I'm paying attention to the wrong things so if you disagree with any of the below please comment!
First a quick survey of libraries
Survey of existing solutions
Interestingly enough none of the below solutions package their libraries into a package and instead encourage users to copy-paste their code and cite it. It's common to make these libraries to be headers only to make integrations easier.
And we thankfully do have the machinery to support CUDA kernels on multiple different kinds of versions with minimal headache thanks to our custom CUDA extension support https://github.com/pytorch/ao/tree/main/torchao/csrc
So it's easy to merge kernels but which ones should we actually merge?
Marlin
This is the kernel of choice in VLLM arguably the most popular inference provider on the market, they have fp16xint4 kernels that work for smaller batch sizes but larger than tinygemm and competitors and the kernels don't seem particularly affected by power limitation on GPUs, something that has bit us in the past when running internal performance benchmarks.
There's also a 2:4 sparse variant of the above kernel which we're already working on upstreaming #621 yet I'm not sure right now whether we should look to merge both kernels or just the sparse one.
Regardless the https://github.com/IST-DASLab/marlin lab does excellent work consistently and is worth following for us
tinygemm
tinygemm isn't a full library in core but it's an op and it's the speediest thing we've found for int4 weight-only quantization (w4a16) so far torch.ops.aten._weight_int4pack_mm
. One of the challenges though is because of how fast it is it becomes a hammer and all our performance problems become nails whereas if we could easily accelerate other dtypes we might not rely on it so much
CUTLASS
This work leverages Universal Gemm operator in CUTLASS NVIDIA/cutlass#1549 - no bit is packing since CUTLASS supports a type for cutlass::int4b_t
There are also some open PRs in CUTLASS for signed and unsigned int4/int8 multiplication with activations in fp16 NVIDIA/cutlass#1413 by @alexsamardzic
Perhaps the main recurring con that comes up with CUTLASS is that it's hard to learn but it generally is one of the best perf targets considering it's more vertically integrated within the NVIDIA stack. And well maybe it's not hard, maybe it's a bit of a skill issue on my end.
gemlite
This is a more recent project but it offers GEMV acceleration https://mobiusml.github.io/gemlite_blogpost/ by @mobicham
The core idea is well explained in https://github.com/Bruce-Lee-LY/cuda_hgemv#optimization-method where they walk through naive implementations to ones efficiently using shared memory and warp scheduling
GEMV kernels are inherently solving a more restricted problem which is bs=1 inference a la gpt-fast
However, despite being limited to batch size 1, gemlite is quite expressive in that allows arbitrary weight dtype. If you look at their function definition gemv_A16fWnO16f_int32packing
you can read that _fp16 x n-bit as 32-bit packed, mixed fp16 accumulation
The code is quite short and restricted to very few files so quite easy to releverage.
bitblas
https://github.com/microsoft/BitBLAS
This is the only repo with a pip package so packaging it doesn't make as much sense although we could explore using it as an optional backend in ao in cases when we don't have the right kernel. Their support matrix is probably the most comprehensive out of any repo in this list https://github.com/microsoft/BitBLAS#support-matrix
Suggested next steps
Merge the obviously useful kernels
The sort of obvious next steps to match the current state of things are
- Merging and packaging Marlin Kernels because we don't support int4 at medium batches and we dont have a good story for fast sparsity
- Merging and packaging CUTLASS kernels because they are very fast and are GEMM kernels and not purely GEMV meaning they will help for larger batch sizes something where we don't do super well yet and has been a recurring ask for some outside partners dealing with high throughput inference
Considering both of the above work let us work with larger batch sizes than 1 and are an industry standard where people have been frustrated with the installation experience.
Write the non-obvious kernels
For the non-obvious kernels, they haven't been written yet so our strategy typically has been
- Cheat by using
torch.compile()
with clever bitpacking as a baseline - Run end-to-end benchmarks against the best options on the market. Not possible considering a lot of these kernels don't exist
- Run speed of light analysis using the new profiler by @jeromeku [FEAT] Perf Profiler Update #690
End to end benchmarks are certainly helpful but considering here we're talking about kernels we'd also need to run microbenchmarks on various shapes as @jerryzh168 suggests
For bs=1 get better performance for dtypes smaller than 4
gemlite is a nice educational library supporting gemv for a variety of dtypes, so leveraging it not just for end-to-end performance benchmarks but also speed-of-light calculations to help us understand a bit better the gaps for bs=1 inference. The idea here is to ensure that performance is great for a variety of intX as opposed to overfitting to 4 just because we have tinygemm
@vayuda has already led some early work here by doing bitpacks with torch.compile so we need to start baselining more heavily
for bs=n inference start writing new kernels since they don't exist
For H100+
The biggest theme here is that instead of relying on fp16 as the activation dtype we can instead rely on fp8
Some of this work was already mentioned here #663 but we'll add more detail
- Compelling perf for fp8 gemm at a variety of batch sizes which is work started by @drisspg and @jainapurva
- Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes. In particular, this ask came from the team at Neural Magic directly
For A100
For A100 our options are a bit more obvious where we should be showing compelling dynamic quantization (quantize the activations to int8) performance on larger batch sizes. gpt-fast has already been extended to support larger batch sizes https://github.com/pytorch-labs/gpt-fast/tree/batched_generation
For this work we'd focus on int8 dynamic quantization and then work our way down from there.
Related work
- LUT-GEMM https://arxiv.org/abs/2206.09557 uses lookup tables for the weights instead of having to dequantize them
- Atom low bit quantization for efficient and accurate LLM serving https://arxiv.org/abs/2310.19102 - in particular they implemented their kernels for W8A8 and W4A16 but since we already have tinygemm this is not super relevant
- Flash Infer (everyone else is citing this work) https://github.com/flashinfer-ai/flashinfer a kernel library for inference. bs=1 and n kernels for prefill, decode and append kernels on different kv cache formats including pagged, ragged and page table. compressed and quantized kv cahe.
- DeepGEMM: https://arxiv.org/abs/2304.09049 - 2 bit matrix multiplication is represented as a lookup table and then 4 of these values are packed into an 8-bit vector register. Their benchmarks are on x86 and they benchmark vs QNNPACK - this is not CUDA specific
- https://github.com/google/gemmlowp this is an older project which no longer seems maintained and is primarily meant to accelerate x86 and arm