Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,26 +203,41 @@ def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim):
return ffn1_flops + ffn2_flops


def calculate_deepseek_ffn_tflops_per_device(config):
def calculate_routed_and_shared_ffn_tflops_per_device(config):
"""Helper function to calculate DeepSeek-style ffn TFLOP"""
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
# Due to the mixed decoder layers, the flops is multiplied by num of layers for both dense and moe
dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.first_num_dense_layers
num_dense_layers, num_moe_layers = get_dense_moe_layers(config)
dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * num_dense_layers
shared_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.shared_experts
routed_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
moe_layers = config.num_decoder_layers - config.first_num_dense_layers
moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * moe_layers
moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * num_moe_layers
total_ffn_flops = dense_ffn_flops + moe_ffn_flops
return total_ffn_flops


def get_dense_moe_layers(config):
"""Helper function to calculate number of dense and moe layers"""
if config.decoder_block == DecoderBlockType.DEEPSEEK:
num_dense_layers = config.first_num_dense_layers
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
return num_dense_layers, num_moe_layers
elif config.decoder_block == DecoderBlockType.LLAMA4:
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
num_dense_layers = config.num_decoder_layers - num_moe_layers
else:
raise ValueError("Currently we only support DeepSeek and Llama4 calculation.")

return num_dense_layers, num_moe_layers


def calculate_tflops_training_per_device(config, log=True):
"""Calculate training TFLOP"""
# MLP flops
if config.num_experts > 1:
# calculation based on dropless implementation
if config.decoder_block == DecoderBlockType.DEEPSEEK:
total_ffn_flops = calculate_deepseek_ffn_tflops_per_device(config)
if config.decoder_block == DecoderBlockType.DEEPSEEK or config.decoder_block == DecoderBlockType.LLAMA4:
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
else:
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
total_ffn_flops = (
Expand Down Expand Up @@ -263,7 +278,7 @@ def calculate_tflops_training_per_device(config, log=True):
attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
)
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
elif config.decoder_block == DecoderBlockType.DEEPSEEK or config.decoder_block == DecoderBlockType.LLAMA4:
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
Expand Down
4 changes: 4 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def validate_llama4_config(keys: dict):
raise ValueError("Llama4 decoder has not been tested with capacity_factor >= 0 -- please set that value to -1 for now!")
if keys["num_experts_per_tok"] > 1:
raise ValueError("Only top-1 routing is supported for Llama4 for now!")
if keys["base_num_decoder_layers"] % keys["interleave_moe_layer_step"] != 0:
raise ValueError(
f"The number of decoder layers ({keys['base_num_decoder_layers']}) must be divisible by interleave moe layer step ({keys['interleave_moe_layer_step']})"
)


def validate_model_name(s: str) -> bool:
Expand Down
Loading