Skip to content

[core] moe fp8 block quant tuning support #14068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: List[int] = None,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
Expand Down Expand Up @@ -81,8 +82,24 @@ def benchmark_config(
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
E = num_experts
N = shard_intermediate_size // 2
K = hidden_size
factor_for_scale = 1e-2
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
dtype=torch.float32) * factor_for_scale
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
dtype=torch.float32) * factor_for_scale
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)

a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)

Expand Down Expand Up @@ -111,6 +128,7 @@ def run():
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)

# JIT compilation & warmup
Expand Down Expand Up @@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges


def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
def get_configs_compute_bound(use_fp16,
block_quant_shape) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []

if current_platform.is_rocm():
Expand Down Expand Up @@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)

# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]:
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
"BLOCK_SIZE_N"] % block_n != 0:
configs.remove(config)
return configs


def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
search_space, is_fp16, topk):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
search_space, is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
search_space, is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space

Expand Down Expand Up @@ -372,6 +401,7 @@ def tune(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: list[dict[str, int]],
block_quant_shape: list[int],
) -> dict[str, int]:
best_config = None
best_time = float("inf")
Expand All @@ -380,21 +410,23 @@ def tune(
search_space = prune_rocm_search_space(num_tokens,
shard_intermediate_size,
hidden_size, search_space,
is_fp16)
is_fp16, topk)

with torch.cuda.device(self.device_id):
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
Expand Down Expand Up @@ -436,16 +468,16 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:

def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
block_quant_shape: List[int]) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)

# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
dtype_str, block_quant_shape)

print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
Expand All @@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,

def main(args: argparse.Namespace):
print(args)

block_quant_shape = None
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
Expand All @@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
block_quant_shape = config.quantization_config['weight_block_size']
else:
# Default: Mixtral.
E = config.num_local_experts
Expand Down Expand Up @@ -511,27 +544,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:

if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
print(f"Start tuning over {len(search_space)} configurations...")

start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
"tune",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
block_quant_shape)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])
"benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
for batch_size in batch_sizes])

for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
Expand All @@ -31,15 +31,15 @@
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
Expand All @@ -49,13 +49,13 @@
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
Expand All @@ -64,7 +64,7 @@
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
Expand All @@ -73,7 +73,7 @@
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
Expand All @@ -82,46 +82,82 @@
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
}
Expand Down