Skip to content

Commit d4fb333

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
split train_pipeline.utils - forward
Summary: # context * train_pipeline.utils file is overloaded * split the functions, classes, etc. into three files with each ~< 1000 lines * this diff: forwards.py Reviewed By: malaybag Differential Revision: D74939567
1 parent bbb1e75 commit d4fb333

File tree

5 files changed

+325
-287
lines changed

5 files changed

+325
-287
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import itertools
11+
import logging
12+
from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union
13+
14+
import torch
15+
from torch import distributed as dist
16+
from torch.profiler import record_function
17+
18+
from torchrec.distributed.embedding_sharding import KJTSplitsAllToAllMeta
19+
from torchrec.distributed.model_parallel import ShardedModule
20+
from torchrec.distributed.train_pipeline.pipeline_context import (
21+
EmbeddingTrainPipelineContext,
22+
PrefetchTrainPipelineContext,
23+
TrainPipelineContext,
24+
)
25+
from torchrec.distributed.train_pipeline.types import CallArgs
26+
from torchrec.distributed.types import Awaitable, LazyNoWait
27+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
28+
from torchrec.streamable import Multistreamable
29+
30+
logger: logging.Logger = logging.getLogger(__name__)
31+
32+
TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext)
33+
34+
EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor]
35+
36+
37+
class BaseForward(Generic[TForwardContext]):
38+
def __init__(
39+
self,
40+
name: str,
41+
args: CallArgs,
42+
module: ShardedModule,
43+
context: TForwardContext,
44+
stream: Optional[torch.Stream] = None,
45+
) -> None:
46+
self._name = name
47+
self._args = args
48+
self._module = module
49+
self._context = context
50+
self._stream = stream
51+
self._device: torch.device = stream.device if stream else torch.device("cuda")
52+
53+
@property
54+
def name(self) -> str:
55+
return self._name
56+
57+
@property
58+
def args(self) -> CallArgs:
59+
return self._args
60+
61+
def set_context(self, context: TForwardContext) -> None:
62+
self._context = context
63+
64+
def get_context(self) -> TForwardContext:
65+
return self._context
66+
67+
68+
class PipelinedForward(BaseForward[TrainPipelineContext]):
69+
"""
70+
This pipeline is used in TrainPipelineSparseDist
71+
"""
72+
73+
# pyre-ignore [2, 24]
74+
def __call__(self, *input, **kwargs) -> Awaitable:
75+
assert (
76+
self._name in self._context.input_dist_tensors_requests
77+
), "Invalid PipelinedForward usage, please do not directly call model.forward()"
78+
request = self._context.input_dist_tensors_requests.pop(self._name)
79+
assert isinstance(request, Awaitable)
80+
with record_function("## wait_sparse_data_dist ##"):
81+
# Finish waiting on the dist_stream,
82+
# in case some delayed stream scheduling happens during the wait() call.
83+
with torch.get_device_module(self._device).stream(self._stream):
84+
data = request.wait()
85+
86+
# Make sure that both result of input_dist and context
87+
# are properly transferred to the current stream.
88+
ctx = self._context.module_contexts.pop(self._name)
89+
90+
if self._stream is not None:
91+
torch.get_device_module(self._device).current_stream().wait_stream(
92+
self._stream
93+
)
94+
cur_stream = torch.get_device_module(self._device).current_stream()
95+
96+
assert isinstance(
97+
data, (torch.Tensor, Multistreamable)
98+
), f"{type(data)} must implement Multistreamable interface"
99+
data.record_stream(cur_stream)
100+
ctx.record_stream(cur_stream)
101+
102+
return self._module.compute_and_output_dist(ctx, data)
103+
104+
105+
class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]):
106+
"""
107+
This pipeline is used in TrainPipelineSemiSync
108+
"""
109+
110+
def __call__(
111+
self,
112+
# pyre-ignore
113+
*input,
114+
# pyre-ignore
115+
**kwargs,
116+
) -> Union[
117+
Awaitable[EmbeddingModuleRetType],
118+
Tuple[
119+
Awaitable[EmbeddingModuleRetType], Awaitable[Optional[KeyedJaggedTensor]]
120+
],
121+
]:
122+
assert (
123+
self._name in self._context.embedding_a2a_requests
124+
), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()"
125+
126+
ctx = self._context.module_contexts.pop(self._name)
127+
cur_stream = torch.get_device_module(self._device).current_stream()
128+
129+
if self._stream is not None:
130+
torch.get_device_module(self._device).current_stream().wait_stream(
131+
self._stream
132+
)
133+
ctx.record_stream(cur_stream)
134+
135+
awaitable = self._context.embedding_a2a_requests.pop(self._name)
136+
# in case of MC modules
137+
is_mc_module: bool = isinstance(awaitable, Iterable)
138+
remapped_kjts: Optional[KeyedJaggedTensor] = None
139+
140+
if is_mc_module:
141+
embeddings = awaitable[0].wait()
142+
remapped_kjts = awaitable[1].wait()
143+
else:
144+
assert isinstance(awaitable, Awaitable)
145+
embeddings = (
146+
awaitable.wait()
147+
) # trigger awaitable manually for type checking
148+
149+
self.detach_embeddings(embeddings=embeddings, cur_stream=cur_stream)
150+
151+
if is_mc_module:
152+
return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts))
153+
else:
154+
return LazyNoWait(embeddings)
155+
156+
def detach_embeddings(
157+
self,
158+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
159+
cur_stream: torch.Stream,
160+
) -> None:
161+
"""
162+
detach the grad from embeddings so that the backward/opt of the embeddings
163+
won't be invoked by loss.backward(). Instead, there is a dedicated embedding_backward
164+
call in semi-sync pipeline progress.
165+
"""
166+
tensors = []
167+
detached_tensors = []
168+
# in case of EC, embeddings are Dict[str, JaggedTensor]
169+
if isinstance(embeddings, Dict):
170+
for jt in embeddings.values():
171+
assert isinstance(jt, JaggedTensor)
172+
tensor = jt.values()
173+
detached_tensor = tensor.detach().requires_grad_()
174+
detached_tensor.retain_grad()
175+
jt._values = detached_tensor
176+
tensors.append(tensor)
177+
detached_tensors.append(detached_tensor)
178+
self._context.embedding_tensors.append(tensors)
179+
self._context.embedding_features.append(list(embeddings.keys()))
180+
self._context.detached_embedding_tensors.append(detached_tensors)
181+
else:
182+
# in case of EBC, embeddings are KeyedTensor
183+
assert isinstance(embeddings, KeyedTensor)
184+
embeddings.record_stream(cur_stream)
185+
tensor = embeddings.values()
186+
detached_tensor = tensor.detach().requires_grad_()
187+
detached_tensor.retain_grad()
188+
embeddings._values = detached_tensor
189+
tensors.append(tensor)
190+
detached_tensors.append(detached_tensor)
191+
self._context.embedding_tensors.append(tensors)
192+
"""
193+
KeyedTensor is returned by EmbeddingBagCollections and its variants
194+
KeyedTensor holds dense data from multiple features and .values()
195+
returns a single concatenated dense tensor. To ensure that
196+
context.embedding_tensors[i] has the same length as
197+
context.embedding_features[i], we pass in a list with a single item:
198+
a list containing all the embedding feature names.
199+
"""
200+
self._context.embedding_features.append([list(embeddings.keys())])
201+
self._context.detached_embedding_tensors.append(detached_tensors)
202+
203+
204+
class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
205+
"""
206+
This pipeline is used in TrainPipelineFusedSparseDist
207+
"""
208+
209+
def detach_embeddings(
210+
self,
211+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
212+
cur_stream: torch.Stream,
213+
) -> None:
214+
# doing nothing
215+
pass
216+
217+
218+
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
219+
"""
220+
This pipeline is used in PrefetchTrainPipelineSparseDist
221+
"""
222+
223+
def __init__(
224+
self,
225+
name: str,
226+
args: CallArgs,
227+
module: ShardedModule,
228+
context: PrefetchTrainPipelineContext,
229+
prefetch_stream: Optional[torch.Stream] = None,
230+
) -> None:
231+
super().__init__(
232+
name=name,
233+
args=args,
234+
module=module,
235+
context=context,
236+
stream=prefetch_stream,
237+
)
238+
239+
# pyre-ignore [2, 24]
240+
def __call__(self, *input, **kwargs) -> Awaitable:
241+
assert (
242+
self._name in self._context.module_input_post_prefetch
243+
), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()"
244+
data = self._context.module_input_post_prefetch.pop(self._name)
245+
ctx = self._context.module_contexts_post_prefetch.pop(self._name)
246+
247+
# Make sure that both result of input_dist and context
248+
# are properly transferred to the current stream.
249+
if self._stream is not None:
250+
torch.get_device_module(self._device).current_stream().wait_stream(
251+
self._stream
252+
)
253+
cur_stream = torch.get_device_module(self._device).current_stream()
254+
255+
assert isinstance(
256+
data, (torch.Tensor, Multistreamable)
257+
), f"{type(data)} must implement Multistreamable interface"
258+
data.record_stream(cur_stream)
259+
260+
ctx.record_stream(cur_stream)
261+
262+
return self._module.compute_and_output_dist(ctx, data)
263+
264+
265+
class KJTAllToAllForward:
266+
def __init__(
267+
self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1
268+
) -> None:
269+
self._pg = pg
270+
self._splits = splits
271+
self._stagger = stagger
272+
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
273+
274+
def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
275+
with torch.no_grad():
276+
assert len(input.keys()) == sum(self._splits)
277+
rank = dist.get_rank(self._pg)
278+
local_keys = input.keys()[
279+
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
280+
]
281+
input_splits = input.dist_splits(self._splits)
282+
device = input.values().device
283+
splits_tensors = [
284+
torch.tensor(splits, device=device) for splits in input_splits
285+
]
286+
if not input.variable_stride_per_key():
287+
splits_tensors.append(
288+
torch.tensor([input.stride()] * self._pg.size(), device=device)
289+
)
290+
return KJTSplitsAllToAllMeta(
291+
pg=self._pg,
292+
_input=input,
293+
splits=self._splits,
294+
splits_tensors=splits_tensors,
295+
input_splits=input_splits,
296+
input_tensors=input.dist_tensors(),
297+
labels=input.dist_labels(),
298+
keys=local_keys,
299+
device=device,
300+
stagger=self._stagger,
301+
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
)
5050
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
5151
from torchrec.distributed.train_pipeline.postproc import PipelinedPostproc
52+
from torchrec.distributed.train_pipeline.runtime_forwards import (
53+
EmbeddingPipelinedForward,
54+
PipelinedForward,
55+
)
5256
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
5357
TrainPipelineSparseDistTestBase,
5458
)
@@ -70,9 +74,7 @@
7074
)
7175
from torchrec.distributed.train_pipeline.utils import (
7276
DataLoadingThread,
73-
EmbeddingPipelinedForward,
7477
get_h2d_func,
75-
PipelinedForward,
7678
PipelineStage,
7779
SparseDataDistUtil,
7880
StageOut,

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2020
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
2121
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
22+
from torchrec.distributed.train_pipeline.runtime_forwards import PipelinedForward
2223

2324
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
2425
TrainPipelineSparseDistTestBase,
@@ -30,7 +31,7 @@
3031
NodeArgsHelper,
3132
PipelinedPostproc,
3233
)
33-
from torchrec.distributed.train_pipeline.utils import _rewrite_model, PipelinedForward
34+
from torchrec.distributed.train_pipeline.utils import _rewrite_model
3435
from torchrec.distributed.types import ShardingType
3536
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3637

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@
4040
PrefetchTrainPipelineContext,
4141
TrainPipelineContext,
4242
)
43+
from torchrec.distributed.train_pipeline.runtime_forwards import (
44+
EmbeddingPipelinedForward,
45+
InSyncEmbeddingPipelinedForward,
46+
PipelinedForward,
47+
PrefetchPipelinedForward,
48+
)
4349
from torchrec.distributed.train_pipeline.tracing import PipelinedPostproc
4450
from torchrec.distributed.train_pipeline.utils import (
4551
_override_input_dist_forwards,
@@ -52,11 +58,7 @@
5258
_wait_for_batch,
5359
_wait_for_events,
5460
DataLoadingThread,
55-
EmbeddingPipelinedForward,
56-
InSyncEmbeddingPipelinedForward,
57-
PipelinedForward,
5861
PipelineStage,
59-
PrefetchPipelinedForward,
6062
RunnableType,
6163
StageOut,
6264
StageOutputWithEvent,

0 commit comments

Comments
 (0)