Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __call__(
cfg = self.config
mesh = self.mesh
if model_mode == MODEL_MODE_PREFILL:
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")

if model_mode == MODEL_MODE_PREFILL:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
Expand Down Expand Up @@ -610,7 +610,7 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
else:
logits = nn.with_logical_constraint(
logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
logits, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")
)

if self.config.cast_logits_to_fp32:
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
output = jnp.asarray(embedding, self.dtype)[inputs]

output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed")
output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed")
output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_embed")

if model_mode == MODEL_MODE_PREFILL:
output = nn.with_logical_constraint(output, output_prefill_axis_names)
Expand Down
12 changes: 6 additions & 6 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from MaxText import max_logging
from MaxText import max_utils
from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH_NO_EXP, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
from MaxText.layers import initializers, nnx_wrappers
from MaxText.layers.linears import mlp_block
from MaxText.layers import models
Expand Down Expand Up @@ -197,11 +197,11 @@ class Gpt3MultiHeadAttention(nn.Module):
kv_quant: None | KVQuant = None
use_bias: bool = True

input_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED)
query_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
key_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
value_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)

def qkv_projection(self, inputs: Array, proj_name: str):
"""Fused QKV projection"""
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
if self.model_mode == MODEL_MODE_PREFILL:
x = nn.with_logical_constraint(x, ("activation_batch", "prefill_activation_length", "activation_mlp"))
else:
x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_mlp"))
x = nn.with_logical_constraint(x, ("activation_batch", "activation_length_no_exp", "activation_mlp"))
output = self.wo(x)

output = checkpoint_name(output, "mlpwo")
Expand Down
32 changes: 16 additions & 16 deletions MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,19 +835,19 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
batch_logical_axis = "activation_batch_no_exp"

if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length_no_exp", "activation_embed"))
w0_bias_pspec = nn.logical_to_mesh_axes(("exp", None))
w1_bias_pspec = nn.logical_to_mesh_axes(("exp", None))
wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed"))
else:
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length_no_exp", None))
w0_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed"))

gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length_no_exp", None))
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length_no_exp", None))
else:
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits_pspec = None
Expand Down Expand Up @@ -875,7 +875,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
w1_bias_pspec,
wo_bias_pspec,
),
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))),
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_length_no_exp", "activation_embed"))),
check_rep=False,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias):
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
)
expert_token_count = nn.with_logical_constraint(
expert_token_count,
("activation_batch", "activation_norm_length", None, None, None),
("activation_batch", "activation_length_no_exp", None, None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def generate_masks(self, top_k_indices, softmax_probs):
)
expert_token_count = nn.with_logical_constraint(
expert_token_count,
("activation_batch", "activation_norm_length", None, None),
("activation_batch", "activation_length_no_exp", None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
Expand Down Expand Up @@ -1293,10 +1293,10 @@ def dense_matmul(
) -> tuple[jax.Array, Optional[jax.Array]]:
"""Dense matrix multiplication."""
# gate_logits: batch, length, expert
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None))
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length_no_exp", None))
if self.config.model_name.startswith("deepseek3"):
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_norm_length", None))
pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_length_no_exp", None))
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits)
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
if is_llama4_decoder_layer:
Expand All @@ -1323,7 +1323,7 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch", "activation_norm_length", None, None)
mask_axes = ("activation_batch", "activation_length_no_exp", None, None)
dispatch_axis = (
"activation_exp",
"activation_batch_no_exp",
Expand All @@ -1347,14 +1347,14 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
mask_axes = (
"activation_norm_length",
"activation_length_no_exp",
"activation_batch",
None,
None,
None,
)
input_axis = (
"activation_norm_length",
"activation_length_no_exp",
"activation_batch",
None,
"activation_embed",
Expand All @@ -1376,14 +1376,14 @@ def dense_matmul(
else:
mask_axes = (
"activation_batch",
"activation_norm_length",
"activation_length_no_exp",
None,
None,
None,
)
input_axis = (
"activation_batch",
"activation_norm_length",
"activation_length_no_exp",
None,
"activation_embed",
)
Expand Down Expand Up @@ -1423,7 +1423,7 @@ def dense_matmul(
(
None,
"activation_batch_no_exp",
"activation_norm_length",
"activation_length_no_exp",
None,
"activation_embed",
),
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def dense_matmul(
)
return output, loss
else:
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length_no_exp", "activation_embed"))
with jax.named_scope("wi_0"):
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
Expand Down
14 changes: 8 additions & 6 deletions MaxText/layers/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def self_attention_with_norm(
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(inputs_checkpoint)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# Self-attention block
attention_layer = attentions.attention_as_linen(
Expand Down Expand Up @@ -94,7 +94,7 @@ def self_attention_with_norm(
model_mode=model_mode,
)
attention_output = nn.with_logical_constraint(
attention_output, ("activation_batch", "activation_length", "activation_embed")
attention_output, ("activation_batch", "activation_norm_length", "activation_embed")
)

# Residual connection after attention
Expand All @@ -109,7 +109,9 @@ def self_attention_with_norm(
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(residual_after_attention)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
hidden_states = nn.with_logical_constraint(
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
)

return hidden_states, residual_after_attention

Expand Down Expand Up @@ -167,7 +169,7 @@ def __call__(
layer_output = residual_after_attention + mlp_output
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.scan_layers:
Expand Down Expand Up @@ -230,13 +232,13 @@ def __call__(
if load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)

mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_length", "activation_embed"))
mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_norm_length", "activation_embed"))

# Final residual connection
layer_output = residual_after_attention + mlp_output
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.scan_layers:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,12 +1170,12 @@ def using_tensor_parallelism(raw_keys) -> bool:


def using_sequence_parallelism(raw_keys) -> bool:
if int(raw_keys["ici_expert_parallelism"]) > 1 and int(raw_keys["dcn_expert_parallelism"]) > 1:
raise ValueError("Expert parallelism can only be enabled on ICI or DCN, not both.")
return int(raw_keys["ici_sequence_parallelism"]) > 1 or int(raw_keys["dcn_sequence_parallelism"]) > 1


def using_expert_parallelism(raw_keys) -> bool:
if int(raw_keys["ici_expert_parallelism"]) > 1 and int(raw_keys["dcn_expert_parallelism"]) > 1:
raise ValueError("Expert parallelism can only be enabled on ICI or DCN, not both.")
return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1


Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_t5x_cross_entropy(self):
# Calculate xent from custom T5X implementation
one_hot_targets = jax.nn.one_hot(targets, 4096)
t5x_xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)
t5x_xent = nn.with_logical_constraint(t5x_xent, ("activation_batch", "activation_length"))
t5x_xent = nn.with_logical_constraint(t5x_xent, ("activation_batch", "activation_length_no_exp"))

# Compare results
self.assertTrue(jax.numpy.allclose(optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False))
Expand Down
6 changes: 4 additions & 2 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,15 @@ def __call__(self, inputs, deterministic: bool = False):
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype)
mlp_lnx = jnp.zeros_like(inputs)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length_no_exp", "activation_embed"))

for k in range(self.num_experts):
weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1)
getattr(self, f"mlp_{k}")
mlp_lnx_exp = getattr(self, f"mlp_{k}")(inputs, deterministic=deterministic)
mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx_exp = nn.with_logical_constraint(
mlp_lnx_exp, ("activation_batch", "activation_length_no_exp", "activation_embed")
)
mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp
mlp_lnx += mlp_lnx_exp

Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):

one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length_no_exp"))
# Mask out paddings at the end of each example.
xent = xent * (data["targets_segmentation"] != 0)
total_loss = jnp.sum(xent)
Expand Down
Loading