Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,47 +227,49 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_heads', ['tensor','sequence','tensor_sequence']],
['activation_kv_heads', ['tensor','sequence','tensor_sequence']],
['activation_length', ['sequence']],
['activation_norm_length', ['tensor_sequence', 'sequence']],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_mlp', ['tensor', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_head_dim', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_kv_head_dim', ['tensor', 'tensor_sequence']],
['activation_vocab', ['tensor', 'sequence', 'tensor_sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'tensor_sequence'],
['activation_vocab', 'sequence'],
['activation_stage', 'stage'],
['activation_exp', 'expert'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed_no_exp', ['fsdp', 'sequence']],
['norm', 'tensor'],
['q_heads', ['tensor', 'autoregressive']],
['heads', ['tensor', 'autoregressive']],
['norm', ['tensor', 'tensor_sequence']],
['q_heads', ['tensor', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_sequence', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['kv_heads', ['tensor', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_sequence', 'autoregressive']],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
['cache_kv', []],
['cache_sequence', []],
['exp', 'expert'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand All @@ -281,6 +283,7 @@ dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommended
dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
Expand All @@ -289,6 +292,7 @@ ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
Expand Down
5 changes: 5 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AttentionType(enum.Enum):
KV_BATCH = common_types.KV_BATCH
LENGTH = common_types.LENGTH
HEAD = common_types.HEAD
EMBED = common_types.EMBED
KV_HEAD = common_types.KV_HEAD
D_KV = common_types.D_KV
KV_HEAD_DIM = common_types.KV_HEAD_DIM
Expand Down Expand Up @@ -1114,6 +1115,7 @@ class Attention(nn.Module):
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
Expand Down Expand Up @@ -1265,6 +1267,9 @@ def __call__(
Returns:
output of shape `[batch, length, q_features]`.
"""
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)

# apply projection.
if self.config.fused_qkv:
query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj")
Expand Down
12 changes: 7 additions & 5 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def __call__(
):
cfg = self.config
mesh = self.mesh
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",))(
inputs
)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

attention_layer = Attention(
config=cfg,
Expand Down Expand Up @@ -108,7 +108,9 @@ def __call__(
model_mode=model_mode,
)

attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
attention_lnx += inputs
residual = attention_lnx
attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",))(
Expand All @@ -126,7 +128,7 @@ def __call__(
config=cfg,
quant=self.quant,
)(attn_output, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

next_layer_addition = mlp_lnx + residual

Expand All @@ -137,7 +139,7 @@ def __call__(
layer_output = next_layer_addition_dropped_out
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.record_internal_nn_metrics:
Expand Down
24 changes: 14 additions & 10 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def __call__(
):
cfg = self.config
mesh = self.mesh
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_local", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

attention_layer = Attention(
config=cfg,
Expand Down Expand Up @@ -113,7 +113,9 @@ def __call__(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_local", kernel_axes=("norm",)
)(attention_lnx)

attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
attention_lnx += inputs
residual = attention_lnx

Expand All @@ -137,7 +139,7 @@ def __call__(
mlp_lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_local", kernel_axes=("norm",))(
mlp_lnx
)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

next_layer_addition = mlp_lnx + residual

Expand All @@ -148,18 +150,18 @@ def __call__(
layer_output = next_layer_addition_dropped_out
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

### global part
inputs = nn.with_logical_constraint(layer_output, ("activation_batch", "activation_length", "activation_embed"))
inputs = nn.with_logical_constraint(layer_output, ("activation_batch", "activation_norm_length", "activation_embed"))

# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_global", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

attention_layer = Attention(
config=cfg,
Expand Down Expand Up @@ -195,7 +197,9 @@ def __call__(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_global", kernel_axes=("norm",)
)(attention_lnx)

attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
attention_lnx += inputs
residual = attention_lnx

Expand All @@ -219,7 +223,7 @@ def __call__(
mlp_lnx
)

mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

next_layer_addition = mlp_lnx + residual

Expand All @@ -230,7 +234,7 @@ def __call__(
layer_output = next_layer_addition_dropped_out
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.record_internal_nn_metrics:
Expand Down
15 changes: 10 additions & 5 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
LENGTH = common_types.LENGTH
HEAD = common_types.HEAD
D_KV = common_types.D_KV
EMBED = common_types.EMBED

DenseGeneral = linears.DenseGeneral
NdInitializer = initializers.NdInitializer
Expand Down Expand Up @@ -148,6 +149,7 @@ class Gpt3MultiHeadAttention(nn.Module):
kv_quant: Optional[KVQuant] = None
use_bias: bool = True

input_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
Expand Down Expand Up @@ -213,6 +215,7 @@ def __call__(
model_mode: str = common_types.MODEL_MODE_TRAIN,
deterministic: bool = False,
):
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
if self.fused_qkv:
query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj")
else:
Expand Down Expand Up @@ -279,7 +282,7 @@ def __call__(
cfg = self.config
mesh = self.mesh

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx_layer_norm = Gpt3LayerNorm(
dtype=cfg.dtype,
Expand All @@ -291,7 +294,7 @@ def __call__(
)
lnx = lnx_layer_norm(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# Self-attention block
assert (
Expand Down Expand Up @@ -319,7 +322,9 @@ def __call__(
lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic
)

attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
attention_lnx += inputs

# MLP block.
Expand All @@ -335,15 +340,15 @@ def __call__(
config=cfg,
quant=self.quant,
)(attention_lnx, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = attention_lnx + mlp_lnx

layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)

layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.record_internal_nn_metrics:
Expand Down
16 changes: 10 additions & 6 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __call__(
cfg = self.config
mesh = self.mesh

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx_rms = models.RMSNorm(
dtype=cfg.dtype,
Expand All @@ -89,7 +89,7 @@ def __call__(
)
lnx = lnx_rms(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# Self-attention block
attention_layer = Attention(
Expand Down Expand Up @@ -124,7 +124,9 @@ def __call__(
model_mode=model_mode,
)

attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
intermediate_inputs = inputs + attention_lnx

# Fully Connected
Expand All @@ -135,7 +137,9 @@ def __call__(
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
hidden_states = nn.with_logical_constraint(
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
)

# MLP block.
mlp_lnx = linears.MlpBlock(
Expand All @@ -148,15 +152,15 @@ def __call__(
config=cfg,
quant=self.quant,
)(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = mlp_lnx + intermediate_inputs

layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)

layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
("activation_batch", "activation_norm_length", "activation_embed"),
)

if cfg.record_internal_nn_metrics:
Expand Down
Loading