@@ -309,7 +309,7 @@ def __init__(
309309 context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
310310 pipeline_preproc : bool = False ,
311311 custom_model_fwd : Optional [
312- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
312+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
313313 ] = None ,
314314 ) -> None :
315315 self ._model = model
@@ -363,6 +363,10 @@ def __init__(
363363 self ._dataloader_exhausted : bool = False
364364 self ._context_type : Type [TrainPipelineContext ] = context_type
365365
366+ self ._model_fwd : Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]] = (
367+ custom_model_fwd if custom_model_fwd else model
368+ )
369+
366370 # DEPRECATED FIELDS
367371 self ._batch_i : Optional [In ] = None
368372 self ._batch_ip1 : Optional [In ] = None
@@ -480,9 +484,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
480484
481485 # forward
482486 with record_function ("## forward ##" ):
483- losses , output = cast (
484- Tuple [torch .Tensor , Out ], self ._model (self .batches [0 ])
485- )
487+ losses , output = self ._model_fwd (self .batches [0 ])
486488
487489 if len (self .batches ) >= 2 :
488490 self .wait_sparse_data_dist (self .contexts [1 ])
@@ -715,7 +717,7 @@ def __init__(
715717 stash_gradients : bool = False ,
716718 pipeline_preproc : bool = False ,
717719 custom_model_fwd : Optional [
718- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
720+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
719721 ] = None ,
720722 ) -> None :
721723 super ().__init__ (
@@ -726,6 +728,7 @@ def __init__(
726728 apply_jit = apply_jit ,
727729 context_type = EmbeddingTrainPipelineContext ,
728730 pipeline_preproc = pipeline_preproc ,
731+ custom_model_fwd = custom_model_fwd ,
729732 )
730733 self ._start_batch = start_batch
731734 self ._stash_gradients = stash_gradients
@@ -749,9 +752,6 @@ def __init__(
749752 self ._embedding_odd_streams : List [Optional [torch .Stream ]] = []
750753 self ._embedding_even_streams : List [Optional [torch .Stream ]] = []
751754 self ._gradients : Dict [str , torch .Tensor ] = {}
752- self ._model_fwd : Union [
753- torch .nn .Module , Callable [[In ], Tuple [torch .Tensor , List [torch .Tensor ]]]
754- ] = (custom_model_fwd if custom_model_fwd is not None else model )
755755
756756 def _grad_swap (self ) -> None :
757757 for name , param in self ._model .named_parameters ():
@@ -890,7 +890,7 @@ def _mlp_forward(
890890 _wait_for_events (
891891 batch , context , torch .get_device_module (self ._device ).current_stream ()
892892 )
893- return cast ( Tuple [ torch . Tensor , Out ], self ._model_fwd (batch ) )
893+ return self ._model_fwd (batch )
894894
895895 def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
896896 default_stream = torch .get_device_module (self ._device ).current_stream ()
@@ -1017,6 +1017,10 @@ def __init__(
10171017 device : torch .device ,
10181018 execute_all_batches : bool = True ,
10191019 apply_jit : bool = False ,
1020+ pipeline_preproc : bool = False ,
1021+ custom_model_fwd : Optional [
1022+ Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1023+ ] = None ,
10201024 ) -> None :
10211025 super ().__init__ (
10221026 model = model ,
@@ -1025,6 +1029,8 @@ def __init__(
10251029 execute_all_batches = execute_all_batches ,
10261030 apply_jit = apply_jit ,
10271031 context_type = PrefetchTrainPipelineContext ,
1032+ pipeline_preproc = pipeline_preproc ,
1033+ custom_model_fwd = custom_model_fwd ,
10281034 )
10291035 self ._context = PrefetchTrainPipelineContext (version = 0 )
10301036 self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1081,7 +1087,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10811087 self ._wait_sparse_data_dist ()
10821088 # forward
10831089 with record_function ("## forward ##" ):
1084- losses , output = cast ( Tuple [ torch . Tensor , Out ], self ._model (self ._batch_i ) )
1090+ losses , output = self ._model_fwd (self ._batch_i )
10851091
10861092 self ._prefetch (self ._batch_ip1 )
10871093
0 commit comments