@@ -937,6 +937,7 @@ def __init__(
937937 enable_jit_fused : bool = False ,
938938 enable_sequence_parallelism : bool = False ,
939939 enable_sequence_overlap : bool = False ,
940+ parallel_output : bool = True ,
940941 num_microbatches : Optional [int ] = None ,
941942 microbatch_size : Optional [int ] = None ,
942943 initial_scale : float = 2 ** 16 ,
@@ -961,6 +962,7 @@ def __init__(
961962 pp_style : str = "1f1b" ,
962963 num_model_chunks : int = 1 ,
963964 enable_metadata_cache : bool = True ,
965+ make_vocab_size_divisible_by : int = 128 ,
964966 ) -> None :
965967 super ().__init__ ()
966968 assert (
@@ -1033,6 +1035,8 @@ def __init__(
10331035 enable_jit_fused = self .enable_jit_fused ,
10341036 enable_sequence_parallelism = enable_sequence_parallelism ,
10351037 enable_sequence_overlap = enable_sequence_overlap ,
1038+ parallel_output = parallel_output ,
1039+ make_vocab_size_divisible_by = make_vocab_size_divisible_by ,
10361040 )
10371041 self .amp_config = dict (
10381042 initial_scale = initial_scale ,
0 commit comments