Skip to content

Commit

Permalink
Fix Issue with Resizing Parameters on the Meta Device in Low CPU Mem …
Browse files Browse the repository at this point in the history
…Mode (#96)

* fix: quant models on meta device cannot have embedding resized

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* fix: grad reduce hook

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

---------

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim authored Oct 28, 2024
1 parent 98fcd2e commit 28eb168
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
register_tensors_as_parameters_patch_rule,
requires_installation_on_all_linears,
)
from .fsdp_utils import put_selected_meta_tensors_on_cpu


class AutoGPTQAccelerationPlugin(AccelerationPlugin):
Expand Down Expand Up @@ -219,6 +220,11 @@ def model_loader(self, model_name: str, **kwargs):
# replace
AutoModelForCausalLM.from_config = _old_from_config

# in low_cpu_mem_mode, if certain tensors like embeddings
# are in the meta device, then certain operations like
# embedding resizing will fail
put_selected_meta_tensors_on_cpu(model)

# AutoGPTQ does not set the torch_dtype of the model carefully
model.config.torch_dtype = torch_dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from transformers.utils.import_utils import _is_package_available
import torch

# Local
from .fsdp_utils import put_selected_meta_tensors_on_cpu


# this is a modified copy of the function from peft.utils.other, that we
# will instead use
Expand Down Expand Up @@ -154,6 +157,27 @@ def model_loader(self, model_name: str, **kwargs):
attn_implementation=attn_implementation,
)

if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype

_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# in low_cpu_mem_mode, if certain tensors like embeddings
# are in the meta device, then certain operations like
# embedding resizing will fail
put_selected_meta_tensors_on_cpu(model)

return model

@property
Expand Down
26 changes: 26 additions & 0 deletions plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections import defaultdict

# Third Party
from accelerate.utils import set_module_tensor_to_device
from transformers import PreTrainedModel
import torch

# Copyright The IBM Tuning Team
Expand Down Expand Up @@ -70,3 +72,27 @@ def param_init_fn_tied_param(module: torch.nn.Module):
return module

return param_init_fn_tied_param


# utility to put tensors on the cpu
def put_selected_meta_tensors_on_cpu(model: PreTrainedModel):

done = {}
# - fow now we only put input and output embeddings
for module in [
model.get_input_embeddings(),
model.get_output_embeddings(),
]:

for param_name, param in module.named_parameters(recurse=False):
param_id = id(param)

if param.device == torch.device("meta"):
if param_id not in done:
value = torch.empty(*param.size(), dtype=param.dtype)
done[param_id] = value # memoize
else:
# this is a tied weight, get back the previous value
value = done[param_id]

set_module_tensor_to_device(module, param_name, "cpu", value)
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def _all_reduce_hook(grad):
A = mod.lora_A.default
B = mod.lora_B.default

# install hooks on the adapters
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)

# because we will ignore these from FSDP, we need to manually
# move them to gpu if they are already not on them
# - if the adapters are on meta, we assume that this is for FSDP
Expand Down Expand Up @@ -80,6 +76,11 @@ def _all_reduce_hook(grad):
if is_fsdp_enabled():
dist.broadcast(B.weight, src=0)

# install hooks on the adapters
# - this has to be done after all weight replacement happens
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)

def register_foak_model_patch_rules(base_type):
# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
Expand Down

0 comments on commit 28eb168

Please sign in to comment.