Skip to content

Optimized fused MoE Kernel, take 2 #2979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions benchmarks/kernels/benchmark_mixtral_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import json
import os
import sys

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from vllm.model_executor.layers.fused_moe import fused_moe
import torch
import torch.nn.functional as F
import triton


def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)


def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100

num_warmup_trials = 1
num_trials = 1

configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]

for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})

best_config = None
best_time_us = 1e20

for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print(f'warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue

# trial
print(f'benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)

kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers

if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us

print(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}'
)

print("best_time_us", best_time_us)
print("best_config", best_config)

filename = "/tmp/config.jsonl"
print(f"writing config to file {filename}")
with open(filename, "a") as f:
f.write(json.dumps({str(bs): best_config}) + "\n")


def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
)

ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=ws,
w2=w2s,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
)
end_event.record()
end_event.synchronize()

dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms


if __name__ == "__main__":
sys.exit(main())
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def get_requirements() -> List[str]:
return requirements


package_data = {"vllm": ["py.typed"]}
package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}
if os.environ.get("VLLM_USE_PRECOMPILED"):
ext_modules = []
package_data["vllm"].append("*.so")
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe

__all__ = [
"fused_moe",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a script to tune these parameters? No worries otherwise. I can implement it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, happy to add the script. The process of tuning is not fully automatic at the moment and requires some manual modifications, but I will contribute what I have 😊

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the script benchmark_mixtral_moe.py now to do the search -- in practice I wasn't using exactly this script but was doing some modifications of the script as I searched through the different batch sizes. But the script should be a good way to get started :)

Maybe by doing a more exhaustive search, we could even improve on these parameters, but I think the gap would be pretty small if we find something even better :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
}
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/fused_moe/configs/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.

The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional

import torch
import triton
import triton.language as tl

from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)


@triton.jit
def fused_moe_kernel(
Expand Down Expand Up @@ -210,6 +218,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.

The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""

# First look up if an optimized configuration is available in the configs directory
device_name = torch.cuda.get_device_name().replace(" ", "_")

config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs",
f"E={E},N={N},device_name={device_name}.json")
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}

# If no optimized configuration is available, we will use the default configuration
return None


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -218,6 +254,7 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Expand All @@ -230,6 +267,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand Down Expand Up @@ -279,20 +317,31 @@ def fused_moe(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}

if topk_ids.numel() <= w1.shape[0]:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2])

if configs:
# If an optimal configuration map has been found, look up the optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}

if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}

intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
Expand Down