Skip to content

Commit 604ddc9

Browse files
kddnewtonfacebook-github-bot
authored andcommitted
Allow a context manager to be called around apply_jit (pytorch#2927)
Summary: When running torch.jit.script on the various forward functions, you can run into issues if there are any other utilites interacting with the function definitions. As an example, if you have another JIT running, you need to disable it throughout this process. This commit adds the ability to additionally pass an apply_jit_context context manager wherever apply_jit is currently passed that will be called around the application of the torch jit. Reviewed By: SonicField Differential Revision: D73781040
1 parent 57deb6e commit 604ddc9

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import copy
1111
import enum
1212
import unittest
13-
from typing import List
13+
from contextlib import contextmanager
14+
from typing import Generator, List
1415
from unittest.mock import MagicMock
1516

1617
import torch
@@ -43,6 +44,29 @@ class ModelType(enum.Enum):
4344

4445

4546
class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
47+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
48+
@unittest.skipIf(
49+
not torch.cuda.is_available(),
50+
"Not enough GPUs, this test requires at least one GPU",
51+
)
52+
def test_rewrite_model_apply_jit(self) -> None:
53+
@contextmanager
54+
def apply_jit_context(events: list[str]) -> Generator[None, None, None]:
55+
events.append("__enter__")
56+
yield
57+
events.append("__exit__")
58+
59+
events = []
60+
_rewrite_model(
61+
model=self._setup_model(),
62+
context=TrainPipelineContext(),
63+
dist_stream=None,
64+
apply_jit=True,
65+
apply_jit_context=apply_jit_context(events),
66+
)
67+
68+
self.assertEqual(events, ["__enter__", "__exit__"])
69+
4670
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4771
@unittest.skipIf(
4872
not torch.cuda.is_available(),

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
393393
(applicable to 2D sharding only)
394394
if set and DMP collection is enabled for 2D sharding,
395395
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
396+
apply_jit_context (Optional[ContextManager]): a context manager that
397+
will surround the application of the JIT
396398
"""
397399

398400
# The PipelinedForward class that is used in _rewrite_model
@@ -413,13 +415,15 @@ def __init__(
413415
] = None,
414416
dmp_collection_sync_interval_batches: Optional[int] = 1,
415417
enqueue_batch_after_forward: bool = False,
418+
apply_jit_context: Optional[ContextManager[None]] = None,
416419
) -> None:
417420
self._model = model
418421
self._optimizer = optimizer
419422
self._device = device
420423
self._execute_all_batches = execute_all_batches
421424
self._apply_jit = apply_jit
422425
self._enqueue_batch_after_forward = enqueue_batch_after_forward
426+
self._apply_jit_context = apply_jit_context
423427

424428
if device.type == "cuda":
425429
# use two data streams to support two concurrent batches
@@ -716,6 +720,7 @@ def _pipeline_model(
716720
apply_jit=self._apply_jit,
717721
pipelined_forward=pipelined_forward,
718722
pipeline_postproc=self._pipeline_postproc,
723+
apply_jit_context=self._apply_jit_context,
719724
)
720725
# initializes input dist, so we can override input dist forwards
721726
self.start_sparse_data_dist(batch, context)
@@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
904909
TODO: pipeline_postproc, custom_model_fwd, strict
905910
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
906911
(for batch i+1) using a new stream, else re-using the data_dist stream
912+
apply_jit_context (ContextManager): a context manager that will surround the
913+
application of the JIT
907914
"""
908915

909916
# The PipelinedForward class that is used in _rewrite_model
@@ -922,6 +929,7 @@ def __init__(
922929
] = None,
923930
strict: bool = False,
924931
emb_lookup_stream: str = "data_dist", # new, current, data_dist (default)
932+
apply_jit_context: Optional[ContextManager[None]] = None,
925933
) -> None:
926934
super().__init__(
927935
model=model,
@@ -932,6 +940,7 @@ def __init__(
932940
context_type=EmbeddingTrainPipelineContext,
933941
pipeline_postproc=pipeline_postproc,
934942
custom_model_fwd=custom_model_fwd,
943+
apply_jit_context=apply_jit_context,
935944
)
936945
if emb_lookup_stream == "new":
937946
self._emb_lookup_stream: Optional[torch.Stream] = (
@@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
10661075
(applicable to 2D sharding only)
10671076
if set and DMP collection is enabled for 2D sharding,
10681077
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
1078+
apply_jit_context (ContextManager): a context manager that will surround the
1079+
application of the JIT
10691080
"""
10701081

10711082
# The PipelinedForward class that is used in _rewrite_model
@@ -1086,6 +1097,7 @@ def __init__(
10861097
] = None,
10871098
strict: bool = False,
10881099
dmp_collection_sync_interval_batches: Optional[int] = 1,
1100+
apply_jit_context: Optional[ContextManager[None]] = None,
10891101
) -> None:
10901102
super().__init__(
10911103
model=model,
@@ -1097,6 +1109,7 @@ def __init__(
10971109
pipeline_postproc=pipeline_postproc,
10981110
custom_model_fwd=custom_model_fwd,
10991111
dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches,
1112+
apply_jit_context=apply_jit_context,
11001113
)
11011114
self._start_batch = start_batch
11021115
self._stash_gradients = stash_gradients
@@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
13781391
execute_all_batches (bool): executes remaining batches in pipeline after
13791392
exhausting dataloader iterator.
13801393
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1394+
apply_jit_context (ContextManager): a context manager that will surround the
1395+
application of the JIT
13811396
"""
13821397

13831398
# The PipelinedForward class that is used in _rewrite_model
@@ -1394,6 +1409,7 @@ def __init__(
13941409
custom_model_fwd: Optional[
13951410
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
13961411
] = None,
1412+
apply_jit_context: Optional[ContextManager[None]] = None,
13971413
) -> None:
13981414
super().__init__(
13991415
model=model,
@@ -1404,6 +1420,7 @@ def __init__(
14041420
context_type=PrefetchTrainPipelineContext,
14051421
pipeline_postproc=pipeline_postproc,
14061422
custom_model_fwd=custom_model_fwd,
1423+
apply_jit_context=apply_jit_context,
14071424
)
14081425
self._context = PrefetchTrainPipelineContext(version=0)
14091426
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
15351552
device (torch.device): device where device transfer, sparse data dist, and
15361553
forward/backward pass will happen.
15371554
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1555+
apply_jit_context (Optional[ContextManager]): a context manager that
1556+
will surround the application of the JIT
15381557
"""
15391558

15401559
# The PipelinedForward class that is used in _rewrite_model
@@ -1546,8 +1565,16 @@ def __init__(
15461565
optimizer: torch.optim.Optimizer,
15471566
device: torch.device,
15481567
apply_jit: bool = False,
1568+
apply_jit_context: Optional[ContextManager[None]] = None,
15491569
) -> None:
1550-
super().__init__(model, optimizer, device, True, apply_jit)
1570+
super().__init__(
1571+
model,
1572+
optimizer,
1573+
device,
1574+
True,
1575+
apply_jit,
1576+
apply_jit_context=apply_jit_context,
1577+
)
15511578
self._batch_loader: Optional[DataLoadingThread[In]] = None
15521579

15531580
def __del__(self) -> None:
@@ -1909,6 +1936,7 @@ def __init__(
19091936
custom_model_fwd: Optional[
19101937
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
19111938
] = None,
1939+
apply_jit_context: Optional[ContextManager[None]] = None,
19121940
) -> None:
19131941
super().__init__(
19141942
model,
@@ -1919,6 +1947,7 @@ def __init__(
19191947
context_type,
19201948
pipeline_postproc,
19211949
custom_model_fwd,
1950+
apply_jit_context=apply_jit_context,
19221951
)
19231952

19241953
torch._logging.set_logs(compiled_autograd_verbose=True)

torchrec/distributed/train_pipeline/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import itertools
1414
import logging
1515
from collections import defaultdict, OrderedDict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from dataclasses import dataclass, field
1818

1919
from itertools import chain
@@ -22,6 +22,7 @@
2222
Any,
2323
Callable,
2424
cast,
25+
ContextManager,
2526
Dict,
2627
Generator,
2728
Generic,
@@ -1540,6 +1541,7 @@ def _rewrite_model( # noqa C901
15401541
pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward,
15411542
pipeline_postproc: bool = False,
15421543
default_stream: Optional[torch.Stream] = None,
1544+
apply_jit_context: Optional[ContextManager[None]] = None,
15431545
) -> Tuple[
15441546
List[ShardedModule],
15451547
torch.nn.Module,
@@ -1643,10 +1645,14 @@ def _rewrite_model( # noqa C901
16431645

16441646
# JIT script unsharded modules if applicable.
16451647
if apply_jit:
1646-
graph_model = torch.fx.GraphModule(model, graph)
1647-
_jit_modules(graph_model, "")
1648-
if isinstance(input_model, DistributedModelParallel):
1649-
input_model.module = graph_model
1648+
if apply_jit_context is None:
1649+
apply_jit_context = nullcontext()
1650+
1651+
with apply_jit_context:
1652+
graph_model = torch.fx.GraphModule(model, graph)
1653+
_jit_modules(graph_model, "")
1654+
if isinstance(input_model, DistributedModelParallel):
1655+
input_model.module = graph_model
16501656

16511657
if non_pipelined_sharded_modules:
16521658
logger.warn(

0 commit comments

Comments
 (0)