Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC/WIP] First steps towards FP8 for Mixtral #3208

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions benchmarks/kernels/benchmark_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
for bs in [1, 2, 4, 8, 16, 24, 32, 80, 96, 128, 192, 200, 208, 216, 224, 256, 512, 1024, 1536, 2048, 3072, 4096]:
run_grid(bs, method=method)


Expand Down Expand Up @@ -124,20 +121,20 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
dtype=torch.float16,
)

ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
).to(torch.float8_e4m3fn)

w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
).to(torch.float8_e4m3fn)

gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
Expand Down
40 changes: 40 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ __device__ __forceinline__ T gelu_kernel(const T& x) {
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}

// Scaled activation and gating kernel template.
template<typename scalar_t>
__global__ void scaled_silu_and_mul_kernel(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ scales,
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float) input[token_idx * 2 * d + idx];
const float y = (float) input[token_idx * 2 * d + d + idx];
const float s = (float) scales[idx];
float r = silu_kernel(x) * y / s;
out[token_idx * d + idx] = (scalar_t) r;
}
}

} // namespace vllm

// Launch activation and gating kernel.
Expand Down Expand Up @@ -66,6 +83,29 @@ void silu_and_mul(
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void scaled_silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scales) // [2 * d]
{
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES_FP8(
input.scalar_type(),
"scaled_silu_and_mul_kernel",
[&] {
vllm::scaled_silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
scales.data_ptr<scalar_t>(),
d);
});
}

void gelu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
Expand Down
10 changes: 10 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES_FP8(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void scaled_silu_and_mul(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scales);

void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
Expand Down
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"scaled_silu_and_mul",
&scaled_silu_and_mul,
"Compute SwiGLU activation function and scale per channel");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
Expand Down
34 changes: 34 additions & 0 deletions examples/quantization/mixtral/mixtral_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from safetensors import safe_open
from safetensors.torch import save_file
import torch

SMOOTH_STRENGHT = 0.5

activation_scales = torch.load("/home/ray/default/mixtral_scales.pth")
smoothquant_scales = {}
for layer_idx in range(32):
key_prefix = f"model.layers.{layer_idx}.block_sparse_moe.experts"
target_prefix = f"model.layers.{layer_idx}.block_sparse_moe.scales"
for weight_name in ["w1", "w2", "w3"]:
tensors = [activation_scales[key_prefix + f".{expert_idx}.{weight_name}"] for expert_idx in range(8)]
smoothquant_scales[target_prefix + f".{weight_name}"] = torch.mean(torch.stack(tensors), dim=0)**SMOOTH_STRENGHT

def rewrite_safetensors(name):
tensors = {}
with safe_open(name, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
if "w1" in k or "w2" in k or "w3" in k:
name_parts = k.split(".")
scale_name = "model.layers." + name_parts[2] + ".block_sparse_moe.scales." + name_parts[-2]
print(f"scaling {k} with {scale_name}")
tensors[scale_name] = smoothquant_scales[scale_name]
tensors[k] *= smoothquant_scales[scale_name]
# Convert tensor to fp8
tensors[k] = tensors[k].to(torch.float8_e4m3fn)
save_file(tensors, name)

for i in range(1, 20):
filename = f"model-{i:05}-of-00019.safetensors"
print(f"rewriting {filename}")
rewrite_safetensors(filename)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [
"ninja",
"packaging",
"setuptools >= 49.4.0",
"torch == 2.1.2",
"torch == 2.2.0",
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
4 changes: 2 additions & 2 deletions requirements-build.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
ninja
packaging
setuptools>=49.4.0
torch==2.1.2
wheel
torch==2.2.0
wheel
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ psutil
ray >= 2.9
sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
torch == 2.2.0
transformers >= 4.38.0 # Required for Gemma.
xformers == 0.0.23.post1 # Required for CUDA 12.1.
xformers == 0.0.24 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
triton == 2.2.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 2, "num_stages": 2},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 8},
"4": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 8},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2},
"16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 2},
"32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 8},
"80": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 8},
"96": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 8},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 8},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 8},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 8},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2},
"256": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 2},
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2},
"80": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4}
}
Loading