11
11
import contextlib
12
12
import logging
13
13
from collections import deque
14
+ from contextlib import nullcontext
14
15
from dataclasses import dataclass
15
16
from typing import (
16
17
Any ,
@@ -318,6 +319,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
318
319
return output
319
320
320
321
322
+ _apply_jit_context_default : ContextManager [None ] = nullcontext ()
323
+
324
+
321
325
class TrainPipelineSparseDist (TrainPipeline [In , Out ]):
322
326
"""
323
327
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
@@ -343,6 +347,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
343
347
execute_all_batches (bool): executes remaining batches in pipeline after
344
348
exhausting dataloader iterator.
345
349
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
350
+ apply_jit_context (ContextManager): a context manager that will surround the
351
+ application of the JIT
346
352
"""
347
353
348
354
# The PipelinedForward class that is used in _rewrite_model
@@ -355,6 +361,7 @@ def __init__(
355
361
device : torch .device ,
356
362
execute_all_batches : bool = True ,
357
363
apply_jit : bool = False ,
364
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
358
365
context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
359
366
# keep for backward compatibility
360
367
pipeline_postproc : bool = False ,
@@ -367,6 +374,7 @@ def __init__(
367
374
self ._device = device
368
375
self ._execute_all_batches = execute_all_batches
369
376
self ._apply_jit = apply_jit
377
+ self ._apply_jit_context = apply_jit_context
370
378
371
379
if device .type == "cuda" :
372
380
# use two data streams to support two concurrent batches
@@ -641,6 +649,7 @@ def _pipeline_model(
641
649
default_stream = torch .get_device_module (self ._device ).current_stream (),
642
650
batch = batch ,
643
651
apply_jit = self ._apply_jit ,
652
+ apply_jit_context = self ._apply_jit_context ,
644
653
pipelined_forward = pipelined_forward ,
645
654
pipeline_postproc = self ._pipeline_postproc ,
646
655
)
@@ -820,6 +829,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
820
829
execute_all_batches (bool): executes remaining batches in pipeline after
821
830
exhausting dataloader iterator.
822
831
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
832
+ apply_jit_context (ContextManager): a context manager that will surround the
833
+ application of the JIT
823
834
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
824
835
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
825
836
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
@@ -835,6 +846,7 @@ def __init__(
835
846
device : torch .device ,
836
847
execute_all_batches : bool = True ,
837
848
apply_jit : bool = False ,
849
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
838
850
start_batch : int = 900 ,
839
851
stash_gradients : bool = False ,
840
852
pipeline_postproc : bool = True ,
@@ -849,6 +861,7 @@ def __init__(
849
861
device = device ,
850
862
execute_all_batches = execute_all_batches ,
851
863
apply_jit = apply_jit ,
864
+ apply_jit_context = apply_jit_context ,
852
865
context_type = EmbeddingTrainPipelineContext ,
853
866
pipeline_postproc = pipeline_postproc ,
854
867
custom_model_fwd = custom_model_fwd ,
@@ -1135,6 +1148,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1135
1148
execute_all_batches (bool): executes remaining batches in pipeline after
1136
1149
exhausting dataloader iterator.
1137
1150
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1151
+ apply_jit_context (ContextManager): a context manager that will surround the
1152
+ application of the JIT
1138
1153
"""
1139
1154
1140
1155
# The PipelinedForward class that is used in _rewrite_model
@@ -1147,6 +1162,7 @@ def __init__(
1147
1162
device : torch .device ,
1148
1163
execute_all_batches : bool = True ,
1149
1164
apply_jit : bool = False ,
1165
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1150
1166
pipeline_postproc : bool = True ,
1151
1167
custom_model_fwd : Optional [
1152
1168
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
@@ -1158,6 +1174,7 @@ def __init__(
1158
1174
device = device ,
1159
1175
execute_all_batches = execute_all_batches ,
1160
1176
apply_jit = apply_jit ,
1177
+ apply_jit_context = apply_jit_context ,
1161
1178
context_type = PrefetchTrainPipelineContext ,
1162
1179
pipeline_postproc = pipeline_postproc ,
1163
1180
custom_model_fwd = custom_model_fwd ,
@@ -1292,6 +1309,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1292
1309
device (torch.device): device where device transfer, sparse data dist, and
1293
1310
forward/backward pass will happen.
1294
1311
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1312
+ apply_jit_context (ContextManager): a context manager that will surround the
1313
+ application of the JIT
1295
1314
"""
1296
1315
1297
1316
# The PipelinedForward class that is used in _rewrite_model
@@ -1303,8 +1322,9 @@ def __init__(
1303
1322
optimizer : torch .optim .Optimizer ,
1304
1323
device : torch .device ,
1305
1324
apply_jit : bool = False ,
1325
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1306
1326
) -> None :
1307
- super ().__init__ (model , optimizer , device , True , apply_jit )
1327
+ super ().__init__ (model , optimizer , device , True , apply_jit , apply_jit_context )
1308
1328
self ._batch_loader : Optional [DataLoadingThread [In ]] = None
1309
1329
1310
1330
def __del__ (self ) -> None :
@@ -1661,6 +1681,7 @@ def __init__(
1661
1681
device : torch .device ,
1662
1682
execute_all_batches : bool = True ,
1663
1683
apply_jit : bool = False ,
1684
+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
1664
1685
context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
1665
1686
pipeline_postproc : bool = False ,
1666
1687
custom_model_fwd : Optional [
@@ -1673,6 +1694,7 @@ def __init__(
1673
1694
device ,
1674
1695
execute_all_batches ,
1675
1696
apply_jit ,
1697
+ apply_jit_context ,
1676
1698
context_type ,
1677
1699
pipeline_postproc ,
1678
1700
custom_model_fwd ,
0 commit comments