diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 8ce7c02..41ea2d6 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -38,6 +38,7 @@ register_tensors_as_parameters_patch_rule, requires_installation_on_all_linears, ) +from .fsdp_utils import put_selected_meta_tensors_on_cpu class AutoGPTQAccelerationPlugin(AccelerationPlugin): @@ -219,6 +220,11 @@ def model_loader(self, model_name: str, **kwargs): # replace AutoModelForCausalLM.from_config = _old_from_config + # in low_cpu_mem_mode, if certain tensors like embeddings + # are in the meta device, then certain operations like + # embedding resizing will fail + put_selected_meta_tensors_on_cpu(model) + # AutoGPTQ does not set the torch_dtype of the model carefully model.config.torch_dtype = torch_dtype diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index 3a4c12a..9b24d0a 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -28,6 +28,9 @@ from transformers.utils.import_utils import _is_package_available import torch +# Local +from .fsdp_utils import put_selected_meta_tensors_on_cpu + # this is a modified copy of the function from peft.utils.other, that we # will instead use @@ -154,6 +157,27 @@ def model_loader(self, model_name: str, **kwargs): attn_implementation=attn_implementation, ) + if ( + world_size > 1 + and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" + ): + config_kwargs["bnb_4bit_quant_storage"] = torch_dtype + + _, _transformers_version = _is_package_available( + "transformers", return_version=True + ) + _trl_installed, _trl_version = _is_package_available( + "trl", return_version=True + ) + + if _transformers_version >= "4.45" and ( + not _trl_installed or (_trl_installed and _trl_version >= "0.12") + ): + # in low_cpu_mem_mode, if certain tensors like embeddings + # are in the meta device, then certain operations like + # embedding resizing will fail + put_selected_meta_tensors_on_cpu(model) + return model @property diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py index 3086cf7..e747e05 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py @@ -2,6 +2,8 @@ from collections import defaultdict # Third Party +from accelerate.utils import set_module_tensor_to_device +from transformers import PreTrainedModel import torch # Copyright The IBM Tuning Team @@ -70,3 +72,27 @@ def param_init_fn_tied_param(module: torch.nn.Module): return module return param_init_fn_tied_param + + +# utility to put tensors on the cpu +def put_selected_meta_tensors_on_cpu(model: PreTrainedModel): + + done = {} + # - fow now we only put input and output embeddings + for module in [ + model.get_input_embeddings(), + model.get_output_embeddings(), + ]: + + for param_name, param in module.named_parameters(recurse=False): + param_id = id(param) + + if param.device == torch.device("meta"): + if param_id not in done: + value = torch.empty(*param.size(), dtype=param.dtype) + done[param_id] = value # memoize + else: + # this is a tied weight, get back the previous value + value = done[param_id] + + set_module_tensor_to_device(module, param_name, "cpu", value) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index 0558d6b..c825ebb 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -48,10 +48,6 @@ def _all_reduce_hook(grad): A = mod.lora_A.default B = mod.lora_B.default - # install hooks on the adapters - A.weight.register_hook(_all_reduce_hook) - B.weight.register_hook(_all_reduce_hook) - # because we will ignore these from FSDP, we need to manually # move them to gpu if they are already not on them # - if the adapters are on meta, we assume that this is for FSDP @@ -80,6 +76,11 @@ def _all_reduce_hook(grad): if is_fsdp_enabled(): dist.broadcast(B.weight, src=0) + # install hooks on the adapters + # - this has to be done after all weight replacement happens + A.weight.register_hook(_all_reduce_hook) + B.weight.register_hook(_all_reduce_hook) + def register_foak_model_patch_rules(base_type): # Third Party from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel