diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c79a7fb061e802..43f9a434fa8f92 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -73,6 +73,7 @@ is_torch_tpu_available, logging, replace_return_docstrings, + strtobool, ) from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy @@ -106,6 +107,14 @@ _init_weights = True +def is_fsdp_enabled(): + return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + + +def is_fsdp_enabled_and_dist_rank_0(): + return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 + + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION @@ -458,7 +467,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: + if ( + (is_deepspeed_zero3_enabled() or is_fsdp_enabled) + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ): map_location = "meta" else: map_location = "cpu" @@ -2283,6 +2296,9 @@ def from_pretrained( commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + if is_fsdp_enabled(): + low_cpu_mem_usage = True + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning @@ -3238,7 +3254,8 @@ def _fix_key(key): model_buffers = {".".join([prefix, key]) for key in model_buffers} unexpected_keys = list(unexpected_keys - model_buffers) - if device_map is None: + model.tie_weights() + if device_map is None and not is_fsdp_enabled(): ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): id_tensor = id_tensor_storage(tensor) @@ -3443,23 +3460,35 @@ def _find_mismatched_keys( ) if low_cpu_mem_usage: - new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - loaded_keys, - start_prefix, - expected_keys, - device_map=device_map, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_folder=state_dict_folder, - state_dict_index=state_dict_index, - dtype=dtype, - is_quantized=is_quantized, - is_safetensors=is_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - ) - error_msgs += new_error_msgs + if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + is_quantized=is_quantized, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + error_msgs += new_error_msgs + else: + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + if not (is_quantized): + set_module_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + set_module_quantized_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f870a50b3fe143..c43bf19311885b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -465,10 +465,6 @@ def __init__( ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST - self.forward_prefetch = False - if self.args.fsdp_config.get("forward_prefetch", False): - self.forward_prefetch = True - self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None): auto_wrapper_callable = None default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( - "fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap ) - if self.args.fsdp_config["fsdp_min_num_params"] > 0: + if self.args.fsdp_config["min_num_params"] > 0: auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] ) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() @@ -1517,7 +1513,12 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: + if ( + resume_from_checkpoint is not None + and not is_sagemaker_mp_enabled() + and not self.is_deepspeed_enabled + and not self.is_fsdp_enabled + ): self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1651,7 +1652,7 @@ def _inner_training_loop( model = self._wrap_model(self.model_wrapped) - if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # as the model is wrapped, don't use `accelerator.prepare` @@ -3886,7 +3887,6 @@ def create_accelerator_and_postprocess(self): fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) - fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params) if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index f27c7cd0ce3473..b1013965358aad 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -436,13 +436,13 @@ class TrainingArguments: deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`. A List of config and its options: - - fsdp_min_num_params (`int`, *optional*, defaults to `0`): + - min_num_params (`int`, *optional*, defaults to `0`): FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed). - - fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*): + - transformer_layer_cls_to_wrap (`List[str]`, *optional*): List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed). - - fsdp_backward_prefetch (`str`, *optional*) + - backward_prefetch (`str`, *optional*) FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when `fsdp` field is passed). @@ -454,7 +454,7 @@ class TrainingArguments: - `"backward_post"` : This prefetches the next set of parameters after the current set of parameter’s gradient computation. - - fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`) + - forward_prefetch (`bool`, *optional*, defaults to `False`) FSDP's forward prefetch mode (useful only when `fsdp` field is passed). If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. @@ -462,6 +462,14 @@ class TrainingArguments: FSDP's limit_all_gathers (useful only when `fsdp` field is passed). If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. + - use_orig_params (`bool`, *optional*, defaults to `False`) + If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed + frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please + refer this + [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + - sync_module_states (`bool`, *optional*, defaults to `True`) + If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to + ensure they are the same across all ranks after initialization - xla (`bool`, *optional*, defaults to `False`): Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature and its API may evolve in the future. @@ -1520,44 +1528,44 @@ def __post_init__(self): self.fsdp_config = {} if isinstance(self.fsdp_config, str): + if len(self.fsdp) == 0: + warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") with io.open(self.fsdp_config, "r", encoding="utf-8") as f: self.fsdp_config = json.load(f) + for k, v in self.fsdp_config.items(): + if k.startswith("fsdp_"): + self.fsdp_config[k.replace("fsdp_", "")] = v + del self.fsdp_config[k] if self.fsdp_min_num_params > 0: warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) - self.fsdp_config["fsdp_min_num_params"] = max( - self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params - ) + self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params) - # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object - if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str): - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [ - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] - ] + # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]] if self.fsdp_transformer_layer_cls_to_wrap is not None: warnings.warn( "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning ) - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get( - "fsdp_transformer_layer_cls_to_wrap", [] + self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "transformer_layer_cls_to_wrap", [] ) + [self.fsdp_transformer_layer_cls_to_wrap] - if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0: - warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") + if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: + warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.") - if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") if ( len(self.fsdp) > 0 - and self.fsdp_config["fsdp_min_num_params"] > 0 - and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None + and self.fsdp_config["min_num_params"] > 0 + and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None ): - raise ValueError( - "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." - ) + raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) if self.fsdp_config["xla"]: @@ -1583,23 +1591,29 @@ def __post_init__(self): FSDP_SHARDING_STRATEGY, ) + prefix = "FSDP_" for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: # set environment variable for FSDP sharding strategy - os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) + os.environ[f"{prefix}SHARDING_STRATEGY"] = str( + FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + ) elif fsdp_option == FSDPOption.OFFLOAD: - os.environ["FSDP_OFFLOAD_PARAMS"] = "true" + os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" elif fsdp_option == FSDPOption.AUTO_WRAP: - os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] - if self.fsdp_config["fsdp_min_num_params"] > 0: - os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) - os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] - elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + if self.fsdp_config["min_num_params"] > 0: + os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] ) prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") - os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false") + os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") + os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false") if self.tpu_metrics_debug: warnings.warn(