Skip to content

Commit

Permalink
Fix FSDP when performing GPTQ-LoRA with Triton V2 (#15)
Browse files Browse the repository at this point in the history
* wrap in parameters and torch view to correct dtype

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* refactor to apply patch only on FSDP and simplify

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

---------

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim authored May 21, 2024
1 parent d510ceb commit 2003a3e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
from transformers.utils.import_utils import _is_package_available
from typing import List, Callable
import torch


Expand Down Expand Up @@ -54,3 +54,32 @@ def create_new_module_peft(

# if module cannot be found, return None which results in a raise in the call-stack
return new_module

# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
attribute_names: List[str], torch_dtype,
):
# patch old_forward to view attribtues to torch_dype
# before call

def _forward(self, *args, **kwargs):
# perform a view on all these attributes
for attr_name in attribute_names:

# the view should be a passthrough
# if attr.dtype == torch_dtype
attr = getattr(self, attr_name)

# perform view
attr = attr.view(torch_dtype)

try:
setattr(self, attr_name, attr)
except TypeError:
# this means already have attr_name as a parameter, then
# just assign this way
self.__dict__[attr_name] = attr

return old_forward(*args, **kwargs)
return _forward
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
import torch.distributed
from transformers import AutoModelForCausalLM, TrainingArguments
import torch
import os


class AutoGPTQAccelerationPlugin(AccelerationPlugin):
Expand All @@ -50,6 +52,8 @@ def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction
from .autogptq_utils import patch_forward_to_view_attributes_before_call

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -121,6 +125,43 @@ def model_loader(self, model_name: str, **kwargs):
device_map=device_map,
)

# https://github.com/foundation-model-stack/fms-acceleration/pull/15
# if FSDP distributed need to convert the AutoGPTQ model's
# parameters (in tensors) to parameters. Also need to
# store the int32 tensors in a float type

try:
world_size = torch.distributed.get_world_size()
except ValueError:
world_size = 1 # pg not init

if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros']

# patch all the QuantLinear base layers
for mod in model.modules():
if isinstance(mod, QuantLinear):

# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False)
setattr(mod, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config

Expand Down

0 comments on commit 2003a3e

Please sign in to comment.