Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Low CPU Memory Mode Issues for Quantized Peft #90

Merged
merged 8 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions plugins/accelerated-peft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Currently only supports LoRA-related techniques, but more are in the pipeline to

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅
[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface<br>bitsandbytes | ✅ | ✅
[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅ | ✅
[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface<br>bitsandbytes | ✅ | ✅ | ✅


### Key Points
Expand Down Expand Up @@ -43,6 +43,7 @@ GPTQ-LORA depends on an AutoGPTQ backend to run. There are 2 backend options

## Known Issues

<!--
- Models with sliding windows (e.g., Mistral, Mixtral) will have [memory and throughout issues](https://github.com/huggingface/transformers/issues/30461).
-->
- GPTQ-LORA sometimes observed to have `nan` grad norms in the begining of training, but training proceeds well otherwise.
- `low_cpu_mem_usage` temporarily disabled for AutoGPTQ until bug with `make_sure_no_tensor_in_meta_device` is resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.import_utils import _is_package_available
import torch
import torch.distributed

Expand Down Expand Up @@ -61,10 +62,6 @@ def __init__(self, configurations: Dict[str, Dict]):
)

if self.use_external_lib:
# Third Party
from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel
_is_package_available,
)

assert _is_package_available("auto_gptq") is True, (
"Unable to use external library, auto_gptq module not found. "
Expand Down Expand Up @@ -351,6 +348,48 @@ def augmentation(

return model, modifiable_args

def get_callbacks_and_ready_for_train(
anhuong marked this conversation as resolved.
Show resolved Hide resolved
self, model: torch.nn.Module = None, accelerator=None
):
callbacks = []
if (
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
model.get_base_model(),
accelerator.device,
)
return callbacks


# register
AccelerationPlugin.register_plugin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,64 +116,12 @@ def model_loader(self, model_name: str, **kwargs):
except ValueError:
world_size = 1 # pg not init

patched_is_local_dist_rank_0 = None
if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype

# - of course assume that this package must exist, simply need the version
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)

# this is a workaround that disables low_cpu_mem_mode for quant QLORA
# - this issue was introduced in https://github.com/huggingface/transformers/pull/33154
# whereby the low_cpu_mem_mode was actually fixed.
# - However fixing it causes some problems with the current impl.
# 1. For lora fused ops, the adapters cannot be managed by FSDP, as
# forwards are not called. This causes issue 2) in
# https://github.com/foundation-model-stack/fms-acceleration/issues/83
# where the adapters are still sharded when passed in the fused-ops.
# However, if low_cpu_mem_mode=True, then we NEED FSDP to intialize
# their state, which contradicts the above point.
#
# 2. We have observed,
# see https://github.com/foundation-model-stack/fms-acceleration/pull/86
# that low_cpu_mem_mode=True can cause torch distributed primitives
# to hang.

if _transformers_version >= "4.45":

# pylint: disable=import-outside-toplevel
# Third Party
from fms_acceleration.model_patcher import patch_target_module
import transformers.modeling_utils

def _truthy():
return (
True # use this to always return True to is_local_dist_rank_0
)

# - we cannot use the model patcher and this needs to be called immediately below
# at the model_loader
# - but we immediately revert the patch after loading
patched_is_local_dist_rank_0 = (
transformers.modeling_utils.is_local_dist_rank_0
)
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
_truthy,
)

warnings.warn(
"Disabling low_cpu_mem_mode in the BNBAccelerationPlugin as this may "
"potentiall cause problems with: "
"1. the fused-ops-and-kernels package, and, "
"2. the syncing of FSDP modules across devices."
)

elif world_size > 1:
warnings.warn(
"Running in distributed mode but bnb_4bit_quant_storage is not set. "
Expand Down Expand Up @@ -206,13 +154,6 @@ def _truthy():
attn_implementation=attn_implementation,
)

if patched_is_local_dist_rank_0 is not None:
# replace it
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
patched_is_local_dist_rank_0,
)

return model

@property
Expand Down Expand Up @@ -252,6 +193,49 @@ def augmentation(
modifiable_args = (None,) # return a None
return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
):
callbacks = []
if (
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
model if self._no_peft_model else model.get_base_model(),
accelerator.device,
)

return callbacks


# register
AccelerationPlugin.register_plugin(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Standard
from collections import defaultdict

# Third Party
import torch

# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


def ensure_weights_retied(
param_init_fn, model: torch.nn.Module, device: torch.cuda.device
):

_tied_names = model._tied_weights_keys
if not _tied_names:
# if no tied names just passthrough
return param_init_fn

# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
for name in _tied_names:
name = name.split(".")
name, param_name = ".".join(name[:-1]), name[-1]
mod = model.get_submodule(name)
param = getattr(mod, param_name)

_tied_params[id(param)] = None # placeholder for the param first

# build param_init_fn for the case with tied params
def param_init_fn_tied_param(module: torch.nn.Module):

# track which params to tie
# - usually only 1, but for completeness consider > 1
params_to_tie = defaultdict(list)
for n, param in module.named_parameters(recurse=False):
if id(param) in _tied_params:
params_to_tie[id(param)].append(n)

# call the param init fn, which potentially re-allocates the
# parameters
module = param_init_fn(module)

# search the parameters again and tie them up again
for id_key, _param_names in params_to_tie.items():
for param_name in _param_names:
param = _tied_params[id_key]
if param is None:
# everything will be tied to the first time the
# param is observed
_tied_params[id_key] = getattr(module, param_name)
else:
setattr(module, param_name, param) # tie

return module

return param_init_fn_tied_param
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.modeling_utils import (
is_local_dist_rank_0,
no_init_weights,
shard_checkpoint,
)
from transformers.utils.generic import ContextManagers
import accelerate
import torch
Expand Down Expand Up @@ -1105,45 +1109,50 @@ def skip(*args, **kwargs):
# prepares the model on gpu in `trainer.train` to avoid unnecessary gpu usage
device_map = {"": "cpu"}

load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype,
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if (
not quantize_config.sym
and not quantize_config.is_quantized_or_packed_by_v2()
):
raise ValueError(
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
# low_cpu_mem_usage fix by flim@sg.ibm.com
# - load the checkpoint only if not low_cpu_mem_usage
# - or if low_cpu_mem_usage then only in the rank_0
if not low_cpu_mem_usage or is_local_dist_rank_0():
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype,
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if (
not quantize_config.sym
and not quantize_config.is_quantized_or_packed_by_v2()
):
raise ValueError(
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)

logger.info(
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
)
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
qlinear_kernel=preload_qlinear_kernel,
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2
logger.info(
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
)
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
qlinear_kernel=preload_qlinear_kernel,
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2

if not load_checkpoint_in_model and backend == Backend.TRITON:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)

if not load_checkpoint_in_model and backend == Backend.TRITON:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# TODO: Why are we using this custom function and not dispatch_model?
model = simple_dispatch_model(model, device_map)
# TODO: Why are we using this custom function and not dispatch_model?
model = simple_dispatch_model(model, device_map)

qlinear_kernel = select_quant_linear(
bits=quantize_config.bits,
Expand Down
Loading
Loading