Skip to content

Commit

Permalink
add util for ram efficient loading of model when using fsdp (#25107)
Browse files Browse the repository at this point in the history
* add util for ram efficient loading of model when using fsdp

* make fix-copies

* fixes 😅

* docs

* making it further easier to use

* rename the function

* refactor to handle fsdp ram efficiency in `from_pretrained`

* fixes

* fixes

* fixes

* update

* fixes

* revert `load_pretrained_model_only_on_rank0`

* resolve `load_from_checkpoint`
  • Loading branch information
pacman100 authored Aug 17, 2023
1 parent 4e1dee0 commit c4c0cef
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 62 deletions.
67 changes: 48 additions & 19 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_torch_tpu_available,
logging,
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
Expand Down Expand Up @@ -106,6 +107,14 @@
_init_weights = True


def is_fsdp_enabled():
return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
Expand Down Expand Up @@ -458,7 +467,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
return safe_load_file(checkpoint_file)
try:
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled)
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -2283,6 +2296,9 @@ def from_pretrained(
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

if is_fsdp_enabled():
low_cpu_mem_usage = True

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
Expand Down Expand Up @@ -3238,7 +3254,8 @@ def _fix_key(key):
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)

if device_map is None:
model.tie_weights()
if device_map is None and not is_fsdp_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
Expand Down Expand Up @@ -3443,23 +3460,35 @@ def _find_mismatched_keys(
)

if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down
20 changes: 10 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,6 @@ def __init__(
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

self.forward_prefetch = False
if self.args.fsdp_config.get("forward_prefetch", False):
self.forward_prefetch = True

self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
Expand Down Expand Up @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None):
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["fsdp_min_num_params"] > 0:
if self.args.fsdp_config["min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
Expand Down Expand Up @@ -1517,7 +1513,12 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
if (
resume_from_checkpoint is not None
and not is_sagemaker_mp_enabled()
and not self.is_deepspeed_enabled
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)

# If model was re-initialized, put it on the right device and update self.model_wrapped
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def _inner_training_loop(

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# as the model is wrapped, don't use `accelerator.prepare`
Expand Down Expand Up @@ -3886,7 +3887,6 @@ def create_accelerator_and_postprocess(self):
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down
80 changes: 47 additions & 33 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,13 @@ class TrainingArguments:
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
A List of config and its options:
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
- min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
- transformer_layer_cls_to_wrap (`List[str]`, *optional*):
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
`T5Block` .... (useful only when `fsdp` flag is passed).
- fsdp_backward_prefetch (`str`, *optional*)
- backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).
Expand All @@ -454,14 +454,22 @@ class TrainingArguments:
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
- forward_prefetch (`bool`, *optional*, defaults to `False`)
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `False`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
refer this
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
Expand Down Expand Up @@ -1520,44 +1528,44 @@ def __post_init__(self):
self.fsdp_config = {}

if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k, v in self.fsdp_config.items():
if k.startswith("fsdp_"):
self.fsdp_config[k.replace("fsdp_", "")] = v
del self.fsdp_config[k]

if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)

self.fsdp_config["fsdp_min_num_params"] = max(
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
)
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)

# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
]
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]

if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", []
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]

if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")

if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")

if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
)
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
Expand All @@ -1583,23 +1591,29 @@ def __post_init__(self):
FSDP_SHARDING_STRATEGY,
)

prefix = "FSDP_"
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["min_num_params"] > 0:
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if self.tpu_metrics_debug:
warnings.warn(
Expand Down

0 comments on commit c4c0cef

Please sign in to comment.