Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

# pylint: disable=bare-except, consider-using-generator
# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
""" Utils that are only interesting to MaxText. """

from typing import Optional
Expand Down Expand Up @@ -268,7 +268,7 @@ def calculate_tflops_training_per_device(config, log=True):

# Attention flops
if config.attention_type == "mla":
qkv_flops, attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
else:
qkv_flops = (
2
Expand All @@ -278,7 +278,7 @@ def calculate_tflops_training_per_device(config, log=True):
* (config.num_query_heads + 2 * config.num_kv_heads)
* config.head_dim
)
attention_flops = (
noncausal_attention_flops = (
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)
projection_flops = (
Expand All @@ -290,6 +290,12 @@ def calculate_tflops_training_per_device(config, log=True):
* config.head_dim
)

# Divide attantion flops by 2 due to causal mask
# References:
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
causal_attention_flops = noncausal_attention_flops / 2

# Embedding flops
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size

Expand All @@ -302,14 +308,13 @@ def calculate_tflops_training_per_device(config, log=True):
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
else:
# multiply by 3 for both feed forward and back propagation flops
learnable_weight_tflops = (
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
)
# megatron tflops calculation does not account for causality in attention
attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12

learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps
attention_tflops = attention_tflops * config.gradient_accumulation_steps
Expand Down Expand Up @@ -338,7 +343,7 @@ def calculate_tflops_training_per_device(config, log=True):
def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, config, log=True):
"""Calculate training TFLOP"""
learnable_weight_tflops = 2 * num_model_parameters * prefill_length / jax.device_count() / 1e12
noncasual_attention_flops = (
noncausal_attention_flops = (
4
* config.num_query_heads
* config.num_decoder_layers
Expand All @@ -347,7 +352,7 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
/ jax.device_count()
/ 1e12
)
causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention
causal_attention_tflops = noncausal_attention_flops / 2 # due to causality in attention
total_tflops = learnable_weight_tflops + causal_attention_tflops

if log:
Expand Down
Loading