Skip to content

Conversation

@bvandermoon
Copy link
Collaborator

Description

Accept inputs shape instead of the full inputs in Attention/MLA. This will help since we are about to move Attention initialization to __init__ of the decoder layers. The full inputs are not available at that point.

Example:

This Attention initialization needs to move to __init__ in LlamaDecoderLayer. lnx is generated in __call__, so it won't be available in __init__

Tests

  • Base model train gives same perf before/after:
python3 -m MaxText.train MaxText/configs/base.yml  \
    run_name=<run_name> \
    base_output_directory=gs://<gcs_bucket> \
    dataset_type=synthetic \
    steps=10
  • Deepseek3-test gives same perf before/after:
python3 -m MaxText.train MaxText/configs/base.yml \
    run_name=bvandermoon-$RANDOM \
    base_output_directory=gs://bvandermoon-multipod-maxtext \
    dataset_type=synthetic \
    steps=10 \
    model_name=deepseek3-test \
    mla_naive_kvcache=False \ # Not needed for train
    max_target_length=256 \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=-1 \
    scan_layers=false \
    weight_dtype=bfloat16 opt_type=sgd

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@bvandermoon bvandermoon force-pushed the bvandermoon-nnx-attention-shape branch from 1f0ebd8 to dffa82b Compare August 6, 2025 17:04
Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@copybara-service copybara-service bot merged commit 3c30585 into main Aug 6, 2025
23 checks passed
@copybara-service copybara-service bot deleted the bvandermoon-nnx-attention-shape branch August 6, 2025 18:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants