Skip to content

Commit 1105cc6

Browse files
committed
[Full DTensor] Initial skeleton for full_dtensor mode
This PR provides a skelet This PR introduces an initial prototype and skeleton for fully DTensor-based training. The current codebase builds upon SimpleFSDP, but we anticipate developing our own Reparameterization to better serve our specific use case. There are several reasons why SimpleFSDP's Reparameterization is insufficient. For instance, the current parallelize_buffers() implementation in this PR will not function correctly when additional parallelization strategies are applied. Despite these limitations, this PR provides a starting point for experimenting with a full DTensor trainer. Accuracy verification: HSDP SimpleFSDP v.s. FSDP2 ``` python3 scripts/loss_compare.py . . \ --baseline-options='--activation_checkpoint.mode="none" --parallelism.data_parallel_replicate_degree=2' \ --test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none" --parallelism.data_parallel_replicate_degree=2' \ --test-train-file=torchtitan.experiments.full_dtensor.train \ --steps=10 --assert-equal --no-seed-checkpoint ``` ``` [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK ``` Note that, `--no-seed-checkpoint` is used because when seed-checkpoint is used, we got accuracy mismatch. ghstack-source-id: c177628 Pull-Request: #2049
1 parent bb2ab1a commit 1105cc6

File tree

7 files changed

+476
-0
lines changed

7 files changed

+476
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
name: Full DTensor 8 GPU Integration Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
paths:
7+
- 'torchtitan/experiments/full_dtensor/**'
8+
- '.github/workflows/integration_test_8gpu_full_dtensor.yaml'
9+
pull_request:
10+
paths:
11+
- 'torchtitan/experiments/full_dtensor/**'
12+
- '.github/workflows/integration_test_8gpu_full_dtensor.yaml'
13+
schedule:
14+
# Runs every 12 hours
15+
- cron: '0 */12 * * *'
16+
17+
concurrency:
18+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
19+
cancel-in-progress: true
20+
21+
defaults:
22+
run:
23+
shell: bash -l -eo pipefail {0}
24+
25+
jobs:
26+
build-test:
27+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
28+
with:
29+
runner: linux.g5.48xlarge.nvidia.gpu
30+
gpu-arch-type: cuda
31+
gpu-arch-version: "12.6"
32+
# This image is faster to clone than the default, but it lacks CC needed by triton
33+
# (1m25s vs 2m37s).
34+
docker-image: torchtitan-ubuntu-20.04-clang12
35+
repository: pytorch/torchtitan
36+
upload-artifact: outputs
37+
script: |
38+
set -eux
39+
40+
# The generic Linux job chooses to use base env, not the one setup by the image
41+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
42+
conda activate "${CONDA_ENV}"
43+
44+
# Log CUDA driver version for debugging.
45+
DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true)
46+
echo "CUDA driver version: ${DRIVER_VERSION}"
47+
48+
pip config --user set global.progress_bar off
49+
50+
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
51+
52+
mkdir artifacts-to-be-uploaded
53+
TRAIN_FILE=torchtitan.experiments.full_dtensor.train python -m torchtitan.experiments.full_dtensor.tests.integration_tests artifacts-to-be-uploaded --ngpu 8

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"gpt_oss",
1010
"simple_fsdp.llama3",
1111
"simple_fsdp.deepseek_v3",
12+
"full_dtensor.llama3",
1213
"vlm",
1314
"compiler_toolkit.deepseek_v3",
1415
"compiler_toolkit.llama3",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.loss import build_cross_entropy_loss
8+
from torchtitan.components.lr_scheduler import build_lr_schedulers
9+
from torchtitan.components.optimizer import build_optimizers
10+
from torchtitan.components.tokenizer import build_hf_tokenizer
11+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
12+
from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer
13+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
14+
from torchtitan.models.llama3 import llama3_args
15+
from torchtitan.protocols.train_spec import TrainSpec
16+
17+
from .parallelize import parallelize_llama
18+
19+
20+
def get_train_spec() -> TrainSpec:
21+
return TrainSpec(
22+
model_cls=SimpleFSDPTransformer,
23+
model_args=llama3_args,
24+
parallelize_fn=parallelize_llama,
25+
pipelining_fn=pipeline_llm,
26+
build_optimizers_fn=build_optimizers,
27+
build_lr_schedulers_fn=build_lr_schedulers,
28+
build_dataloader_fn=build_text_dataloader,
29+
build_tokenizer_fn=build_hf_tokenizer,
30+
build_loss_fn=build_cross_entropy_loss,
31+
)
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import functools
8+
from typing import Any, Callable
9+
10+
import torch
11+
import torch.nn as nn
12+
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
13+
14+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
15+
from torchtitan.distributed import ParallelDims
16+
17+
from torchtitan.experiments.compiler_toolkit.graph_utils import (
18+
CompiledModule,
19+
joint_graph_builder,
20+
make_compiler_with_passes,
21+
)
22+
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
23+
annotate_fsdp_all_gather,
24+
)
25+
from torchtitan.experiments.simple_fsdp.simple_fsdp import (
26+
data_parallel,
27+
MixedPrecisionPolicy,
28+
)
29+
from torchtitan.tools.logging import logger
30+
31+
32+
def _get_dp_mesh(parallel_dims: ParallelDims) -> DeviceMesh:
33+
if parallel_dims.dp_replicate_enabled:
34+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
35+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
36+
else:
37+
dp_mesh_dim_names = ("dp_replicate",)
38+
else:
39+
dp_mesh_dim_names = ("dp_shard_cp",)
40+
41+
return parallel_dims.world_mesh[tuple(dp_mesh_dim_names)]
42+
43+
44+
def _get_spmd_mesh(parallel_dims: ParallelDims) -> DeviceMesh:
45+
return _get_dp_mesh(parallel_dims)
46+
47+
48+
def apply_dp(
49+
model: nn.Module,
50+
parallel_dims: ParallelDims,
51+
job_config: JobConfig,
52+
) -> nn.Module:
53+
if parallel_dims.dp_replicate_enabled:
54+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
55+
dp_mode = "hybrid_shard"
56+
else:
57+
dp_mode = "replicate"
58+
else:
59+
dp_mode = "fully_shard"
60+
61+
mp_policy = MixedPrecisionPolicy(
62+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
63+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
64+
)
65+
66+
model = data_parallel(
67+
model,
68+
_get_dp_mesh(parallel_dims),
69+
mode=dp_mode,
70+
mp_policy=mp_policy,
71+
full_dtensor=True,
72+
)
73+
logger.info(
74+
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
75+
)
76+
return model
77+
78+
79+
def parallelize_llama(
80+
model: nn.Module,
81+
parallel_dims: ParallelDims,
82+
job_config: JobConfig,
83+
) -> nn.Module:
84+
if parallel_dims.cp_enabled:
85+
# TODO: SDPA + CP enablement:
86+
# Dependency: https://github.com/pytorch/pytorch/pull/167381 (sharding rule fix)
87+
# Goal: Enable Shard(2) -> Replicate() transition on the CP mesh placement.
88+
#
89+
# Implementation options for handling the required allgather:
90+
# 1. Transform into ring attention (requires converting current implementation
91+
# to an async TP-like operation)
92+
# 2. Retain explicit allgather (approach used in Llama 4)
93+
94+
# TODO: FlexAttention + CP enablement:
95+
# Need to resolve DTensor + FlexAttention compatibility issues.
96+
97+
raise NotImplementedError("CP is not implemented yet.")
98+
99+
if parallel_dims.tp_enabled:
100+
# TODO: TP parallelization strategy - Key architectural decision:
101+
#
102+
# Option 1: Parallelize parameters directly (current design approach)
103+
# - Apply TP dimension immediately at this point
104+
# - Requires _StridedShard for implementation
105+
#
106+
# Option 2: Record the placement and apply full placements later
107+
# - Record TP dimension placement now, apply full placement later with DP dimension
108+
# - No need to use _StridedShard, we can just use Shard()
109+
#
110+
# It's mostly likely that we will go with option 2 as we are going to use
111+
# parameterization to handle the full parameters transformation, which
112+
# makes option 2 more natural.
113+
raise NotImplementedError("TP is not implemented yet.")
114+
115+
if job_config.activation_checkpoint.mode != "none":
116+
# TODO: Graph based AC.
117+
raise NotImplementedError("AC is not implemented yet.")
118+
119+
# TODO: CP integration challenge:
120+
#
121+
# Problem:
122+
# When CP is enabled, the mesh structure becomes ["dp_replicate", "dp_shard", "cp"]
123+
# to maintain sequence sharding in DTensor. However, naively applying data_parallel
124+
# may trigger two separate allgather operations because DTensor.redistribute cannot
125+
# recognize that the two mesh dimensions can be flattened into a single allgather.
126+
#
127+
# Potential solution using SimpleFSDP:
128+
# 1. Transform mesh: ["dp_replicate", "dp_shard", "cp"] -> ["dp_replicate", "dp_shard_cp"]
129+
# via to_local() and from_local()
130+
# 2. Redistribute placement on ["dp_shard_cp"] dimension
131+
# 3. Transform mesh back: ["dp_replicate", "dp_shard_cp"] -> ["dp_replicate", "dp_shard", "cp"]
132+
# via to_local() and from_local()
133+
#
134+
# Note: This solution leaves the dp_shard process group wasted (
135+
# we can initialize it with fake backend).
136+
#
137+
# Note: We may be able to implement this solution with parameterization directly.
138+
139+
# Keep cp_enabled here to remind us cp_enabled=True requires data_parallel
140+
if (
141+
parallel_dims.dp_replicate_enabled
142+
or parallel_dims.dp_shard_enabled
143+
or parallel_dims.cp_enabled
144+
):
145+
model = apply_dp(model, parallel_dims, job_config)
146+
147+
# Apply compilation after SPMD parallelization is complete. This differs from
148+
# eager mode parallelization where compilation occurs earlier.
149+
return apply_compile(model, parallel_dims, job_config)
150+
151+
152+
def parallelize_buffers(
153+
model: nn.Module,
154+
parallel_dims: ParallelDims,
155+
) -> nn.Module:
156+
# Buffer-to-mesh mapping in multi-SPMD scenarios:
157+
#
158+
# When buffers are used with different SPMD meshes (e.g., dense vs sparse meshes), we
159+
# will need an explicit mapping to associate each buffer with its corresponding mesh.
160+
# This indicates that the current implementation is not general enough to support
161+
# nD meshes.
162+
#
163+
# The solution is that we need to reparameterize the buffers together with the
164+
# parameters within a module.
165+
spmd_mesh = _get_spmd_mesh(parallel_dims)
166+
placements = (Replicate() for _ in range(spmd_mesh.ndim))
167+
for m in model.modules():
168+
buffers = {
169+
name: DTensor.from_local(b, spmd_mesh, placements)
170+
for name, b in m.named_buffers(recurse=False)
171+
}
172+
for name, b in buffers.items():
173+
setattr(m, name, b)
174+
175+
return model
176+
177+
178+
def build_parallelize_inputs_fn(
179+
parallel_dims: ParallelDims,
180+
) -> Callable[[torch.Tensor, torch.Tensor], tuple[DTensor, DTensor]]:
181+
spmd_mesh = _get_spmd_mesh(parallel_dims)
182+
183+
# TODO: We need to make this more general to support nD mesh. But we can do this
184+
# after the DeviceMesh revamp PR is landed.
185+
spmd_placements = []
186+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
187+
spmd_placements.append(Shard(0))
188+
if parallel_dims.dp_replicate_enabled:
189+
spmd_placements.append(Shard(0))
190+
191+
def parallelize_inputs(
192+
inputs: torch.Tensor, labels: torch.Tensor
193+
) -> tuple[DTensor, DTensor]:
194+
inputs = DTensor.from_local(inputs, spmd_mesh, spmd_placements)
195+
labels = DTensor.from_local(labels, spmd_mesh, spmd_placements)
196+
return inputs, labels
197+
198+
return parallelize_inputs
199+
200+
201+
def joint_custom_pass_builder(
202+
parallel_dims: ParallelDims, job_config: JobConfig
203+
) -> Callable:
204+
match job_config.parallelism.fsdp_reshard_after_forward:
205+
case "always":
206+
fsdp_reshard_after_forward = True
207+
case "never":
208+
fsdp_reshard_after_forward = False
209+
case "default":
210+
# For PP, by default do not reshard after forward to avoid per-microbatch
211+
# all-gathers, which can be expensive and non-overlapped
212+
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
213+
case _:
214+
raise ValueError(
215+
"Invalid fsdp_reshard_after_forward_policy: "
216+
f"{job_config.parallelism.fsdp_reshard_after_forward}."
217+
)
218+
219+
def joint_ac_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
220+
gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward)
221+
gm.recompile()
222+
return gm
223+
224+
def joint_custom_pass(joint_with_descriptors) -> None:
225+
# TODO: Is this safe? Or should we use update_joint_with_descriptors from auto_parallel?
226+
joint_with_descriptors.graph_module = joint_ac_pass(
227+
joint_with_descriptors.graph_module
228+
)
229+
230+
231+
def apply_compile(
232+
model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig
233+
) -> nn.Module:
234+
# TODO: This API just implements compiler toolkit.
235+
# We should also add torch.compile() support
236+
237+
if not (job_config.compile.enable and "model" in job_config.compile.passes):
238+
return model
239+
240+
compiler_passes = []
241+
# Create compilers with specified passes (defaults to no passes)
242+
fw_compiler, bw_compiler = make_compiler_with_passes(
243+
compiler_passes, dump_folder=job_config.job.dump_folder
244+
)
245+
246+
# Create custom joint_graph_builder with llama-specific compilers and validation
247+
llama_joint_graph_builder = functools.partial(
248+
joint_graph_builder,
249+
fw_compiler=fw_compiler,
250+
bw_compiler=bw_compiler,
251+
joint_custom_pass=joint_custom_pass_builder(parallel_dims, job_config),
252+
dump_folder=job_config.job.dump_folder,
253+
)
254+
255+
# Full DTensor trainer will convert the inputs to DTensor, so we don't
256+
# need CompiledModule to do it.
257+
def dummy_parallelize_inputs(
258+
mesh: DeviceMesh, args: tuple[Any, ...], kwargs: dict[str, Any]
259+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
260+
return args, kwargs
261+
262+
return CompiledModule(
263+
model, parallel_dims, llama_joint_graph_builder, dummy_parallelize_inputs
264+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)