Simple quantization, compatible with vllm/sglang.
git clone https://github.com/LambdaLabsML/openquant.git
cd openquant
python compress_fp8.py -m Qwen/Qwen3-32B
vllm serve Qwen3-32B-FP8
Model/quantization support:
Model | fp8 | awq |
---|---|---|
Qwen3 | ✅ | ✅ |
Qwen3 MoE | ✅ | * |
Llama 3 | ✅ | ✅ |
Llama 4 | ✅ | * |
Gemma 3 | ✅ | ✅ |
Mistral | ✅ | ✅ |
* AWQ can't really handle MoE models
For contributing new model architectures, see examples in openquant/models.py.
python compress_fp8.py -m Qwen/Qwen3-32B
tl;dr:
model size * 0.5
throughput * 1.2ish
(with a lot of caveats)
Models today are usually trained in bf16
, which is a decimal number stored in 16 bits (2 bytes). At the billions of parameter scale, these add up VERY quickly. The main reason for quantizing a model from bf16
to fp8
is memory reduction.
For example meta-llama/Llama-3.3-70B-Instruct has 70 billion parameters, which at
bf16
is 140 billion bytes or 140 GB of data. A single H100 GPU has 80GB of GPU RAM, so you'd need at LEAST 2xH100 to serve it, but likely more for kv cache space. If you halve the number of bytes, it would only take 70 GB, enabling it to comfortably fit on 2xH100s, and just fit barely on 1xH100.
Starting with NVIDIA H100 GPU, GPUs have hardware support for 8 bit floating point numbers (fp8
), meaning fp8
performance is >= bf16
performance (mostly). This performance gain comes from a couple of reasons:
- Model takes less GPU ram => more space for kv cache. Modern inference libraries (like vllm/sglang) will have higher/more stable performance with more space for kv cache
- Model parameters are half as big => less GPU memory bandwidth
- Depending on the GPU, fp8 FLOPS are just higher than bf16 FLOPS. E.g. See H100 specifications; bfloat16 has ~2k teraFLOPS and fp8 has ~4k teraFLOPS
When we talk about fp8 models, we typically only are talking about the weights being fp8. The actual execution of the model is still done in bf16
. So all the intermediate tensors are still in bf16, and it's the underlying CUDA kernels that are taking in bf16 tensors and fp8 weights.
fp8 models still use bf16
kv cache by default (since the kv cache stores kv values, which are intermediate tensors).
There are a number of different fp8
formats; the most common is float8_e4m3fn
. Here are the bit patterns for the f8 and f16 formats:
Format | Bit Pattern | INF Support |
---|---|---|
float8_e4m3fn | ⚫🟩🟩🟩🟩🟥🟥🟥 | ❌ |
float8_e5m2fn | ⚫🟩🟩🟩🟩🟩🟥🟥 | ❌ |
bfloat16 | ⚫🟩🟩🟩🟩🟩🟩🟩🟥🟥🟥🟥🟥🟥🟥 | ✅ |
float16 | ⚫🟩🟩🟩🟩🟥🟥🟥🟥🟥🟥🟥🟥🟥🟥 | ✅ |
where: ⚫ = Sign bit, 🟩 = Exponent bit, 🟥 = Mantissa (fraction) bit
Here are some facts about float8_e4m3fn
:
- This format has
1
sign bit,4
bits for exponent (e4
), and3
bits for mantissa (m3
) - Values can be between
[-448, +448]
- There are
256
representable values infinity
not supported (thefn
postfix stands for "finite numbers only" - there are other fp8 formats that do support infinity)NaN
supported- Model parameters are typically stored using this format (note that
inf
is not usually present in pretrained model parameters)
Expand this section to see all the possible fp8_e4m3fn values
torch.arange(256, dtype=torch.uint8).view(dtype=torch.float8_e4m3fn).tolist()
[0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875, 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875, 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375, 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0, 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0, 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0, 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan, -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875, -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875, -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375, -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875, -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375, -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875, -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375, -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875, -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75, -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5, -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0, -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0, -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0, -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0, -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0, -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan]
And here is how all the representable values are distributed (notice how there are waaaaay more values closer to 0! ):
So this leads us with two questions for quantization:
bf16
can store values between[-3.38953e+38, +3.38953e+38]
, how do we fit that into fp8 range of[-448, +448]
?- How do we take advantage of the distribution of values in fp8?
Since bf16
and fp8
have different ranges, we need to scale the values to fit into the fp8
range. This scale is based
on the max value of the data at bf16
, and is roughly computed like:
# NOTE: this will be a single value
scale = x.abs().amax() / 448
Then once we have the scale we can quantize the bf16
tensor:
x_quantized = (x / scale).clamp(min=-448, max=448).to(torch.float8_e4m3fn)
And to dequantize (which is essentially done on the fly at runtime inside the CUDA kernels), you do this (noting that you have to store the scale
values for the forward process):
x_dequantized = x.to(torch.bfloat16) * scale
Above I showed the scale being a single value, but you can also have it be a tensor. If you look at some popular open source fp8 models they typically use this option.
Why would you do this? To theoretically preserve accuracy, though if the values in your tensor are all relatively close together you won't get much benefit.
Given a weight_block_size of [128, 128]
, and a tensor of shape [N, K]
, the scale will be of size [N // 128, K // 128]
:
E.g. assuming x is 2d, we have the code:
N, K = x.shape
n, k = weight_block_size
x = x.reshape(N // n, n, K // k, k)
scale = x.abs().amax(dim=[1, 3]) / 448
assert scale.shape == torch.Size([N // n, K // k])
For compatibility with things like VLLM there's a couple things we need to do:
- Add the
weight_scale
as a parameter to each of theLinear
layers. This basically means just replace theLinear
layer with thisPackedLinear
class, whereweight
is thefp8
tensor, andweight_scale
is the scale.
class PackedLinear(torch.nn.Module):
def __init__(self, weight: torch.Tensor, weight_scale: torch.Tensor):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
- Add a
quantization_config
into the model's config. This will also appear in theconfig.json
file in the huggingface repo of the model.
model.config.quantization_config = {
"quant_method": "fp8",
"is_checkpoint_fp8_serialized": True,
"activation_scheme": "dynamic",
"weight_block_size": ..., # `None` or `[int, int]`
"ignored_layers": ..., # list of module names that are not quantized
}
And that's all we need to do for vllm!
NOTE: some models don't support all layers being quantized. For example, vllm does not support the decoder.mlp.gate
linear layer being quantized in Qwen3 MoE models.