@@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
393
393
(applicable to 2D sharding only)
394
394
if set and DMP collection is enabled for 2D sharding,
395
395
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
396
+ apply_jit_context (Optional[ContextManager]): a context manager that
397
+ will surround the application of the JIT
396
398
"""
397
399
398
400
# The PipelinedForward class that is used in _rewrite_model
@@ -413,13 +415,15 @@ def __init__(
413
415
] = None ,
414
416
dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
415
417
enqueue_batch_after_forward : bool = False ,
418
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
416
419
) -> None :
417
420
self ._model = model
418
421
self ._optimizer = optimizer
419
422
self ._device = device
420
423
self ._execute_all_batches = execute_all_batches
421
424
self ._apply_jit = apply_jit
422
425
self ._enqueue_batch_after_forward = enqueue_batch_after_forward
426
+ self ._apply_jit_context = apply_jit_context
423
427
424
428
if device .type == "cuda" :
425
429
# use two data streams to support two concurrent batches
@@ -716,6 +720,7 @@ def _pipeline_model(
716
720
apply_jit = self ._apply_jit ,
717
721
pipelined_forward = pipelined_forward ,
718
722
pipeline_postproc = self ._pipeline_postproc ,
723
+ apply_jit_context = self ._apply_jit_context ,
719
724
)
720
725
# initializes input dist, so we can override input dist forwards
721
726
self .start_sparse_data_dist (batch , context )
@@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
904
909
TODO: pipeline_postproc, custom_model_fwd, strict
905
910
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
906
911
(for batch i+1) using a new stream, else re-using the data_dist stream
912
+ apply_jit_context (ContextManager): a context manager that will surround the
913
+ application of the JIT
907
914
"""
908
915
909
916
# The PipelinedForward class that is used in _rewrite_model
@@ -922,6 +929,7 @@ def __init__(
922
929
] = None ,
923
930
strict : bool = False ,
924
931
emb_lookup_stream : str = "data_dist" , # new, current, data_dist (default)
932
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
925
933
) -> None :
926
934
super ().__init__ (
927
935
model = model ,
@@ -932,6 +940,7 @@ def __init__(
932
940
context_type = EmbeddingTrainPipelineContext ,
933
941
pipeline_postproc = pipeline_postproc ,
934
942
custom_model_fwd = custom_model_fwd ,
943
+ apply_jit_context = apply_jit_context ,
935
944
)
936
945
if emb_lookup_stream == "new" :
937
946
self ._emb_lookup_stream : Optional [torch .Stream ] = (
@@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
1066
1075
(applicable to 2D sharding only)
1067
1076
if set and DMP collection is enabled for 2D sharding,
1068
1077
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
1078
+ apply_jit_context (ContextManager): a context manager that will surround the
1079
+ application of the JIT
1069
1080
"""
1070
1081
1071
1082
# The PipelinedForward class that is used in _rewrite_model
@@ -1086,6 +1097,7 @@ def __init__(
1086
1097
] = None ,
1087
1098
strict : bool = False ,
1088
1099
dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
1100
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1089
1101
) -> None :
1090
1102
super ().__init__ (
1091
1103
model = model ,
@@ -1097,6 +1109,7 @@ def __init__(
1097
1109
pipeline_postproc = pipeline_postproc ,
1098
1110
custom_model_fwd = custom_model_fwd ,
1099
1111
dmp_collection_sync_interval_batches = dmp_collection_sync_interval_batches ,
1112
+ apply_jit_context = apply_jit_context ,
1100
1113
)
1101
1114
self ._start_batch = start_batch
1102
1115
self ._stash_gradients = stash_gradients
@@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1378
1391
execute_all_batches (bool): executes remaining batches in pipeline after
1379
1392
exhausting dataloader iterator.
1380
1393
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1394
+ apply_jit_context (ContextManager): a context manager that will surround the
1395
+ application of the JIT
1381
1396
"""
1382
1397
1383
1398
# The PipelinedForward class that is used in _rewrite_model
@@ -1394,6 +1409,7 @@ def __init__(
1394
1409
custom_model_fwd : Optional [
1395
1410
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1396
1411
] = None ,
1412
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1397
1413
) -> None :
1398
1414
super ().__init__ (
1399
1415
model = model ,
@@ -1404,6 +1420,7 @@ def __init__(
1404
1420
context_type = PrefetchTrainPipelineContext ,
1405
1421
pipeline_postproc = pipeline_postproc ,
1406
1422
custom_model_fwd = custom_model_fwd ,
1423
+ apply_jit_context = apply_jit_context ,
1407
1424
)
1408
1425
self ._context = PrefetchTrainPipelineContext (version = 0 )
1409
1426
self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
1535
1552
device (torch.device): device where device transfer, sparse data dist, and
1536
1553
forward/backward pass will happen.
1537
1554
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1555
+ apply_jit_context (Optional[ContextManager]): a context manager that
1556
+ will surround the application of the JIT
1538
1557
"""
1539
1558
1540
1559
# The PipelinedForward class that is used in _rewrite_model
@@ -1546,8 +1565,16 @@ def __init__(
1546
1565
optimizer : torch .optim .Optimizer ,
1547
1566
device : torch .device ,
1548
1567
apply_jit : bool = False ,
1568
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1549
1569
) -> None :
1550
- super ().__init__ (model , optimizer , device , True , apply_jit )
1570
+ super ().__init__ (
1571
+ model ,
1572
+ optimizer ,
1573
+ device ,
1574
+ True ,
1575
+ apply_jit ,
1576
+ apply_jit_context = apply_jit_context ,
1577
+ )
1551
1578
self ._batch_loader : Optional [DataLoadingThread [In ]] = None
1552
1579
1553
1580
def __del__ (self ) -> None :
@@ -1909,6 +1936,7 @@ def __init__(
1909
1936
custom_model_fwd : Optional [
1910
1937
Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1911
1938
] = None ,
1939
+ apply_jit_context : Optional [ContextManager [None ]] = None ,
1912
1940
) -> None :
1913
1941
super ().__init__ (
1914
1942
model ,
@@ -1919,6 +1947,7 @@ def __init__(
1919
1947
context_type ,
1920
1948
pipeline_postproc ,
1921
1949
custom_model_fwd ,
1950
+ apply_jit_context = apply_jit_context ,
1922
1951
)
1923
1952
1924
1953
torch ._logging .set_logs (compiled_autograd_verbose = True )
0 commit comments