Skip to content

Conversation

@ChingTsai
Copy link
Collaborator

@ChingTsai ChingTsai commented Jan 29, 2026

Description

This PR resolves discrepancies in loss calculation and training step counts when running SFT with gradient accumulation enabled. Currently, MaxText.sft.sft_trainer (which uses the Tunix trainer) exhibits breaking behavior when gradient accumulation is turned on, diverging significantly from the native implementation in MaxText.sft_trainer. This change aligns the Tunix-based SFT logic to match the native behavior.

Problem Statement

  • Loss Disparity: When GA is enabled, MaxText-Tunix observed a massive loss scale disparity compared to the native implementation. This occurs because MaxText-Tunix uses the same loss_fn as the native implementation. The native logic explicitly skips dividing by total_weights inside the function (deferring normalization to a later stage here). Consequently, Tunix inherited this behavior and failed to normalize the loss, resulting in broken calculations and inflated values.

pr_fix_vs_original_vs_native

  • Step Count Mismatch: While MaxText-Native handles micro-batching internally by reshaping the full global batch, Tunix relies on the input pipeline to provide pre-sized micro-batches. Without this adjustment, Tunix was ingesting full global batches at every step, resulting in incorrect epoch calculations and causing the run to terminate prematurely compared to the native implementation.

FIXES: b/478823561

Tests

python3 -m MaxText.sft_trainer \
    src/MaxText/configs/sft.yml \
    run_name=$RUN_NAME \
    base_output_directory=..../qwen3-4b \
    model_name=qwen3-4b \
    load_parameters_path=..../qwen3-4b/0/items \
    tokenizer_path=Qwen/qwen3-4b \
    steps=$train_step \
    profiler=xplane \
    hf_path=arrow \
    dataset_type=hf \
    train_split=train \
    hf_train_files=..../data-00000-of-00001.arrow \
    hf_eval_files=..../data-00000-of-00001.arrow \
    per_device_batch_size=4 \
    gradient_accumulation_steps=4 \
    max_target_length=1024 \
    learning_rate=1.3e-5 \
    warmup_steps_fraction=0.05 \
    data_shuffle_seed=42 \
    gradient_clipping_threshold=1 \
    learning_rate_final_fraction=0 \
    weight_dtype=bfloat16

After applying the changes, the loss graphs of both versions are now almost identical.

graph_2026-01-29_15-52-15

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ChingTsai ChingTsai changed the title Fix loss and batching when using tunix Fix gradient accumulation in post training Jan 29, 2026
@ChingTsai ChingTsai changed the title Fix gradient accumulation in post training Fix gradient accumulation in post training sft Jan 29, 2026
@codecov
Copy link

codecov bot commented Jan 29, 2026

Codecov Report

❌ Patch coverage is 25.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/input_pipeline/_hf_data_processing.py 33.33% 2 Missing ⚠️
src/MaxText/train.py 0.00% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch from 69e7031 to f36a364 Compare January 29, 2026 08:44
@ChingTsai ChingTsai self-assigned this Jan 29, 2026
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch from f36a364 to b891b70 Compare January 29, 2026 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant