Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 18, 2025

Experiments like SimpleFSDP/Compiler Toolkit/Autoparallel are all being developed at the same time, and SimpleFSDP/Compiler Toolkit both run into issues with PP that requires the PP utilities from Autoparallel. We want to land the Autoparallel experiment into main to facilitate that sharing.

wconstab and others added 30 commits July 11, 2025 12:46
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```

fix lint
lets existing torchtitan knobs which govern DP/TP mesh creation and mesh
size influence the sharding constraints of autoparallel, allowing it to
support these different sharding configurations.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
- fix passing of "none" (not None) to control bucketing passes
prints an (internal, vpn) only link for each profile trace file that's
saved to manifold.  Just search for 'trace' in your job logs on mast,
and click one of the rank links.

e.g.

[trainer37|5]:[titan] 2025-08-07 14:21:01,227 - root - INFO - Finished
dumping profiler traces in 5.22 seconds:
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/torchtrain_datasets/tree/outputs/torchtitan-64-whc-jv2j4mp/profile_trace/iteration_20/rank37_trace.json
This PR makes bucket sizes for all-gather and reduce-scatter to be of
the same size for 1d FSDP.
IMO we should just add the loss in the model and let autoparallel
parallelize it for us. But for now, let's follow how the other models
are implemented
just add `--experimental.enable_simplefsdp_passes` and do not try to
combine it with other `bucket_*` or `reorder_*` options.
This command should now run

`CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml"
./run_train.sh --model.name deepseekv3_auto_parallel`

However it doesn't actually do anything with autoparallel yet. Next step
is to attach local_map to the model so that autoparallel can run.
Validated debugmodel llama3 works, but ds3 crashes becuase of
`build_optimizers_with_moe_load_balancing` doing stuff that traverses
the original model structure, only now its an AutoParallelModule which
isn't compatible, we'll have to disable this optimization for now and
think about what to do.

Note: paths have changed, update your run commands:

`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml
./run_train.sh --model.name llama3_auto_parallel
--parallelism.tensor_parallel_degree 4`

Failing (ds3):
`CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml
./run_train.sh --model.name deepseekv3_auto_parallel`
as titled, this pr adds entry to simplefsdp's autobucketing pass in
autoparallel. original code is in:
pytorch/pytorch#160282

The main code for autobucketing pass will be added to autoparallel repo.
needs to merge in lock step with
meta-pytorch/autoparallel#233
@xmfan xmfan marked this pull request as ready for review November 21, 2025 23:44
@xmfan xmfan requested a review from wwwjn as a code owner November 21, 2025 23:44
Comment on lines 17 to 34
custom_import: str = ""
"""
This option enables the importation of external modules.
Currently, it only supports dotted import modules (e.g., some_package.model_x).
It is the user's responsibility to ensure that the specified path can be
successfully imported. One method to achieve this, you can place your module
inside the ``torchtitan/torchtitan`` folder and execute ``pip install -e .`` to
make it available for import.
"""

custom_args_module: str = ""
"""
DEPRECATED (moved to Job.custom_config_module). Will be removed soon.
This option allows users to extend TorchTitan's existing JobConfig by extending
a user defined JobConfig dataclass. Similar to ``--experimental.custom_import``, the user
needs to ensure that the path can be imported.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first two are available in base job_config.py.

ft_manager=ft_manager,
)

def should_manual_allreduce(tokens_per_expert_by_layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this trying to do? also do you really need a function for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokens_per_expert_by_layer is usually a plain tensor, so you need to call dist.AR on it for stats. but if it's a dtensor, we just need to redistribute

@xmfan xmfan changed the title Autoparallel as an experiment into main Autoparallel as an experiment in main Nov 22, 2025
Comment on lines 385 to 387
tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute(
placements=[Replicate()] * dp_cp_mesh.ndim
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how it works in autop, what is the dimension of the mesh that tokens_per_expert_by_layer is on?

Regardless of what mesh it is on, you can just call tokens_per_expert_by_layer = tokens_per_expert_by_layer.full_tensor() as syntactic sugar to make it Replicate on every dimension of the mesh.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i need to use the ap mesh. for ap, all params and buffers are dtensors, i need to keep this as a dtensor for the add_ below

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no guarantee that tokens_per_expert_by_layer's mesh is the same as dp_cp_mesh, so we should do something like

                tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute(
                    placements=[Replicate()] * tokens_per_expert_by_layer.mesh.ndim
                )

@sanketpurandare sanketpurandare self-requested a review November 22, 2025 06:33
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) |
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
| [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |
| [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been renamed to the item below, may need to rebase / remove

if "flex_attn" in config:
continue

use_flex_attn = (default_args.use_flex_attn,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this field has been renamed to attn_type

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to a an integration test for AP in this PR or it will come in anther PR?

"vlm",
"compiler_toolkit.deepseek_v3",
"compiler_toolkit.llama3",
"transformers_backend",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a merge conflict? transformers_modeling_backend is already declared below.

@xmfan
Copy link
Member Author

xmfan commented Nov 24, 2025

Do you plan to a an integration test for AP in this PR or it will come in anther PR?

Fast follow

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xmfan xmfan merged commit 7e10d60 into main Nov 25, 2025
9 checks passed
@tianyu-l tianyu-l deleted the autoparallel branch November 25, 2025 04:23
@wconstab
Copy link
Contributor

Nice!

Re deleting the autop branch, was the profiler change for manifold url landed on main or do we need to restore a copy of that to the autoparallel branch for now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants