Skip to content

Commit

Permalink
[Kernel] FP8 support for MoE kernel / Mixtral (vllm-project#4244)
Browse files Browse the repository at this point in the history
This PR is the first step towards fixing vllm-project#3208

It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
  • Loading branch information
pcmoritz authored Apr 24, 2024
1 parent 3dec7b3 commit f90df6e
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 21 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
103 changes: 103 additions & 0 deletions csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template<typename scalar_t>
__global__ void segmented_max_reduction(
float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;

__syncthreads();

// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
}
}

template<typename scalar_t>
__global__ void scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
i += blockDim.x * gridDim.x;
}
}

} // namespace vllm

void scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

10 changes: 9 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
return output, scale


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}
Loading

0 comments on commit f90df6e

Please sign in to comment.