Skip to content

Commit 24c8157

Browse files
authored
Set parallelism_config in constructor due to Trainer reset of State (#3713)
1 parent 6891c57 commit 24c8157

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/accelerate/state.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ def __init__(
903903
fsdp_plugin=None,
904904
torch_tp_plugin=None,
905905
megatron_lm_plugin=None,
906+
parallelism_config=None,
906907
_from_accelerator: bool = False,
907908
**kwargs,
908909
):
@@ -917,6 +918,7 @@ def __init__(
917918
self.deepspeed_plugins = None
918919
self.use_ipex = None
919920
self.torch_tp_plugin = torch_tp_plugin
921+
self.parallelism_config = parallelism_config
920922
mixed_precision = (
921923
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
922924
if mixed_precision is None
@@ -995,13 +997,13 @@ def __init__(
995997
raise ValueError(
996998
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
997999
)
998-
if (
999-
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None
1000-
) or (self.parallelism_config is not None and self.parallelism_config.cp_enabled):
1001-
self.distributed_type = DistributedType.FSDP
1002-
if self._mixed_precision != "no":
1003-
fsdp_plugin.set_mixed_precision(self._mixed_precision)
1004-
self.fsdp_plugin = fsdp_plugin
1000+
if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
1001+
self.parallelism_config is not None and self.parallelism_config.cp_enabled
1002+
):
1003+
self.distributed_type = DistributedType.FSDP
1004+
if self._mixed_precision != "no" and fsdp_plugin is not None:
1005+
fsdp_plugin.set_mixed_precision(self._mixed_precision)
1006+
self.fsdp_plugin = fsdp_plugin
10051007
if os.environ.get(
10061008
"ACCELERATE_USE_MEGATRON_LM", "false"
10071009
).lower() == "true" and self.distributed_type not in [

0 commit comments

Comments
 (0)