|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import functools |
| 8 | +from typing import Any, Callable |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard |
| 13 | + |
| 14 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
| 15 | +from torchtitan.distributed import ParallelDims |
| 16 | + |
| 17 | +from torchtitan.experiments.compiler_toolkit.graph_utils import ( |
| 18 | + CompiledModule, |
| 19 | + joint_graph_builder, |
| 20 | + make_compiler_with_passes, |
| 21 | +) |
| 22 | +from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( |
| 23 | + annotate_fsdp_all_gather, |
| 24 | +) |
| 25 | +from torchtitan.experiments.simple_fsdp.simple_fsdp import ( |
| 26 | + data_parallel, |
| 27 | + MixedPrecisionPolicy, |
| 28 | +) |
| 29 | +from torchtitan.tools.logging import logger |
| 30 | + |
| 31 | + |
| 32 | +def _get_dp_mesh(parallel_dims: ParallelDims) -> DeviceMesh: |
| 33 | + if parallel_dims.dp_replicate_enabled: |
| 34 | + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| 35 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| 36 | + else: |
| 37 | + dp_mesh_dim_names = ("dp_replicate",) |
| 38 | + else: |
| 39 | + dp_mesh_dim_names = ("dp_shard_cp",) |
| 40 | + |
| 41 | + return parallel_dims.world_mesh[tuple(dp_mesh_dim_names)] |
| 42 | + |
| 43 | + |
| 44 | +def _get_spmd_mesh(parallel_dims: ParallelDims) -> DeviceMesh: |
| 45 | + return _get_dp_mesh(parallel_dims) |
| 46 | + |
| 47 | + |
| 48 | +def apply_dp( |
| 49 | + model: nn.Module, |
| 50 | + parallel_dims: ParallelDims, |
| 51 | + job_config: JobConfig, |
| 52 | +) -> nn.Module: |
| 53 | + if parallel_dims.dp_replicate_enabled: |
| 54 | + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| 55 | + dp_mode = "hybrid_shard" |
| 56 | + else: |
| 57 | + dp_mode = "replicate" |
| 58 | + else: |
| 59 | + dp_mode = "fully_shard" |
| 60 | + |
| 61 | + mp_policy = MixedPrecisionPolicy( |
| 62 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 63 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| 64 | + ) |
| 65 | + |
| 66 | + model = data_parallel( |
| 67 | + model, |
| 68 | + _get_dp_mesh(parallel_dims), |
| 69 | + mode=dp_mode, |
| 70 | + mp_policy=mp_policy, |
| 71 | + full_dtensor=True, |
| 72 | + ) |
| 73 | + logger.info( |
| 74 | + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode |
| 75 | + ) |
| 76 | + return model |
| 77 | + |
| 78 | + |
| 79 | +def parallelize_llama( |
| 80 | + model: nn.Module, |
| 81 | + parallel_dims: ParallelDims, |
| 82 | + job_config: JobConfig, |
| 83 | +) -> nn.Module: |
| 84 | + if parallel_dims.cp_enabled: |
| 85 | + # TODO: SDPA + CP enablement: |
| 86 | + # Dependency: https://github.com/pytorch/pytorch/pull/167381 (sharding rule fix) |
| 87 | + # Goal: Enable Shard(2) -> Replicate() transition on the CP mesh placement. |
| 88 | + # |
| 89 | + # Implementation options for handling the required allgather: |
| 90 | + # 1. Transform into ring attention (requires converting current implementation |
| 91 | + # to an async TP-like operation) |
| 92 | + # 2. Retain explicit allgather (approach used in Llama 4) |
| 93 | + |
| 94 | + # TODO: FlexAttention + CP enablement: |
| 95 | + # Need to resolve DTensor + FlexAttention compatibility issues. |
| 96 | + |
| 97 | + raise NotImplementedError("CP is not implemented yet.") |
| 98 | + |
| 99 | + if parallel_dims.tp_enabled: |
| 100 | + # TODO: TP parallelization strategy - Key architectural decision: |
| 101 | + # |
| 102 | + # Option 1: Parallelize parameters directly (current design approach) |
| 103 | + # - Apply TP dimension immediately at this point |
| 104 | + # - Requires _StridedShard for implementation |
| 105 | + # |
| 106 | + # Option 2: Record the placement and apply full placements later |
| 107 | + # - Record TP dimension placement now, apply full placement later with DP dimension |
| 108 | + # - No need to use _StridedShard, we can just use Shard() |
| 109 | + # |
| 110 | + # It's mostly likely that we will go with option 2 as we are going to use |
| 111 | + # parameterization to handle the full parameters transformation, which |
| 112 | + # makes option 2 more natural. |
| 113 | + raise NotImplementedError("TP is not implemented yet.") |
| 114 | + |
| 115 | + if job_config.activation_checkpoint.mode != "none": |
| 116 | + # TODO: Graph based AC. |
| 117 | + raise NotImplementedError("AC is not implemented yet.") |
| 118 | + |
| 119 | + # TODO: CP integration challenge: |
| 120 | + # |
| 121 | + # Problem: |
| 122 | + # When CP is enabled, the mesh structure becomes ["dp_replicate", "dp_shard", "cp"] |
| 123 | + # to maintain sequence sharding in DTensor. However, naively applying data_parallel |
| 124 | + # may trigger two separate allgather operations because DTensor.redistribute cannot |
| 125 | + # recognize that the two mesh dimensions can be flattened into a single allgather. |
| 126 | + # |
| 127 | + # Potential solution using SimpleFSDP: |
| 128 | + # 1. Transform mesh: ["dp_replicate", "dp_shard", "cp"] -> ["dp_replicate", "dp_shard_cp"] |
| 129 | + # via to_local() and from_local() |
| 130 | + # 2. Redistribute placement on ["dp_shard_cp"] dimension |
| 131 | + # 3. Transform mesh back: ["dp_replicate", "dp_shard_cp"] -> ["dp_replicate", "dp_shard", "cp"] |
| 132 | + # via to_local() and from_local() |
| 133 | + # |
| 134 | + # Note: This solution leaves the dp_shard process group wasted ( |
| 135 | + # we can initialize it with fake backend). |
| 136 | + # |
| 137 | + # Note: We may be able to implement this solution with parameterization directly. |
| 138 | + |
| 139 | + # Keep cp_enabled here to remind us cp_enabled=True requires data_parallel |
| 140 | + if ( |
| 141 | + parallel_dims.dp_replicate_enabled |
| 142 | + or parallel_dims.dp_shard_enabled |
| 143 | + or parallel_dims.cp_enabled |
| 144 | + ): |
| 145 | + model = apply_dp(model, parallel_dims, job_config) |
| 146 | + |
| 147 | + # Apply compilation after SPMD parallelization is complete. This differs from |
| 148 | + # eager mode parallelization where compilation occurs earlier. |
| 149 | + return apply_compile(model, parallel_dims, job_config) |
| 150 | + |
| 151 | + |
| 152 | +def parallelize_buffers( |
| 153 | + model: nn.Module, |
| 154 | + parallel_dims: ParallelDims, |
| 155 | +) -> nn.Module: |
| 156 | + # Buffer-to-mesh mapping in multi-SPMD scenarios: |
| 157 | + # |
| 158 | + # When buffers are used with different SPMD meshes (e.g., dense vs sparse meshes), we |
| 159 | + # will need an explicit mapping to associate each buffer with its corresponding mesh. |
| 160 | + # This indicates that the current implementation is not general enough to support |
| 161 | + # nD meshes. |
| 162 | + # |
| 163 | + # The solution is that we need to reparameterize the buffers together with the |
| 164 | + # parameters within a module. |
| 165 | + spmd_mesh = _get_spmd_mesh(parallel_dims) |
| 166 | + placements = (Replicate() for _ in range(spmd_mesh.ndim)) |
| 167 | + for m in model.modules(): |
| 168 | + buffers = { |
| 169 | + name: DTensor.from_local(b, spmd_mesh, placements) |
| 170 | + for name, b in m.named_buffers(recurse=False) |
| 171 | + } |
| 172 | + for name, b in buffers.items(): |
| 173 | + setattr(m, name, b) |
| 174 | + |
| 175 | + return model |
| 176 | + |
| 177 | + |
| 178 | +def build_parallelize_inputs_fn( |
| 179 | + parallel_dims: ParallelDims, |
| 180 | +) -> Callable[[torch.Tensor, torch.Tensor], tuple[DTensor, DTensor]]: |
| 181 | + spmd_mesh = _get_spmd_mesh(parallel_dims) |
| 182 | + |
| 183 | + # TODO: We need to make this more general to support nD mesh. But we can do this |
| 184 | + # after the DeviceMesh revamp PR is landed. |
| 185 | + spmd_placements = [] |
| 186 | + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| 187 | + spmd_placements.append(Shard(0)) |
| 188 | + if parallel_dims.dp_replicate_enabled: |
| 189 | + spmd_placements.append(Shard(0)) |
| 190 | + |
| 191 | + def parallelize_inputs( |
| 192 | + inputs: torch.Tensor, labels: torch.Tensor |
| 193 | + ) -> tuple[DTensor, DTensor]: |
| 194 | + inputs = DTensor.from_local(inputs, spmd_mesh, spmd_placements) |
| 195 | + labels = DTensor.from_local(labels, spmd_mesh, spmd_placements) |
| 196 | + return inputs, labels |
| 197 | + |
| 198 | + return parallelize_inputs |
| 199 | + |
| 200 | + |
| 201 | +def joint_custom_pass_builder( |
| 202 | + parallel_dims: ParallelDims, job_config: JobConfig |
| 203 | +) -> Callable: |
| 204 | + match job_config.parallelism.fsdp_reshard_after_forward: |
| 205 | + case "always": |
| 206 | + fsdp_reshard_after_forward = True |
| 207 | + case "never": |
| 208 | + fsdp_reshard_after_forward = False |
| 209 | + case "default": |
| 210 | + # For PP, by default do not reshard after forward to avoid per-microbatch |
| 211 | + # all-gathers, which can be expensive and non-overlapped |
| 212 | + fsdp_reshard_after_forward = not parallel_dims.pp_enabled |
| 213 | + case _: |
| 214 | + raise ValueError( |
| 215 | + "Invalid fsdp_reshard_after_forward_policy: " |
| 216 | + f"{job_config.parallelism.fsdp_reshard_after_forward}." |
| 217 | + ) |
| 218 | + |
| 219 | + def joint_ac_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 220 | + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) |
| 221 | + gm.recompile() |
| 222 | + return gm |
| 223 | + |
| 224 | + def joint_custom_pass(joint_with_descriptors) -> None: |
| 225 | + # TODO: Is this safe? Or should we use update_joint_with_descriptors from auto_parallel? |
| 226 | + joint_with_descriptors.graph_module = joint_ac_pass( |
| 227 | + joint_with_descriptors.graph_module |
| 228 | + ) |
| 229 | + |
| 230 | + |
| 231 | +def apply_compile( |
| 232 | + model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig |
| 233 | +) -> nn.Module: |
| 234 | + # TODO: This API just implements compiler toolkit. |
| 235 | + # We should also add torch.compile() support |
| 236 | + |
| 237 | + if not (job_config.compile.enable and "model" in job_config.compile.passes): |
| 238 | + return model |
| 239 | + |
| 240 | + compiler_passes = [] |
| 241 | + # Create compilers with specified passes (defaults to no passes) |
| 242 | + fw_compiler, bw_compiler = make_compiler_with_passes( |
| 243 | + compiler_passes, dump_folder=job_config.job.dump_folder |
| 244 | + ) |
| 245 | + |
| 246 | + # Create custom joint_graph_builder with llama-specific compilers and validation |
| 247 | + llama_joint_graph_builder = functools.partial( |
| 248 | + joint_graph_builder, |
| 249 | + fw_compiler=fw_compiler, |
| 250 | + bw_compiler=bw_compiler, |
| 251 | + joint_custom_pass=joint_custom_pass_builder(parallel_dims, job_config), |
| 252 | + dump_folder=job_config.job.dump_folder, |
| 253 | + ) |
| 254 | + |
| 255 | + # Full DTensor trainer will convert the inputs to DTensor, so we don't |
| 256 | + # need CompiledModule to do it. |
| 257 | + def dummy_parallelize_inputs( |
| 258 | + mesh: DeviceMesh, args: tuple[Any, ...], kwargs: dict[str, Any] |
| 259 | + ) -> tuple[tuple[Any, ...], dict[str, Any]]: |
| 260 | + return args, kwargs |
| 261 | + |
| 262 | + return CompiledModule( |
| 263 | + model, parallel_dims, llama_joint_graph_builder, dummy_parallelize_inputs |
| 264 | + ) |
0 commit comments