Skip to content

Commit da3104a

Browse files
che-shfacebook-github-bot
authored andcommitted
{Tests-WIP}[Torchrec] Add context manager to use next batch context for postprocs
Summary: Small refactor to reduce code repetition of setting and reverting pipelined postprocs context to the next batch's context Differential Revision: D73824600
1 parent 83d87b1 commit da3104a

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
StageOut,
5959
StageOutputWithEvent,
6060
TrainPipelineContext,
61+
use_context_for_postprocs,
6162
)
6263
from torchrec.distributed.types import Awaitable
6364
from torchrec.pt2.checks import is_torchdynamo_compiling
@@ -719,19 +720,9 @@ def start_sparse_data_dist(
719720
with self._stream_context(self._data_dist_stream):
720721
_wait_for_batch(batch, self._memcpy_stream)
721722

722-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
723-
724723
# Temporarily set context for next iter to populate cache
725-
for postproc_mod in self._pipelined_postprocs:
726-
postproc_mod.set_context(context)
727-
728-
_start_data_dist(self._pipelined_modules, batch, context)
729-
730-
# Restore context for model fwd
731-
for module, context in zip(
732-
self._pipelined_postprocs, original_contexts
733-
):
734-
module.set_context(context)
724+
with use_context_for_postprocs(self._pipelined_postprocs, context):
725+
_start_data_dist(self._pipelined_modules, batch, context)
735726

736727
def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
737728
"""
@@ -1235,22 +1226,15 @@ def start_sparse_data_dist(
12351226
return
12361227

12371228
# Temporarily set context for next iter to populate cache
1238-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
1239-
for postproc_mod in self._pipelined_postprocs:
1240-
postproc_mod.set_context(context)
1241-
1242-
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1243-
with self._stream_context(self._data_dist_stream):
1244-
_wait_for_events(batch, context, self._data_dist_stream)
1245-
model_input = self.extract_model_input_from_batch(batch)
1246-
_start_data_dist(self._pipelined_modules, model_input, context)
1247-
event = torch.get_device_module(self._device).Event()
1248-
event.record()
1249-
context.events.append(event)
1250-
1251-
# Restore context for model forward
1252-
for module, context in zip(self._pipelined_postprocs, original_contexts):
1253-
module.set_context(context)
1229+
with use_context_for_postprocs(self._pipelined_postprocs, context):
1230+
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1231+
with self._stream_context(self._data_dist_stream):
1232+
_wait_for_events(batch, context, self._data_dist_stream)
1233+
model_input = self.extract_model_input_from_batch(batch)
1234+
_start_data_dist(self._pipelined_modules, model_input, context)
1235+
event = torch.get_device_module(self._device).Event()
1236+
event.record()
1237+
context.events.append(event)
12541238

12551239
def start_embedding_lookup(
12561240
self,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import contextlib
1011
import copy
1112
import itertools
1213
import logging
@@ -21,6 +22,7 @@
2122
Callable,
2223
cast,
2324
Dict,
25+
Generator,
2426
Generic,
2527
Iterable,
2628
Iterator,
@@ -1791,6 +1793,28 @@ def _prefetch_embeddings(
17911793
return data_per_sharded_module
17921794

17931795

1796+
@contextlib.contextmanager
1797+
def use_context_for_postprocs(
1798+
pipelined_postprocs: List[PipelinedPostproc],
1799+
next_batch_context: TrainPipelineContext,
1800+
) -> Generator[None, None, None]:
1801+
"""
1802+
Temporarily set pipelined postproc context for next iter to populate cache.
1803+
"""
1804+
# Save original context for model fwd
1805+
original_contexts = [p.get_context() for p in pipelined_postprocs]
1806+
1807+
# Temporarily set context for next iter to populate cache
1808+
for postproc_mod in pipelined_postprocs:
1809+
postproc_mod.set_context(next_batch_context)
1810+
1811+
yield
1812+
1813+
# Restore context for model fwd
1814+
for module, context in zip(pipelined_postprocs, original_contexts):
1815+
module.set_context(context)
1816+
1817+
17941818
class SparseDataDistUtil(Generic[In]):
17951819
"""
17961820
Helper class exposing methods for sparse data dist and prefetch pipelining.
@@ -1802,6 +1826,7 @@ class SparseDataDistUtil(Generic[In]):
18021826
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
18031827
prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
18041828
Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
1829+
pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.
18051830
18061831
Example::
18071832
sdd = SparseDataDistUtil(

0 commit comments

Comments
 (0)