From 343c852908f5e1984003ca78546452d9f9ab28dc Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 14 Oct 2024 03:11:49 +0000 Subject: [PATCH] fmt + lint Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_autogptq.py | 31 +++++++++++-------- .../framework_plugin_bnb.py | 27 ++++++++++------ .../src/fms_acceleration_peft/fsdp_utils.py | 25 ++++++++------- .../gptqmodel/models/base.py | 6 +++- 4 files changed, 55 insertions(+), 34 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 8fd10df..e1fd277 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -62,10 +62,6 @@ def __init__(self, configurations: Dict[str, Dict]): ) if self.use_external_lib: - # Third Party - from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel - _is_package_available, - ) assert _is_package_available("auto_gptq") is True, ( "Unable to use external library, auto_gptq module not found. " @@ -360,28 +356,37 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - _, _transformers_version = _is_package_available("transformers", return_version=True) - _trl_installed, _trl_version = _is_package_available("trl", return_version=True) + _, _transformers_version = _is_package_available( + "transformers", return_version=True + ) + _trl_installed, _trl_version = _is_package_available( + "trl", return_version=True + ) # the meta device fix for quantized models is since this transformers version # or if trl is installed then its only for this version - if ( - _transformers_version >= "4.45" and ( - not _trl_installed or (_trl_installed and _trl_version >= "0.12") - ) + if _transformers_version >= "4.45" and ( + not _trl_installed or (_trl_installed and _trl_version >= "0.12") ): # guarded # NOTE: replace this later with a more specific accelerate version check try: + # Third Party + # pylint: disable=import-outside-toplevel from torch.distributed.utils import ensure_weights_retied + # then its handled internally and there is nothing to do except ImportError: # need to use our internal version - from .fsdp_utils import ensure_weights_retied + # Local + from .fsdp_utils import ( # pylint: disable=import-outside-toplevel + ensure_weights_retied, + ) + accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied( - accelerator.state.fsdp_plugin.param_init_fn, + accelerator.state.fsdp_plugin.param_init_fn, model.get_base_model(), - accelerator.device + accelerator.device, ) return callbacks diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index 8560bd4..3a4c12a 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -201,28 +201,37 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - _, _transformers_version = _is_package_available("transformers", return_version=True) - _trl_installed, _trl_version = _is_package_available("trl", return_version=True) + _, _transformers_version = _is_package_available( + "transformers", return_version=True + ) + _trl_installed, _trl_version = _is_package_available( + "trl", return_version=True + ) # the meta device fix for quantized models is since this transformers version # or if trl is installed then its only for this version - if ( - _transformers_version >= "4.45" and ( - not _trl_installed or (_trl_installed and _trl_version >= "0.12") - ) + if _transformers_version >= "4.45" and ( + not _trl_installed or (_trl_installed and _trl_version >= "0.12") ): # guarded # NOTE: replace this later with a more specific accelerate version check try: + # Third Party + # pylint: disable=import-outside-toplevel from torch.distributed.utils import ensure_weights_retied + # then its handled internally and there is nothing to do except ImportError: # need to use our internal version - from .fsdp_utils import ensure_weights_retied + # Local + from .fsdp_utils import ( # pylint: disable=import-outside-toplevel + ensure_weights_retied, + ) + accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied( - accelerator.state.fsdp_plugin.param_init_fn, + accelerator.state.fsdp_plugin.param_init_fn, model if self._no_peft_model else model.get_base_model(), - accelerator.device + accelerator.device, ) return callbacks diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py index 95a22eb..3086cf7 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py @@ -1,6 +1,8 @@ +# Standard from collections import defaultdict -import torch +# Third Party +import torch # Copyright The IBM Tuning Team # @@ -19,6 +21,7 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ + def ensure_weights_retied( param_init_fn, model: torch.nn.Module, device: torch.cuda.device ): @@ -28,28 +31,28 @@ def ensure_weights_retied( # if no tied names just passthrough return param_init_fn - # get map of parameter instances to params. + # get map of parameter instances to params. # - needed for replacement later _tied_params = {} for name in _tied_names: - name = name.split('.') - name, param_name = '.'.join(name[:-1]), name[-1] + name = name.split(".") + name, param_name = ".".join(name[:-1]), name[-1] mod = model.get_submodule(name) param = getattr(mod, param_name) - _tied_params[id(param)] = None # placeholder for the param first - + _tied_params[id(param)] = None # placeholder for the param first + # build param_init_fn for the case with tied params def param_init_fn_tied_param(module: torch.nn.Module): - # track which params to tie + # track which params to tie # - usually only 1, but for completeness consider > 1 params_to_tie = defaultdict(list) for n, param in module.named_parameters(recurse=False): if id(param) in _tied_params: params_to_tie[id(param)].append(n) - # call the param init fn, which potentially re-allocates the + # call the param init fn, which potentially re-allocates the # parameters module = param_init_fn(module) @@ -62,8 +65,8 @@ def param_init_fn_tied_param(module: torch.nn.Module): # param is observed _tied_params[id_key] = getattr(module, param_name) else: - setattr(module, param_name, param) # tie - + setattr(module, param_name, param) # tie + return module - return param_init_fn_tied_param \ No newline at end of file + return param_init_fn_tied_param diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py index 8782744..85c17ee 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py @@ -32,7 +32,11 @@ PretrainedConfig, PreTrainedModel, ) -from transformers.modeling_utils import no_init_weights, shard_checkpoint, is_local_dist_rank_0 +from transformers.modeling_utils import ( + is_local_dist_rank_0, + no_init_weights, + shard_checkpoint, +) from transformers.utils.generic import ContextManagers import accelerate import torch