Skip to content

Commit 953c787

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
split train_pipeline.utils - tracing (#2982)
Summary: Pull Request resolved: #2982 # context * train_pipeline.utils file is overloaded * split the functions, classes, etc. into three files with each ~< 1000 lines * this diff: tracing.py types.py preproc.py Reviewed By: che-sh, malaybag Differential Revision: D74939217 fbshipit-source-id: 1143a31b39bec808d69f5b179a128677947b6f71
1 parent 114810d commit 953c787

File tree

8 files changed

+956
-899
lines changed

8 files changed

+956
-899
lines changed

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
Out,
1414
TrainPipelineContext,
1515
)
16+
from torchrec.distributed.train_pipeline.tracing import ( # noqa
17+
ArgInfoStepFactory, # noqa
18+
Tracer, # noqa
19+
)
1620
from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa
1721
EvalPipelineSparseDist, # noqa
1822
PrefetchTrainPipelineSparseDist, # noqa
@@ -25,17 +29,14 @@
2529
TrainPipelineSparseDist, # noqa
2630
TrainPipelineSparseDistCompAutograd, # noqa
2731
)
32+
from torchrec.distributed.train_pipeline.types import ArgInfo, CallArgs # noqa
2833
from torchrec.distributed.train_pipeline.utils import ( # noqa
2934
_override_input_dist_forwards, # noqa
3035
_rewrite_model, # noqa
3136
_start_data_dist, # noqa
3237
_to_device, # noqa
3338
_wait_for_batch, # noqa
34-
ArgInfo, # noqa
35-
ArgInfoStepFactory, # noqa
36-
CallArgs, # noqa
3739
DataLoadingThread, # noqa
3840
SparseDataDistUtil, # noqa
3941
StageOut, # noqa
40-
Tracer, # noqa
4142
)
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from collections import OrderedDict
11+
from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Union
12+
13+
import torch
14+
15+
from torch.nn.modules.module import _IncompatibleKeys
16+
from torch.profiler import record_function
17+
18+
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
19+
from torchrec.distributed.train_pipeline.types import CallArgs
20+
from torchrec.streamable import Pipelineable
21+
22+
logger: logging.Logger = logging.getLogger(__name__)
23+
24+
25+
class NoOpStream:
26+
"""No-Op Context manager that takes in a stream"""
27+
28+
def __init__(self, stream: Optional[torch.Stream]) -> None:
29+
self._stream = stream
30+
31+
def __enter__(self) -> "NoOpStream":
32+
"""Return `self` upon entering the runtime context."""
33+
return self
34+
35+
# pyre-ignore
36+
def __exit__(self, exc_type, exc_value, traceback) -> None:
37+
return None
38+
39+
40+
class PipelinedPostproc(torch.nn.Module):
41+
"""
42+
Wrapper around postproc module found during model graph traversal for sparse data dist
43+
pipelining. In addition to the original module, it encapsulates information needed for
44+
execution such as list of ArgInfo and the current training pipeline context.
45+
46+
Args:
47+
postproc_module (torch.nn.Module): postproc module to run
48+
fqn (str): fqn of the postproc module in the model being pipelined
49+
args (CallArgs): CallArgs for the postproc module
50+
context (TrainPipelineContext): Training context for the next iteration / batch
51+
52+
Returns:
53+
Any
54+
55+
Example:
56+
postproc = PipelinedPostproc(postproc_module, fqn, args, context)
57+
# module-swap with pipeliend postproc
58+
setattr(model, fqn, postproc)
59+
"""
60+
61+
_FORCE_STATE_DICT_LOAD = True
62+
63+
def __init__(
64+
self,
65+
postproc_module: torch.nn.Module,
66+
fqn: str,
67+
args: CallArgs,
68+
context: TrainPipelineContext,
69+
# TODO: make streams non-optional - skipping now to avoid ripple effect
70+
default_stream: Optional[torch.Stream],
71+
dist_stream: Optional[torch.Stream],
72+
) -> None:
73+
super().__init__()
74+
self._postproc_module = postproc_module
75+
self._fqn = fqn
76+
self._args = args
77+
self._context = context
78+
self._default_stream = default_stream
79+
self._dist_stream = dist_stream
80+
if not default_stream:
81+
logger.warning(
82+
f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!"
83+
)
84+
if not dist_stream:
85+
logger.warning(
86+
f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!"
87+
)
88+
89+
if self._dist_stream:
90+
device: torch.device = self._dist_stream.device
91+
# pyre-ignore
92+
self._stream_context = (
93+
torch.get_device_module(device).stream
94+
if device.type in ["cuda", "mtia"]
95+
else torch.cuda.stream
96+
)
97+
else:
98+
self._stream_context = NoOpStream
99+
100+
@property
101+
def postproc_module(self) -> torch.nn.Module:
102+
return self._postproc_module
103+
104+
@property
105+
def fqn(self) -> str:
106+
return self._fqn
107+
108+
# pyre-ignore
109+
def forward(self, *input, **kwargs) -> Any:
110+
"""
111+
Args:
112+
Any args and kwargs during model fwd
113+
During _start_data_dist, input[0] contains the current data
114+
Returns:
115+
Any
116+
"""
117+
if self._fqn in self._context.postproc_fwd_results:
118+
# This should only be hit in two cases:
119+
# 1) During model forward
120+
# During model forward, avoid duplicate work
121+
# by returning the cached result from previous
122+
# iteration's _start_data_dist
123+
# 2) During _start_data_dist when postproc module is
124+
# shared by more than one args. e.g. if we have
125+
# postproc_out_a = postproc_a(input)
126+
# postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared
127+
# postproc_out_c = postproc_c(postproc_out_a) <-^
128+
# When processing postproc_b, we cache value of postproc_a(input)
129+
# so when processing postproc_c, we can reuse postproc_a(input)
130+
res = self._context.postproc_fwd_results[self._fqn]
131+
return res
132+
133+
# Everything below should only be called during _start_data_dist stage
134+
135+
# Build up arg and kwargs from recursive call to pass to postproc module
136+
# Arguments to postproc module can be also be a derived product
137+
# of another postproc module call, as long as module is pipelineable
138+
139+
# Use input[0] as _start_data_dist only passes 1 arg
140+
args, kwargs = self._args.build_args_kwargs(input[0])
141+
142+
with record_function(f"## sdd_input_postproc {self._context.index} ##"):
143+
# should be no-op as we call this in dist stream
144+
with self._stream_context(self._dist_stream):
145+
res = self._postproc_module(*args, **kwargs)
146+
147+
# Ensure postproc modules output is safe to use from default stream later
148+
if self._default_stream and self._dist_stream:
149+
self._default_stream.wait_stream(self._dist_stream)
150+
151+
if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)):
152+
# Result from module forward might be a complex type such as
153+
# Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]]
154+
# In this case, we need to first iterate over each element of tuple
155+
# and call record_stream on first item as KJT is Pipelineable
156+
# for the second item (Dict), we iterate over the values and call
157+
# record_stream accordingly.
158+
159+
# pyre-ignore[6]
160+
PipelinedPostproc.recursive_record_stream(res, self._default_stream)
161+
elif self._context.index == 0:
162+
logger.warning(
163+
f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!"
164+
)
165+
166+
with self._stream_context(self._default_stream):
167+
# Cache results, only during _start_data_dist
168+
self._context.postproc_fwd_results[self._fqn] = res
169+
170+
return res
171+
172+
@property
173+
def args(self) -> CallArgs:
174+
return self._args
175+
176+
def set_context(self, context: TrainPipelineContext) -> None:
177+
self._context = context
178+
179+
def get_context(self) -> TrainPipelineContext:
180+
return self._context
181+
182+
def named_modules(
183+
self,
184+
memo: Optional[Set[torch.nn.Module]] = None,
185+
prefix: str = "",
186+
remove_duplicate: bool = True,
187+
) -> Iterator[Tuple[str, torch.nn.Module]]:
188+
if memo is None:
189+
memo = set()
190+
if self not in memo:
191+
if remove_duplicate:
192+
memo.add(self)
193+
# This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one
194+
# Also, `named_modules` need to include self - see base implementation in the nn.modules.Module
195+
yield prefix, self
196+
# Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix
197+
yield from self._postproc_module.named_modules(
198+
memo, prefix, remove_duplicate
199+
)
200+
201+
def named_parameters(
202+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
203+
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
204+
yield from self._postproc_module.named_parameters(
205+
prefix,
206+
recurse,
207+
remove_duplicate,
208+
)
209+
210+
def named_buffers(
211+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
212+
) -> Iterator[Tuple[str, torch.Tensor]]:
213+
yield from self._postproc_module.named_buffers(
214+
prefix, recurse, remove_duplicate
215+
)
216+
217+
# pyre-ignore [14]
218+
def state_dict(
219+
self,
220+
destination: Optional[Dict[str, Any]] = None,
221+
prefix: str = "",
222+
keep_vars: bool = False,
223+
) -> Dict[str, Any]:
224+
# super().state_dict(destination, prefix, keep_vars)
225+
if destination is None:
226+
destination = OrderedDict()
227+
# pyre-ignore [16]
228+
destination._metadata = OrderedDict()
229+
self._postproc_module.state_dict(
230+
destination=destination, prefix=prefix, keep_vars=keep_vars
231+
)
232+
return destination
233+
234+
# pyre-ignore [14]
235+
def load_state_dict(
236+
self,
237+
state_dict: OrderedDict[str, torch.Tensor],
238+
strict: bool = True,
239+
) -> _IncompatibleKeys:
240+
return self._postproc_module.load_state_dict(state_dict, strict=strict)
241+
242+
@staticmethod
243+
def recursive_record_stream(
244+
# pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any`
245+
res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]],
246+
stream: torch.Stream,
247+
) -> None:
248+
if isinstance(res, torch.Tensor) and res.device.type in ["cuda", "mtia"]:
249+
res.record_stream(stream)
250+
elif isinstance(res, Pipelineable):
251+
res.record_stream(stream)
252+
elif isinstance(res, (list, tuple)):
253+
for v in res:
254+
PipelinedPostproc.recursive_record_stream(v, stream)
255+
elif isinstance(res, dict):
256+
for v in res.values():
257+
PipelinedPostproc.recursive_record_stream(v, stream)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,16 @@
4848
create_module_and_freeze,
4949
)
5050
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
51+
from torchrec.distributed.train_pipeline.postproc import PipelinedPostproc
5152
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
5253
TrainPipelineSparseDistTestBase,
5354
)
55+
from torchrec.distributed.train_pipeline.tracing import (
56+
GetAttrArgInfoStep,
57+
GetItemArgInfoStep,
58+
NoopArgInfoStep,
59+
PostprocArgInfoStep,
60+
)
5461
from torchrec.distributed.train_pipeline.train_pipelines import (
5562
EvalPipelineSparseDist,
5663
PrefetchTrainPipelineSparseDist,
@@ -65,13 +72,8 @@
6572
DataLoadingThread,
6673
EmbeddingPipelinedForward,
6774
get_h2d_func,
68-
GetAttrArgInfoStep,
69-
GetItemArgInfoStep,
70-
NoopArgInfoStep,
7175
PipelinedForward,
72-
PipelinedPostproc,
7376
PipelineStage,
74-
PostprocArgInfoStep,
7577
SparseDataDistUtil,
7678
StageOut,
7779
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@
2323
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
2424
TrainPipelineSparseDistTestBase,
2525
)
26-
from torchrec.distributed.train_pipeline.utils import (
27-
_rewrite_model,
26+
from torchrec.distributed.train_pipeline.tracing import (
2827
ArgInfo,
2928
ArgInfoStepFactory,
3029
CallArgs,
3130
NodeArgsHelper,
32-
PipelinedForward,
3331
PipelinedPostproc,
3432
)
33+
from torchrec.distributed.train_pipeline.utils import _rewrite_model, PipelinedForward
3534
from torchrec.distributed.types import ShardingType
3635
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3736

0 commit comments

Comments
 (0)