-
Notifications
You must be signed in to change notification settings - Fork 620
[llama4] use grouped_mm in moe for sm90 #2755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/IvanKobzarev/5/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2755
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8864b5d with merge base 8195b2e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. [ghstack-poisoned]
@@ -180,6 +181,7 @@ def llama4_decoder( | |||
num_experts: int = 16, | |||
experts_per_token: int = 1, | |||
use_shared_expert: bool = True, | |||
use_grouped_mm: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to expose this as an option? Are there cases where it's supported and not the fastest solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I exposed it only for debugging. grouped_mm way was added recently and there could be some issues (NaNs etc.) . I think it will be good to have a switch to test the default version. Once everything is tested - to remove it. Alternatively we can have a global config for this, and not to change the Modules interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am OK to keep it configurable for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the NaN issues are resolved, right? Ofc due to the padding I suspect there will be slightly different numerics, but would still be good to explicitly confirm that we get similar loss curves out of the grouped_mm vs non-grouped_mm implementations. Personally I would lean towards just taking this out (especially since it's hardcoded in the top-level Scout builder anyways) cause we already have enough args as it is. But if you wanna keep it in until things stabilize I don't mind too much
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. [ghstack-poisoned]
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## gh/IvanKobzarev/5/base #2755 +/- ##
=========================================================
Coverage ? 59.91%
=========================================================
Files ? 437
Lines ? 26848
Branches ? 0
=========================================================
Hits ? 16085
Misses ? 10763
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44 grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872 Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately ``` no_compile + no grouped_mm: BASELINE: Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625 Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125 Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125 Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125 Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125 Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125 Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125 Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125 Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125 Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747910831.txt tps_avg: 3.2182 n= 10 peak_memory_alloc max 28.97 peak_memory_reser max 40.25 no_compile + grouped_mm: ______ EAGER_GROUPED_MM: └─ $ cat log_1747915295.txt Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375 Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125 Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125 Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125 Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125 Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125 Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125 Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125 Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375 Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375 Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915295.txt tps_avg: 4.44474 n= 13 peak_memory_alloc max 27.56 peak_memory_reser max 39.02 compile + no_grouped_mm: COMPILED BASELINE: └─ $ cat log_1747915995.txt Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125 Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625 Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625 Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625 Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625 Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625 Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625 Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625 Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625 Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915995.txt tps_avg: 4.33463 n= 12 peak_memory_alloc max 28.95 peak_memory_reser max 40.29 COMPILED + GROUPED_MM └─ $ cat log_1747916676.txt Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625 Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625 Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625 Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625 Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625 Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625 Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625 Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625 Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625 Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625 Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625 Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625 Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625 Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625 Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625 Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625 Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625 Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375 Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375 └─ $ tune_logs_tps log_1747916676.txt tps_avg: 8.70872 n= 24 peak_memory_alloc max 27.54 peak_memory_reser max 39.04 ``` [ghstack-poisoned]
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44 grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872 Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately ``` no_compile + no grouped_mm: BASELINE: Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625 Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125 Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125 Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125 Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125 Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125 Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125 Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125 Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125 Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747910831.txt tps_avg: 3.2182 n= 10 peak_memory_alloc max 28.97 peak_memory_reser max 40.25 no_compile + grouped_mm: ______ EAGER_GROUPED_MM: └─ $ cat log_1747915295.txt Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375 Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125 Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125 Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125 Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125 Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125 Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125 Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125 Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375 Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375 Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915295.txt tps_avg: 4.44474 n= 13 peak_memory_alloc max 27.56 peak_memory_reser max 39.02 compile + no_grouped_mm: COMPILED BASELINE: └─ $ cat log_1747915995.txt Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125 Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625 Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625 Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625 Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625 Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625 Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625 Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625 Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625 Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915995.txt tps_avg: 4.33463 n= 12 peak_memory_alloc max 28.95 peak_memory_reser max 40.29 COMPILED + GROUPED_MM └─ $ cat log_1747916676.txt Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625 Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625 Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625 Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625 Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625 Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625 Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625 Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625 Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625 Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625 Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625 Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625 Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625 Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625 Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625 Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625 Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625 Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375 Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375 └─ $ tune_logs_tps log_1747916676.txt tps_avg: 8.70872 n= 24 peak_memory_alloc max 27.54 peak_memory_reser max 39.04 ``` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left a handful of comments, lmk if anything is unclear. Excited to get grouped_mm landed!
torchtune/modules/moe/indices.py
Outdated
return permuted_indices, m_sizes, m_offsets.to(torch.int32) | ||
|
||
|
||
# Below is for testing only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put the below code in a separate unit test? E.g. tests/torchtune/modules/moe/test_indices.py?
torchtune/modules/moe/indices.py
Outdated
print(f"tokens_per_expert_group = {tokens_per_expert_group}") | ||
print(f"total_tokens_per_expert = {total_tokens_per_expert}") | ||
print(f"m_sizes = {m_sizes}") | ||
print(f"m_offsets = {m_offsets}") | ||
print(f"permuted_indices = {permuted_indices_gpu[:sum(m_sizes).item()]}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove these too and just rely on asserts once we move to a test
import triton.language as tl | ||
|
||
|
||
__all__ = ["generate_permute_indices"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we usually put __all__
in the __init__.py
file so if you want to expose as a public API you should do it there (but personally I also don't mind keeping this as a private API)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think better to keep it private.
@@ -180,6 +181,7 @@ def llama4_decoder( | |||
num_experts: int = 16, | |||
experts_per_token: int = 1, | |||
use_shared_expert: bool = True, | |||
use_grouped_mm: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am OK to keep it configurable for now
|
||
|
||
# reference | ||
def fill_indices_cpu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just for testing purposes? Or do we actually need it for e.g. CPU training? (Though in that case we probably can't use flex for Llama4 anyways)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is more for a reference and testing. It is not used in llama4 recipes. But I think better to keep it as it is easier to read and understand what kernel does.
torchtune/modules/moe/experts.py
Outdated
if num_tokens_per_expert is not None: | ||
# https://github.com/pytorch/pytorch/pull/150374 | ||
# NOTE: torch._gouped_mm requires bf16 dtypes | ||
# and shapes to be multiple of 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: elsewhere it says multiple of 16
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) | ||
# grouped mm between a 2D tensor and a 3D tensor | ||
assert x.dim() == 2 | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this case is for expert choice?
self.experts.num_experts, | ||
1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just because we are fully replicating the experts over all ranks, right?
From: | rank 0 | rank 1 | | ||
To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q on this diagram: does this mean that we actually have the same expert on multiple ranks (i.e. there are only 4 experts total)? I ask because the arg experts_per_rank
seems to imply that we actually have different experts on different ranks (I know it's moot for the implementation in its current form anyways so not a huge deal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our case each rank runs all the experts, but all mms happen in parallel.
So experts_per_rank is always 16 (all experts).
So we are not bucketing experts by ranks. I think this logic will be similar, to what TorchRec does with embeddings, basically each expert works as a lookup Table.
@@ -0,0 +1,353 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add an acknowledgment to the original torchtitan code somewhere. @lessw2020 lmk if any preference on the best way to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left a handful of comments, lmk if anything is unclear. Excited to get grouped_mm landed!
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44 grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872 Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately ``` no_compile + no grouped_mm: BASELINE: Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625 Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125 Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125 Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125 Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125 Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125 Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125 Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125 Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125 Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747910831.txt tps_avg: 3.2182 n= 10 peak_memory_alloc max 28.97 peak_memory_reser max 40.25 no_compile + grouped_mm: ______ EAGER_GROUPED_MM: └─ $ cat log_1747915295.txt Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375 Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125 Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125 Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125 Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125 Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125 Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125 Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125 Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375 Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375 Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915295.txt tps_avg: 4.44474 n= 13 peak_memory_alloc max 27.56 peak_memory_reser max 39.02 compile + no_grouped_mm: COMPILED BASELINE: └─ $ cat log_1747915995.txt Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125 Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625 Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625 Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625 Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625 Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625 Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625 Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625 Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625 Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915995.txt tps_avg: 4.33463 n= 12 peak_memory_alloc max 28.95 peak_memory_reser max 40.29 COMPILED + GROUPED_MM └─ $ cat log_1747916676.txt Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625 Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625 Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625 Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625 Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625 Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625 Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625 Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625 Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625 Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625 Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625 Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625 Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625 Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625 Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625 Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625 Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625 Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375 Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375 └─ $ tune_logs_tps log_1747916676.txt tps_avg: 8.70872 n= 24 peak_memory_alloc max 27.54 peak_memory_reser max 39.04 ``` [ghstack-poisoned]
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44 grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872 Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately ``` no_compile + no grouped_mm: BASELINE: Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625 Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125 Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125 Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125 Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125 Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125 Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125 Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125 Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125 Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747910831.txt tps_avg: 3.2182 n= 10 peak_memory_alloc max 28.97 peak_memory_reser max 40.25 no_compile + grouped_mm: ______ EAGER_GROUPED_MM: └─ $ cat log_1747915295.txt Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375 Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125 Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125 Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125 Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125 Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125 Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125 Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125 Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375 Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375 Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915295.txt tps_avg: 4.44474 n= 13 peak_memory_alloc max 27.56 peak_memory_reser max 39.02 compile + no_grouped_mm: COMPILED BASELINE: └─ $ cat log_1747915995.txt Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125 Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625 Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625 Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625 Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625 Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625 Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625 Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625 Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625 Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915995.txt tps_avg: 4.33463 n= 12 peak_memory_alloc max 28.95 peak_memory_reser max 40.29 COMPILED + GROUPED_MM └─ $ cat log_1747916676.txt Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625 Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625 Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625 Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625 Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625 Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625 Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625 Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625 Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625 Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625 Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625 Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625 Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625 Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625 Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625 Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625 Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625 Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375 Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375 └─ $ tune_logs_tps log_1747916676.txt tps_avg: 8.70872 n= 24 peak_memory_alloc max 27.54 peak_memory_reser max 39.04 ``` [ghstack-poisoned]
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty) ``` out = _grouped_mm(h, w2, offs=offsets) out[offsets[-1] :].zero_() ``` scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now. grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44 grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872 Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately ``` no_compile + no grouped_mm: BASELINE: Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625 Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125 Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125 Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125 Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125 Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125 Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125 Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125 Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125 Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747910831.txt tps_avg: 3.2182 n= 10 peak_memory_alloc max 28.97 peak_memory_reser max 40.25 no_compile + grouped_mm: ______ EAGER_GROUPED_MM: └─ $ cat log_1747915295.txt Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375 Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125 Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125 Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125 Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125 Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125 Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125 Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125 Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125 Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375 Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375 Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915295.txt tps_avg: 4.44474 n= 13 peak_memory_alloc max 27.56 peak_memory_reser max 39.02 compile + no_grouped_mm: COMPILED BASELINE: └─ $ cat log_1747915995.txt Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125 Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625 Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625 Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625 Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625 Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625 Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625 Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625 Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625 Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625 /tmp/torchtune/llama4_17Bx16E/full/logs └─ $ tune_logs_tps log_1747915995.txt tps_avg: 4.33463 n= 12 peak_memory_alloc max 28.95 peak_memory_reser max 40.29 COMPILED + GROUPED_MM └─ $ cat log_1747916676.txt Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625 Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625 Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625 Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625 Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625 Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625 Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625 Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625 Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625 Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625 Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625 Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625 Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625 Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625 Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625 Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625 Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625 Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625 Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625 Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375 Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375 └─ $ tune_logs_tps log_1747916676.txt tps_avg: 8.70872 n= 24 peak_memory_alloc max 27.54 peak_memory_reser max 39.04 ``` [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Enablement copy pasted from torchtitan.
There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty)
scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index.
To prevent copying unitialized data - zero_() after grouped_mm for now.
grouped_mm in eager:
eager tps_avg: 3.2182 -> 4.44
grouped_mm in compile:
compile tps_avg: 4.33463 -> 8.70872
Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately