|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | +import logging |
| 10 | +from collections import OrderedDict |
| 11 | +from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Union |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +from torch.nn.modules.module import _IncompatibleKeys |
| 16 | +from torch.profiler import record_function |
| 17 | + |
| 18 | +from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext |
| 19 | +from torchrec.distributed.train_pipeline.types import CallArgs |
| 20 | +from torchrec.streamable import Pipelineable |
| 21 | + |
| 22 | +logger: logging.Logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +class NoOpStream: |
| 26 | + """No-Op Context manager that takes in a stream""" |
| 27 | + |
| 28 | + def __init__(self, stream: Optional[torch.Stream]) -> None: |
| 29 | + self._stream = stream |
| 30 | + |
| 31 | + def __enter__(self) -> "NoOpStream": |
| 32 | + """Return `self` upon entering the runtime context.""" |
| 33 | + return self |
| 34 | + |
| 35 | + # pyre-ignore |
| 36 | + def __exit__(self, exc_type, exc_value, traceback) -> None: |
| 37 | + return None |
| 38 | + |
| 39 | + |
| 40 | +class PipelinedPostproc(torch.nn.Module): |
| 41 | + """ |
| 42 | + Wrapper around postproc module found during model graph traversal for sparse data dist |
| 43 | + pipelining. In addition to the original module, it encapsulates information needed for |
| 44 | + execution such as list of ArgInfo and the current training pipeline context. |
| 45 | +
|
| 46 | + Args: |
| 47 | + postproc_module (torch.nn.Module): postproc module to run |
| 48 | + fqn (str): fqn of the postproc module in the model being pipelined |
| 49 | + args (CallArgs): CallArgs for the postproc module |
| 50 | + context (TrainPipelineContext): Training context for the next iteration / batch |
| 51 | +
|
| 52 | + Returns: |
| 53 | + Any |
| 54 | +
|
| 55 | + Example: |
| 56 | + postproc = PipelinedPostproc(postproc_module, fqn, args, context) |
| 57 | + # module-swap with pipeliend postproc |
| 58 | + setattr(model, fqn, postproc) |
| 59 | + """ |
| 60 | + |
| 61 | + _FORCE_STATE_DICT_LOAD = True |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + postproc_module: torch.nn.Module, |
| 66 | + fqn: str, |
| 67 | + args: CallArgs, |
| 68 | + context: TrainPipelineContext, |
| 69 | + # TODO: make streams non-optional - skipping now to avoid ripple effect |
| 70 | + default_stream: Optional[torch.Stream], |
| 71 | + dist_stream: Optional[torch.Stream], |
| 72 | + ) -> None: |
| 73 | + super().__init__() |
| 74 | + self._postproc_module = postproc_module |
| 75 | + self._fqn = fqn |
| 76 | + self._args = args |
| 77 | + self._context = context |
| 78 | + self._default_stream = default_stream |
| 79 | + self._dist_stream = dist_stream |
| 80 | + if not default_stream: |
| 81 | + logger.warning( |
| 82 | + f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" |
| 83 | + ) |
| 84 | + if not dist_stream: |
| 85 | + logger.warning( |
| 86 | + f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" |
| 87 | + ) |
| 88 | + |
| 89 | + if self._dist_stream: |
| 90 | + device: torch.device = self._dist_stream.device |
| 91 | + # pyre-ignore |
| 92 | + self._stream_context = ( |
| 93 | + torch.get_device_module(device).stream |
| 94 | + if device.type in ["cuda", "mtia"] |
| 95 | + else torch.cuda.stream |
| 96 | + ) |
| 97 | + else: |
| 98 | + self._stream_context = NoOpStream |
| 99 | + |
| 100 | + @property |
| 101 | + def postproc_module(self) -> torch.nn.Module: |
| 102 | + return self._postproc_module |
| 103 | + |
| 104 | + @property |
| 105 | + def fqn(self) -> str: |
| 106 | + return self._fqn |
| 107 | + |
| 108 | + # pyre-ignore |
| 109 | + def forward(self, *input, **kwargs) -> Any: |
| 110 | + """ |
| 111 | + Args: |
| 112 | + Any args and kwargs during model fwd |
| 113 | + During _start_data_dist, input[0] contains the current data |
| 114 | + Returns: |
| 115 | + Any |
| 116 | + """ |
| 117 | + if self._fqn in self._context.postproc_fwd_results: |
| 118 | + # This should only be hit in two cases: |
| 119 | + # 1) During model forward |
| 120 | + # During model forward, avoid duplicate work |
| 121 | + # by returning the cached result from previous |
| 122 | + # iteration's _start_data_dist |
| 123 | + # 2) During _start_data_dist when postproc module is |
| 124 | + # shared by more than one args. e.g. if we have |
| 125 | + # postproc_out_a = postproc_a(input) |
| 126 | + # postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared |
| 127 | + # postproc_out_c = postproc_c(postproc_out_a) <-^ |
| 128 | + # When processing postproc_b, we cache value of postproc_a(input) |
| 129 | + # so when processing postproc_c, we can reuse postproc_a(input) |
| 130 | + res = self._context.postproc_fwd_results[self._fqn] |
| 131 | + return res |
| 132 | + |
| 133 | + # Everything below should only be called during _start_data_dist stage |
| 134 | + |
| 135 | + # Build up arg and kwargs from recursive call to pass to postproc module |
| 136 | + # Arguments to postproc module can be also be a derived product |
| 137 | + # of another postproc module call, as long as module is pipelineable |
| 138 | + |
| 139 | + # Use input[0] as _start_data_dist only passes 1 arg |
| 140 | + args, kwargs = self._args.build_args_kwargs(input[0]) |
| 141 | + |
| 142 | + with record_function(f"## sdd_input_postproc {self._context.index} ##"): |
| 143 | + # should be no-op as we call this in dist stream |
| 144 | + with self._stream_context(self._dist_stream): |
| 145 | + res = self._postproc_module(*args, **kwargs) |
| 146 | + |
| 147 | + # Ensure postproc modules output is safe to use from default stream later |
| 148 | + if self._default_stream and self._dist_stream: |
| 149 | + self._default_stream.wait_stream(self._dist_stream) |
| 150 | + |
| 151 | + if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)): |
| 152 | + # Result from module forward might be a complex type such as |
| 153 | + # Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]] |
| 154 | + # In this case, we need to first iterate over each element of tuple |
| 155 | + # and call record_stream on first item as KJT is Pipelineable |
| 156 | + # for the second item (Dict), we iterate over the values and call |
| 157 | + # record_stream accordingly. |
| 158 | + |
| 159 | + # pyre-ignore[6] |
| 160 | + PipelinedPostproc.recursive_record_stream(res, self._default_stream) |
| 161 | + elif self._context.index == 0: |
| 162 | + logger.warning( |
| 163 | + f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" |
| 164 | + ) |
| 165 | + |
| 166 | + with self._stream_context(self._default_stream): |
| 167 | + # Cache results, only during _start_data_dist |
| 168 | + self._context.postproc_fwd_results[self._fqn] = res |
| 169 | + |
| 170 | + return res |
| 171 | + |
| 172 | + @property |
| 173 | + def args(self) -> CallArgs: |
| 174 | + return self._args |
| 175 | + |
| 176 | + def set_context(self, context: TrainPipelineContext) -> None: |
| 177 | + self._context = context |
| 178 | + |
| 179 | + def get_context(self) -> TrainPipelineContext: |
| 180 | + return self._context |
| 181 | + |
| 182 | + def named_modules( |
| 183 | + self, |
| 184 | + memo: Optional[Set[torch.nn.Module]] = None, |
| 185 | + prefix: str = "", |
| 186 | + remove_duplicate: bool = True, |
| 187 | + ) -> Iterator[Tuple[str, torch.nn.Module]]: |
| 188 | + if memo is None: |
| 189 | + memo = set() |
| 190 | + if self not in memo: |
| 191 | + if remove_duplicate: |
| 192 | + memo.add(self) |
| 193 | + # This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one |
| 194 | + # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module |
| 195 | + yield prefix, self |
| 196 | + # Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix |
| 197 | + yield from self._postproc_module.named_modules( |
| 198 | + memo, prefix, remove_duplicate |
| 199 | + ) |
| 200 | + |
| 201 | + def named_parameters( |
| 202 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 203 | + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: |
| 204 | + yield from self._postproc_module.named_parameters( |
| 205 | + prefix, |
| 206 | + recurse, |
| 207 | + remove_duplicate, |
| 208 | + ) |
| 209 | + |
| 210 | + def named_buffers( |
| 211 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 212 | + ) -> Iterator[Tuple[str, torch.Tensor]]: |
| 213 | + yield from self._postproc_module.named_buffers( |
| 214 | + prefix, recurse, remove_duplicate |
| 215 | + ) |
| 216 | + |
| 217 | + # pyre-ignore [14] |
| 218 | + def state_dict( |
| 219 | + self, |
| 220 | + destination: Optional[Dict[str, Any]] = None, |
| 221 | + prefix: str = "", |
| 222 | + keep_vars: bool = False, |
| 223 | + ) -> Dict[str, Any]: |
| 224 | + # super().state_dict(destination, prefix, keep_vars) |
| 225 | + if destination is None: |
| 226 | + destination = OrderedDict() |
| 227 | + # pyre-ignore [16] |
| 228 | + destination._metadata = OrderedDict() |
| 229 | + self._postproc_module.state_dict( |
| 230 | + destination=destination, prefix=prefix, keep_vars=keep_vars |
| 231 | + ) |
| 232 | + return destination |
| 233 | + |
| 234 | + # pyre-ignore [14] |
| 235 | + def load_state_dict( |
| 236 | + self, |
| 237 | + state_dict: OrderedDict[str, torch.Tensor], |
| 238 | + strict: bool = True, |
| 239 | + ) -> _IncompatibleKeys: |
| 240 | + return self._postproc_module.load_state_dict(state_dict, strict=strict) |
| 241 | + |
| 242 | + @staticmethod |
| 243 | + def recursive_record_stream( |
| 244 | + # pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any` |
| 245 | + res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]], |
| 246 | + stream: torch.Stream, |
| 247 | + ) -> None: |
| 248 | + if isinstance(res, torch.Tensor) and res.device.type in ["cuda", "mtia"]: |
| 249 | + res.record_stream(stream) |
| 250 | + elif isinstance(res, Pipelineable): |
| 251 | + res.record_stream(stream) |
| 252 | + elif isinstance(res, (list, tuple)): |
| 253 | + for v in res: |
| 254 | + PipelinedPostproc.recursive_record_stream(v, stream) |
| 255 | + elif isinstance(res, dict): |
| 256 | + for v in res.values(): |
| 257 | + PipelinedPostproc.recursive_record_stream(v, stream) |
0 commit comments