|
58 | 58 | StageOut, |
59 | 59 | StageOutputWithEvent, |
60 | 60 | TrainPipelineContext, |
| 61 | + use_context_for_postprocs, |
61 | 62 | ) |
62 | 63 | from torchrec.distributed.types import Awaitable |
63 | 64 | from torchrec.pt2.checks import is_torchdynamo_compiling |
@@ -792,19 +793,9 @@ def start_sparse_data_dist( |
792 | 793 | with self._stream_context(self._data_dist_stream): |
793 | 794 | _wait_for_batch(batch, self._memcpy_stream) |
794 | 795 |
|
795 | | - original_contexts = [p.get_context() for p in self._pipelined_postprocs] |
796 | | - |
797 | 796 | # Temporarily set context for next iter to populate cache |
798 | | - for postproc_mod in self._pipelined_postprocs: |
799 | | - postproc_mod.set_context(context) |
800 | | - |
801 | | - _start_data_dist(self._pipelined_modules, batch, context) |
802 | | - |
803 | | - # Restore context for model fwd |
804 | | - for module, context in zip( |
805 | | - self._pipelined_postprocs, original_contexts |
806 | | - ): |
807 | | - module.set_context(context) |
| 797 | + with use_context_for_postprocs(self._pipelined_postprocs, context): |
| 798 | + _start_data_dist(self._pipelined_modules, batch, context) |
808 | 799 |
|
809 | 800 | def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: |
810 | 801 | """ |
@@ -1325,22 +1316,15 @@ def start_sparse_data_dist( |
1325 | 1316 | return |
1326 | 1317 |
|
1327 | 1318 | # Temporarily set context for next iter to populate cache |
1328 | | - original_contexts = [p.get_context() for p in self._pipelined_postprocs] |
1329 | | - for postproc_mod in self._pipelined_postprocs: |
1330 | | - postproc_mod.set_context(context) |
1331 | | - |
1332 | | - with record_function(f"## start_sparse_data_dist {context.index} ##"): |
1333 | | - with self._stream_context(self._data_dist_stream): |
1334 | | - _wait_for_events(batch, context, self._data_dist_stream) |
1335 | | - model_input = self.extract_model_input_from_batch(batch) |
1336 | | - _start_data_dist(self._pipelined_modules, model_input, context) |
1337 | | - event = torch.get_device_module(self._device).Event() |
1338 | | - event.record() |
1339 | | - context.events.append(event) |
1340 | | - |
1341 | | - # Restore context for model forward |
1342 | | - for module, context in zip(self._pipelined_postprocs, original_contexts): |
1343 | | - module.set_context(context) |
| 1319 | + with use_context_for_postprocs(self._pipelined_postprocs, context): |
| 1320 | + with record_function(f"## start_sparse_data_dist {context.index} ##"): |
| 1321 | + with self._stream_context(self._data_dist_stream): |
| 1322 | + _wait_for_events(batch, context, self._data_dist_stream) |
| 1323 | + model_input = self.extract_model_input_from_batch(batch) |
| 1324 | + _start_data_dist(self._pipelined_modules, model_input, context) |
| 1325 | + event = torch.get_device_module(self._device).Event() |
| 1326 | + event.record() |
| 1327 | + context.events.append(event) |
1344 | 1328 |
|
1345 | 1329 | def start_embedding_lookup( |
1346 | 1330 | self, |
|
0 commit comments