@@ -312,7 +312,7 @@ def __init__(
312312 context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
313313 pipeline_preproc : bool = False ,
314314 custom_model_fwd : Optional [
315- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
315+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
316316 ] = None ,
317317 ) -> None :
318318 self ._model = model
@@ -366,6 +366,10 @@ def __init__(
366366 self ._dataloader_exhausted : bool = False
367367 self ._context_type : Type [TrainPipelineContext ] = context_type
368368
369+ self ._model_fwd : Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]] = (
370+ custom_model_fwd if custom_model_fwd else model
371+ )
372+
369373 # DEPRECATED FIELDS
370374 self ._batch_i : Optional [In ] = None
371375 self ._batch_ip1 : Optional [In ] = None
@@ -483,9 +487,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
483487
484488 # forward
485489 with record_function ("## forward ##" ):
486- losses , output = cast (
487- Tuple [torch .Tensor , Out ], self ._model (self .batches [0 ])
488- )
490+ losses , output = self ._model_fwd (self .batches [0 ])
489491
490492 if len (self .batches ) >= 2 :
491493 self .wait_sparse_data_dist (self .contexts [1 ])
@@ -718,7 +720,7 @@ def __init__(
718720 stash_gradients : bool = False ,
719721 pipeline_preproc : bool = False ,
720722 custom_model_fwd : Optional [
721- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
723+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
722724 ] = None ,
723725 ) -> None :
724726 super ().__init__ (
@@ -729,6 +731,7 @@ def __init__(
729731 apply_jit = apply_jit ,
730732 context_type = EmbeddingTrainPipelineContext ,
731733 pipeline_preproc = pipeline_preproc ,
734+ custom_model_fwd = custom_model_fwd ,
732735 )
733736 self ._start_batch = start_batch
734737 self ._stash_gradients = stash_gradients
@@ -893,7 +896,7 @@ def _mlp_forward(
893896 _wait_for_events (
894897 batch , context , torch .get_device_module (self ._device ).current_stream ()
895898 )
896- return cast ( Tuple [ torch . Tensor , Out ], self ._model_fwd (batch ) )
899+ return self ._model_fwd (batch )
897900
898901 def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
899902 default_stream = torch .get_device_module (self ._device ).current_stream ()
@@ -1020,6 +1023,10 @@ def __init__(
10201023 device : torch .device ,
10211024 execute_all_batches : bool = True ,
10221025 apply_jit : bool = False ,
1026+ pipeline_preproc : bool = False ,
1027+ custom_model_fwd : Optional [
1028+ Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1029+ ] = None ,
10231030 ) -> None :
10241031 super ().__init__ (
10251032 model = model ,
@@ -1028,6 +1035,8 @@ def __init__(
10281035 execute_all_batches = execute_all_batches ,
10291036 apply_jit = apply_jit ,
10301037 context_type = PrefetchTrainPipelineContext ,
1038+ pipeline_preproc = pipeline_preproc ,
1039+ custom_model_fwd = custom_model_fwd ,
10311040 )
10321041 self ._context = PrefetchTrainPipelineContext (version = 0 )
10331042 self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1084,7 +1093,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10841093 self ._wait_sparse_data_dist ()
10851094 # forward
10861095 with record_function ("## forward ##" ):
1087- losses , output = cast ( Tuple [ torch . Tensor , Out ], self ._model (self ._batch_i ) )
1096+ losses , output = self ._model_fwd (self ._batch_i )
10881097
10891098 self ._prefetch (self ._batch_ip1 )
10901099
0 commit comments