Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"ignored_parameters": fsdp_plugin.ignored_parameters,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
Expand Down
4 changes: 0 additions & 4 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,6 @@ class FullyShardedDataParallelPlugin:
default=None,
metadata={"help": "A list of modules to ignore for FSDP."},
)
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = field(
default=None,
metadata={"help": "A list of parameters to ignore for FSDP."},
)
state_dict_type: "typing.Any" = field(
default=None,
metadata={
Expand Down
17 changes: 16 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@

def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0):
os.makedirs(output_dir, exist_ok=True)

if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
# so, only enable it when num_processes>1
is_multi_process = accelerator.num_processes > 1
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
fsdp_plugin.state_dict_config.rank0_only = is_multi_process

with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
Expand Down Expand Up @@ -70,6 +78,12 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0):

def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0):
accelerator.wait_for_everyone()
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
# so, only enable it when num_processes>1
is_multi_process = accelerator.num_processes > 1
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
Expand Down Expand Up @@ -111,7 +125,8 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0):
)
state_dict = state_dict["model"]
logger.info(f"Model loaded from {ckpt_dir}")
model.load_state_dict(state_dict)
load_result = model.load_state_dict(state_dict)
return load_result


def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
Expand Down