|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +import itertools |
| 11 | +import logging |
| 12 | +from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union |
| 13 | + |
| 14 | +import torch |
| 15 | +from torch import distributed as dist |
| 16 | +from torch.profiler import record_function |
| 17 | + |
| 18 | +from torchrec.distributed.embedding_sharding import KJTSplitsAllToAllMeta |
| 19 | +from torchrec.distributed.model_parallel import ShardedModule |
| 20 | +from torchrec.distributed.train_pipeline.pipeline_context import ( |
| 21 | + EmbeddingTrainPipelineContext, |
| 22 | + PrefetchTrainPipelineContext, |
| 23 | + TrainPipelineContext, |
| 24 | +) |
| 25 | +from torchrec.distributed.train_pipeline.types import CallArgs |
| 26 | +from torchrec.distributed.types import Awaitable, LazyNoWait |
| 27 | +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor |
| 28 | +from torchrec.streamable import Multistreamable |
| 29 | + |
| 30 | +logger: logging.Logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | +TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) |
| 33 | + |
| 34 | +EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor] |
| 35 | + |
| 36 | + |
| 37 | +class BaseForward(Generic[TForwardContext]): |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + name: str, |
| 41 | + args: CallArgs, |
| 42 | + module: ShardedModule, |
| 43 | + context: TForwardContext, |
| 44 | + stream: Optional[torch.Stream] = None, |
| 45 | + ) -> None: |
| 46 | + self._name = name |
| 47 | + self._args = args |
| 48 | + self._module = module |
| 49 | + self._context = context |
| 50 | + self._stream = stream |
| 51 | + self._device: torch.device = stream.device if stream else torch.device("cuda") |
| 52 | + |
| 53 | + @property |
| 54 | + def name(self) -> str: |
| 55 | + return self._name |
| 56 | + |
| 57 | + @property |
| 58 | + def args(self) -> CallArgs: |
| 59 | + return self._args |
| 60 | + |
| 61 | + def set_context(self, context: TForwardContext) -> None: |
| 62 | + self._context = context |
| 63 | + |
| 64 | + def get_context(self) -> TForwardContext: |
| 65 | + return self._context |
| 66 | + |
| 67 | + |
| 68 | +class PipelinedForward(BaseForward[TrainPipelineContext]): |
| 69 | + """ |
| 70 | + This pipeline is used in TrainPipelineSparseDist |
| 71 | + """ |
| 72 | + |
| 73 | + # pyre-ignore [2, 24] |
| 74 | + def __call__(self, *input, **kwargs) -> Awaitable: |
| 75 | + assert ( |
| 76 | + self._name in self._context.input_dist_tensors_requests |
| 77 | + ), "Invalid PipelinedForward usage, please do not directly call model.forward()" |
| 78 | + request = self._context.input_dist_tensors_requests.pop(self._name) |
| 79 | + assert isinstance(request, Awaitable) |
| 80 | + with record_function("## wait_sparse_data_dist ##"): |
| 81 | + # Finish waiting on the dist_stream, |
| 82 | + # in case some delayed stream scheduling happens during the wait() call. |
| 83 | + with torch.get_device_module(self._device).stream(self._stream): |
| 84 | + data = request.wait() |
| 85 | + |
| 86 | + # Make sure that both result of input_dist and context |
| 87 | + # are properly transferred to the current stream. |
| 88 | + ctx = self._context.module_contexts.pop(self._name) |
| 89 | + |
| 90 | + if self._stream is not None: |
| 91 | + torch.get_device_module(self._device).current_stream().wait_stream( |
| 92 | + self._stream |
| 93 | + ) |
| 94 | + cur_stream = torch.get_device_module(self._device).current_stream() |
| 95 | + |
| 96 | + assert isinstance( |
| 97 | + data, (torch.Tensor, Multistreamable) |
| 98 | + ), f"{type(data)} must implement Multistreamable interface" |
| 99 | + data.record_stream(cur_stream) |
| 100 | + ctx.record_stream(cur_stream) |
| 101 | + |
| 102 | + return self._module.compute_and_output_dist(ctx, data) |
| 103 | + |
| 104 | + |
| 105 | +class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]): |
| 106 | + """ |
| 107 | + This pipeline is used in TrainPipelineSemiSync |
| 108 | + """ |
| 109 | + |
| 110 | + def __call__( |
| 111 | + self, |
| 112 | + # pyre-ignore |
| 113 | + *input, |
| 114 | + # pyre-ignore |
| 115 | + **kwargs, |
| 116 | + ) -> Union[ |
| 117 | + Awaitable[EmbeddingModuleRetType], |
| 118 | + Tuple[ |
| 119 | + Awaitable[EmbeddingModuleRetType], Awaitable[Optional[KeyedJaggedTensor]] |
| 120 | + ], |
| 121 | + ]: |
| 122 | + assert ( |
| 123 | + self._name in self._context.embedding_a2a_requests |
| 124 | + ), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()" |
| 125 | + |
| 126 | + ctx = self._context.module_contexts.pop(self._name) |
| 127 | + cur_stream = torch.get_device_module(self._device).current_stream() |
| 128 | + |
| 129 | + if self._stream is not None: |
| 130 | + torch.get_device_module(self._device).current_stream().wait_stream( |
| 131 | + self._stream |
| 132 | + ) |
| 133 | + ctx.record_stream(cur_stream) |
| 134 | + |
| 135 | + awaitable = self._context.embedding_a2a_requests.pop(self._name) |
| 136 | + # in case of MC modules |
| 137 | + is_mc_module: bool = isinstance(awaitable, Iterable) |
| 138 | + remapped_kjts: Optional[KeyedJaggedTensor] = None |
| 139 | + |
| 140 | + if is_mc_module: |
| 141 | + embeddings = awaitable[0].wait() |
| 142 | + remapped_kjts = awaitable[1].wait() |
| 143 | + else: |
| 144 | + assert isinstance(awaitable, Awaitable) |
| 145 | + embeddings = ( |
| 146 | + awaitable.wait() |
| 147 | + ) # trigger awaitable manually for type checking |
| 148 | + |
| 149 | + self.detach_embeddings(embeddings=embeddings, cur_stream=cur_stream) |
| 150 | + |
| 151 | + if is_mc_module: |
| 152 | + return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts)) |
| 153 | + else: |
| 154 | + return LazyNoWait(embeddings) |
| 155 | + |
| 156 | + def detach_embeddings( |
| 157 | + self, |
| 158 | + embeddings: Union[Dict[str, JaggedTensor], KeyedTensor], |
| 159 | + cur_stream: torch.Stream, |
| 160 | + ) -> None: |
| 161 | + """ |
| 162 | + detach the grad from embeddings so that the backward/opt of the embeddings |
| 163 | + won't be invoked by loss.backward(). Instead, there is a dedicated embedding_backward |
| 164 | + call in semi-sync pipeline progress. |
| 165 | + """ |
| 166 | + tensors = [] |
| 167 | + detached_tensors = [] |
| 168 | + # in case of EC, embeddings are Dict[str, JaggedTensor] |
| 169 | + if isinstance(embeddings, Dict): |
| 170 | + for jt in embeddings.values(): |
| 171 | + assert isinstance(jt, JaggedTensor) |
| 172 | + tensor = jt.values() |
| 173 | + detached_tensor = tensor.detach().requires_grad_() |
| 174 | + detached_tensor.retain_grad() |
| 175 | + jt._values = detached_tensor |
| 176 | + tensors.append(tensor) |
| 177 | + detached_tensors.append(detached_tensor) |
| 178 | + self._context.embedding_tensors.append(tensors) |
| 179 | + self._context.embedding_features.append(list(embeddings.keys())) |
| 180 | + self._context.detached_embedding_tensors.append(detached_tensors) |
| 181 | + else: |
| 182 | + # in case of EBC, embeddings are KeyedTensor |
| 183 | + assert isinstance(embeddings, KeyedTensor) |
| 184 | + embeddings.record_stream(cur_stream) |
| 185 | + tensor = embeddings.values() |
| 186 | + detached_tensor = tensor.detach().requires_grad_() |
| 187 | + detached_tensor.retain_grad() |
| 188 | + embeddings._values = detached_tensor |
| 189 | + tensors.append(tensor) |
| 190 | + detached_tensors.append(detached_tensor) |
| 191 | + self._context.embedding_tensors.append(tensors) |
| 192 | + """ |
| 193 | + KeyedTensor is returned by EmbeddingBagCollections and its variants |
| 194 | + KeyedTensor holds dense data from multiple features and .values() |
| 195 | + returns a single concatenated dense tensor. To ensure that |
| 196 | + context.embedding_tensors[i] has the same length as |
| 197 | + context.embedding_features[i], we pass in a list with a single item: |
| 198 | + a list containing all the embedding feature names. |
| 199 | + """ |
| 200 | + self._context.embedding_features.append([list(embeddings.keys())]) |
| 201 | + self._context.detached_embedding_tensors.append(detached_tensors) |
| 202 | + |
| 203 | + |
| 204 | +class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward): |
| 205 | + """ |
| 206 | + This pipeline is used in TrainPipelineFusedSparseDist |
| 207 | + """ |
| 208 | + |
| 209 | + def detach_embeddings( |
| 210 | + self, |
| 211 | + embeddings: Union[Dict[str, JaggedTensor], KeyedTensor], |
| 212 | + cur_stream: torch.Stream, |
| 213 | + ) -> None: |
| 214 | + # doing nothing |
| 215 | + pass |
| 216 | + |
| 217 | + |
| 218 | +class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): |
| 219 | + """ |
| 220 | + This pipeline is used in PrefetchTrainPipelineSparseDist |
| 221 | + """ |
| 222 | + |
| 223 | + def __init__( |
| 224 | + self, |
| 225 | + name: str, |
| 226 | + args: CallArgs, |
| 227 | + module: ShardedModule, |
| 228 | + context: PrefetchTrainPipelineContext, |
| 229 | + prefetch_stream: Optional[torch.Stream] = None, |
| 230 | + ) -> None: |
| 231 | + super().__init__( |
| 232 | + name=name, |
| 233 | + args=args, |
| 234 | + module=module, |
| 235 | + context=context, |
| 236 | + stream=prefetch_stream, |
| 237 | + ) |
| 238 | + |
| 239 | + # pyre-ignore [2, 24] |
| 240 | + def __call__(self, *input, **kwargs) -> Awaitable: |
| 241 | + assert ( |
| 242 | + self._name in self._context.module_input_post_prefetch |
| 243 | + ), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()" |
| 244 | + data = self._context.module_input_post_prefetch.pop(self._name) |
| 245 | + ctx = self._context.module_contexts_post_prefetch.pop(self._name) |
| 246 | + |
| 247 | + # Make sure that both result of input_dist and context |
| 248 | + # are properly transferred to the current stream. |
| 249 | + if self._stream is not None: |
| 250 | + torch.get_device_module(self._device).current_stream().wait_stream( |
| 251 | + self._stream |
| 252 | + ) |
| 253 | + cur_stream = torch.get_device_module(self._device).current_stream() |
| 254 | + |
| 255 | + assert isinstance( |
| 256 | + data, (torch.Tensor, Multistreamable) |
| 257 | + ), f"{type(data)} must implement Multistreamable interface" |
| 258 | + data.record_stream(cur_stream) |
| 259 | + |
| 260 | + ctx.record_stream(cur_stream) |
| 261 | + |
| 262 | + return self._module.compute_and_output_dist(ctx, data) |
| 263 | + |
| 264 | + |
| 265 | +class KJTAllToAllForward: |
| 266 | + def __init__( |
| 267 | + self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1 |
| 268 | + ) -> None: |
| 269 | + self._pg = pg |
| 270 | + self._splits = splits |
| 271 | + self._stagger = stagger |
| 272 | + self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) |
| 273 | + |
| 274 | + def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta: |
| 275 | + with torch.no_grad(): |
| 276 | + assert len(input.keys()) == sum(self._splits) |
| 277 | + rank = dist.get_rank(self._pg) |
| 278 | + local_keys = input.keys()[ |
| 279 | + self._splits_cumsum[rank] : self._splits_cumsum[rank + 1] |
| 280 | + ] |
| 281 | + input_splits = input.dist_splits(self._splits) |
| 282 | + device = input.values().device |
| 283 | + splits_tensors = [ |
| 284 | + torch.tensor(splits, device=device) for splits in input_splits |
| 285 | + ] |
| 286 | + if not input.variable_stride_per_key(): |
| 287 | + splits_tensors.append( |
| 288 | + torch.tensor([input.stride()] * self._pg.size(), device=device) |
| 289 | + ) |
| 290 | + return KJTSplitsAllToAllMeta( |
| 291 | + pg=self._pg, |
| 292 | + _input=input, |
| 293 | + splits=self._splits, |
| 294 | + splits_tensors=splits_tensors, |
| 295 | + input_splits=input_splits, |
| 296 | + input_tensors=input.dist_tensors(), |
| 297 | + labels=input.dist_labels(), |
| 298 | + keys=local_keys, |
| 299 | + device=device, |
| 300 | + stagger=self._stagger, |
| 301 | + ) |
0 commit comments