@@ -903,6 +903,7 @@ def __init__(
903903 fsdp_plugin = None ,
904904 torch_tp_plugin = None ,
905905 megatron_lm_plugin = None ,
906+ parallelism_config = None ,
906907 _from_accelerator : bool = False ,
907908 ** kwargs ,
908909 ):
@@ -917,6 +918,7 @@ def __init__(
917918 self .deepspeed_plugins = None
918919 self .use_ipex = None
919920 self .torch_tp_plugin = torch_tp_plugin
921+ self .parallelism_config = parallelism_config
920922 mixed_precision = (
921923 parse_choice_from_env ("ACCELERATE_MIXED_PRECISION" , "no" )
922924 if mixed_precision is None
@@ -995,13 +997,13 @@ def __init__(
995997 raise ValueError (
996998 "Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
997999 )
998- if (
999- os . environ . get ( "ACCELERATE_USE_FSDP" , "false" ). lower () == "true" or fsdp_plugin is not None
1000- ) or ( self . parallelism_config is not None and self . parallelism_config . cp_enabled ):
1001- self .distributed_type = DistributedType .FSDP
1002- if self ._mixed_precision != "no" :
1003- fsdp_plugin .set_mixed_precision (self ._mixed_precision )
1004- self .fsdp_plugin = fsdp_plugin
1000+ if ( os . environ . get ( "ACCELERATE_USE_FSDP" , "false" ). lower () == "true" or fsdp_plugin is not None ) or (
1001+ self . parallelism_config is not None and self . parallelism_config . cp_enabled
1002+ ):
1003+ self .distributed_type = DistributedType .FSDP
1004+ if self ._mixed_precision != "no" and fsdp_plugin is not None :
1005+ fsdp_plugin .set_mixed_precision (self ._mixed_precision )
1006+ self .fsdp_plugin = fsdp_plugin
10051007 if os .environ .get (
10061008 "ACCELERATE_USE_MEGATRON_LM" , "false"
10071009 ).lower () == "true" and self .distributed_type not in [
0 commit comments