Skip to content

Commit 78f73b6

Browse files
kddnewtonfacebook-github-bot
authored andcommitted
Allow a context manager to be called around apply_jit
Summary: When running torch.jit.script on the various forward functions, you can run into issues if there are any other utilites interacting with the function definitions. As an example, if you have another JIT running, you need to disable it throughout this process. This commit adds the ability to additionally pass an apply_jit_context context manager wherever apply_jit is currently passed that will be called around the application of the torch jit. Differential Revision: D73781040
1 parent 97f8dea commit 78f73b6

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import copy
1111
import enum
1212
import unittest
13-
from typing import List
13+
from contextlib import contextmanager
14+
from typing import Generator, List
1415
from unittest.mock import MagicMock
1516

1617
import torch
@@ -42,6 +43,29 @@ class ModelType(enum.Enum):
4243

4344

4445
class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
46+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
47+
@unittest.skipIf(
48+
not torch.cuda.is_available(),
49+
"Not enough GPUs, this test requires at least one GPU",
50+
)
51+
def test_rewrite_model_apply_jit(self) -> None:
52+
@contextmanager
53+
def apply_jit_context(events: list[str]) -> Generator[None, None, None]:
54+
events.append("__enter__")
55+
yield
56+
events.append("__exit__")
57+
58+
events = []
59+
_rewrite_model(
60+
model=self._setup_model(),
61+
context=TrainPipelineContext(),
62+
dist_stream=None,
63+
apply_jit=True,
64+
apply_jit_context=apply_jit_context(events),
65+
)
66+
67+
self.assertEqual(events, ["__enter__", "__exit__"])
68+
4569
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4670
@unittest.skipIf(
4771
not torch.cuda.is_available(),

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import contextlib
1212
import logging
1313
from collections import deque
14+
from contextlib import nullcontext
1415
from dataclasses import dataclass
1516
from typing import (
1617
Any,
@@ -318,6 +319,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
318319
return output
319320

320321

322+
_apply_jit_context_default: ContextManager[None] = nullcontext()
323+
324+
321325
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
322326
"""
323327
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
@@ -343,6 +347,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
343347
execute_all_batches (bool): executes remaining batches in pipeline after
344348
exhausting dataloader iterator.
345349
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
350+
apply_jit_context (ContextManager): a context manager that will surround the
351+
application of the JIT
346352
"""
347353

348354
# The PipelinedForward class that is used in _rewrite_model
@@ -355,6 +361,7 @@ def __init__(
355361
device: torch.device,
356362
execute_all_batches: bool = True,
357363
apply_jit: bool = False,
364+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
358365
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
359366
# keep for backward compatibility
360367
pipeline_postproc: bool = False,
@@ -367,6 +374,7 @@ def __init__(
367374
self._device = device
368375
self._execute_all_batches = execute_all_batches
369376
self._apply_jit = apply_jit
377+
self._apply_jit_context = apply_jit_context
370378

371379
if device.type == "cuda":
372380
# use two data streams to support two concurrent batches
@@ -641,6 +649,7 @@ def _pipeline_model(
641649
default_stream=torch.get_device_module(self._device).current_stream(),
642650
batch=batch,
643651
apply_jit=self._apply_jit,
652+
apply_jit_context=self._apply_jit_context,
644653
pipelined_forward=pipelined_forward,
645654
pipeline_postproc=self._pipeline_postproc,
646655
)
@@ -820,6 +829,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
820829
execute_all_batches (bool): executes remaining batches in pipeline after
821830
exhausting dataloader iterator.
822831
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
832+
apply_jit_context (ContextManager): a context manager that will surround the
833+
application of the JIT
823834
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
824835
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
825836
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
@@ -835,6 +846,7 @@ def __init__(
835846
device: torch.device,
836847
execute_all_batches: bool = True,
837848
apply_jit: bool = False,
849+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
838850
start_batch: int = 900,
839851
stash_gradients: bool = False,
840852
pipeline_postproc: bool = True,
@@ -849,6 +861,7 @@ def __init__(
849861
device=device,
850862
execute_all_batches=execute_all_batches,
851863
apply_jit=apply_jit,
864+
apply_jit_context=apply_jit_context,
852865
context_type=EmbeddingTrainPipelineContext,
853866
pipeline_postproc=pipeline_postproc,
854867
custom_model_fwd=custom_model_fwd,
@@ -1135,6 +1148,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
11351148
execute_all_batches (bool): executes remaining batches in pipeline after
11361149
exhausting dataloader iterator.
11371150
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1151+
apply_jit_context (ContextManager): a context manager that will surround the
1152+
application of the JIT
11381153
"""
11391154

11401155
# The PipelinedForward class that is used in _rewrite_model
@@ -1147,6 +1162,7 @@ def __init__(
11471162
device: torch.device,
11481163
execute_all_batches: bool = True,
11491164
apply_jit: bool = False,
1165+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
11501166
pipeline_postproc: bool = True,
11511167
custom_model_fwd: Optional[
11521168
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
@@ -1158,6 +1174,7 @@ def __init__(
11581174
device=device,
11591175
execute_all_batches=execute_all_batches,
11601176
apply_jit=apply_jit,
1177+
apply_jit_context=apply_jit_context,
11611178
context_type=PrefetchTrainPipelineContext,
11621179
pipeline_postproc=pipeline_postproc,
11631180
custom_model_fwd=custom_model_fwd,
@@ -1292,6 +1309,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
12921309
device (torch.device): device where device transfer, sparse data dist, and
12931310
forward/backward pass will happen.
12941311
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1312+
apply_jit_context (ContextManager): a context manager that will surround the
1313+
application of the JIT
12951314
"""
12961315

12971316
# The PipelinedForward class that is used in _rewrite_model
@@ -1303,8 +1322,9 @@ def __init__(
13031322
optimizer: torch.optim.Optimizer,
13041323
device: torch.device,
13051324
apply_jit: bool = False,
1325+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
13061326
) -> None:
1307-
super().__init__(model, optimizer, device, True, apply_jit)
1327+
super().__init__(model, optimizer, device, True, apply_jit, apply_jit_context)
13081328
self._batch_loader: Optional[DataLoadingThread[In]] = None
13091329

13101330
def __del__(self) -> None:
@@ -1661,6 +1681,7 @@ def __init__(
16611681
device: torch.device,
16621682
execute_all_batches: bool = True,
16631683
apply_jit: bool = False,
1684+
apply_jit_context: ContextManager[None] = _apply_jit_context_default,
16641685
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
16651686
pipeline_postproc: bool = False,
16661687
custom_model_fwd: Optional[
@@ -1673,6 +1694,7 @@ def __init__(
16731694
device,
16741695
execute_all_batches,
16751696
apply_jit,
1697+
apply_jit_context,
16761698
context_type,
16771699
pipeline_postproc,
16781700
custom_model_fwd,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import itertools
1212
import logging
1313
from collections import defaultdict, OrderedDict
14-
from contextlib import AbstractContextManager
14+
from contextlib import AbstractContextManager, nullcontext
1515
from dataclasses import dataclass, field
1616

1717
from itertools import chain
@@ -20,6 +20,7 @@
2020
Any,
2121
Callable,
2222
cast,
23+
ContextManager,
2324
Dict,
2425
Generic,
2526
Iterable,
@@ -1454,13 +1455,17 @@ def _pipeline_detach_model(
14541455
setattr(model, postproc_mod.fqn, postproc_mod.postproc_module)
14551456

14561457

1458+
_rewrite_model_apply_jit_context_default: ContextManager[None] = nullcontext()
1459+
1460+
14571461
# pyre-ignore[3]
14581462
def _rewrite_model( # noqa C901
14591463
model: torch.nn.Module,
14601464
context: TForwardContext,
14611465
dist_stream: Optional[torch.Stream],
14621466
batch: Optional[In] = None,
14631467
apply_jit: bool = False,
1468+
apply_jit_context: ContextManager[None] = _rewrite_model_apply_jit_context_default,
14641469
pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward,
14651470
pipeline_postproc: bool = False,
14661471
default_stream: Optional[torch.Stream] = None,
@@ -1546,10 +1551,11 @@ def _rewrite_model( # noqa C901
15461551

15471552
# JIT script unsharded modules if applicable.
15481553
if apply_jit:
1549-
graph_model = torch.fx.GraphModule(model, graph)
1550-
_jit_modules(graph_model, "")
1551-
if isinstance(input_model, DistributedModelParallel):
1552-
input_model.module = graph_model
1554+
with apply_jit_context:
1555+
graph_model = torch.fx.GraphModule(model, graph)
1556+
_jit_modules(graph_model, "")
1557+
if isinstance(input_model, DistributedModelParallel):
1558+
input_model.module = graph_model
15531559

15541560
if non_pipelined_sharded_modules:
15551561
logger.warn(

0 commit comments

Comments
 (0)