From 98885648f4f596f91759c3b558d4c06650ca7c1f Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Mon, 27 May 2024 18:45:06 +0000 Subject: [PATCH] linting and formatting changes --- .github/workflows/format.yml | 2 +- plugins/accelerated-peft/.pylintrc | 649 ++++++++++++++++++ .../fms_acceleration_peft/autogptq_utils.py | 13 +- .../framework_plugin_autogptq.py | 57 +- .../framework_plugin_bnb.py | 26 +- plugins/accelerated-peft/tests/__init__.py | 13 + .../tests/test_peft_plugins.py | 2 +- plugins/accelerated-peft/tox.ini | 13 +- 8 files changed, 726 insertions(+), 49 deletions(-) create mode 100644 plugins/accelerated-peft/.pylintrc create mode 100644 plugins/accelerated-peft/tests/__init__.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 2ec2bbca..294a0f6d 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -27,7 +27,7 @@ jobs: matrix: plugin_name: - "framework" - # - "accelerated-peft" # enable later + - "accelerated-peft" steps: - uses: actions/checkout@v4 diff --git a/plugins/accelerated-peft/.pylintrc b/plugins/accelerated-peft/.pylintrc new file mode 100644 index 00000000..45da4212 --- /dev/null +++ b/plugins/accelerated-peft/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index e3b2dc6d..c4b497fc 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -15,10 +15,12 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ +# Standard +from typing import Callable, List + # Third Party from peft import LoraConfig from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ -from typing import List, Callable import torch @@ -55,10 +57,12 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module + # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, - attribute_names: List[str], torch_dtype, + attribute_names: List[str], + torch_dtype, ): # patch old_forward to view attribtues to torch_dype # before call @@ -67,7 +71,7 @@ def _forward(self, *args, **kwargs): # perform a view on all these attributes for attr_name in attribute_names: - # the view should be a passthrough + # the view should be a passthrough # if attr.dtype == torch_dtype attr = getattr(self, attr_name) @@ -80,6 +84,7 @@ def _forward(self, *args, **kwargs): # this means already have attr_name as a parameter, then # just assign this way self.__dict__[attr_name] = attr - + return old_forward(*args, **kwargs) + return _forward diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index fa6082ab..734d8110 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -20,15 +20,15 @@ from functools import partial from types import MethodType from typing import Dict, Tuple +import os # Third Party from fms_acceleration import AccelerationPlugin from peft import LoraConfig, prepare_model_for_kbit_training from peft.tuners.lora.model import LoraModel -import torch.distributed from transformers import AutoModelForCausalLM, TrainingArguments import torch -import os +import torch.distributed class AutoGPTQAccelerationPlugin(AccelerationPlugin): @@ -51,9 +51,11 @@ def model_loader(self, model_name: str, **kwargs): # guarded imports # Third Party - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction - from .autogptq_utils import patch_forward_to_view_attributes_before_call + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel + + # Local + from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel # Currently we allow only a quantized checkpoint to be loaded, we do not # implement the quantization process here. @@ -61,16 +63,18 @@ def model_loader(self, model_name: str, **kwargs): # The quantization process is used to convert a non-quantized checkpoint # (provided in model_name) into a quantized one. This entails # 1. providing a BaseQuantizeConfig with the appropriate quantization settings - # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm (may take time, e.g. hours) + # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm + # (may take time, e.g. hours) # 3. calling BaseGPTQForCausalLM.save_pretrained to save a quantized checkpoint # # The reasons for not implementing the flow at this point are. # 1. The quantization can take very long for large models. As such, it is more appropriate - # to run it once outside of training, and save the checkpoint to be used for multiple runs. + # to run it once outside of training, and save the checkpoint to be used for multiple runs. # 2. Requires some API changes to point to where the quantized checkpoint should be saved. # Can be confusing to the user since it will be different from model_name # NOTE: there will be a warning that can be ignored - # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet, switching to cuda/cuda_old/triton backend." + # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet, + # switching to cuda/cuda_old/triton backend." # assume model_name points to a quantized checkpoint. Thus we load the quantization # config directly from the checkpoint. quantize_config = BaseQuantizeConfig.from_pretrained(model_name) @@ -81,8 +85,10 @@ def model_loader(self, model_name: str, **kwargs): attn_implementation = kwargs.get("attn_implementation") if low_cpu_mem_usage: - # Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled. - # e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990 + # Note that low_cpu_mem_usage is typically set to + # transformers.modeling_utils.is_fsdp_enabled. + # e.g., + # https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990 # but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51 # which does not properly check if a QuantLayer has a bias set or not, @@ -95,16 +101,16 @@ def model_loader(self, model_name: str, **kwargs): # there are some kwargs that we wont be passed to AutoModel, so we need # to patch them in _old_from_config = AutoModelForCausalLM.from_config - # Standard - from functools import partial _from_config = partial( _old_from_config, attn_implementation=attn_implementation ) AutoModelForCausalLM.from_config = _from_config # patch - # NOTE: need to set the device map as below as we want to use AutoGPTQ for training. - # device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference + # NOTE: need to set the device map as below as we want to + # use AutoGPTQ for training. + # device_map is for inference only + # ref: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference # Thus we set it as below to effectively disable it. device_map = ( {"": torch.cuda.current_device()} if torch.cuda.is_available() else None @@ -119,14 +125,14 @@ def model_loader(self, model_name: str, **kwargs): low_cpu_mem_usage=low_cpu_mem_usage, use_marlin=False, # disable, cannot be used for training (no forward+backward) disable_exllama=True, # disable, cannot be used for training (no backward) - warmup_triton=False, # disable for now, because it will try to run the warmup while on CPU + warmup_triton=False, # disable for now as it will try to run the warmup while on CPU use_tritonv2=True, trainable=True, # only support trainable mode device_map=device_map, ) # https://github.com/foundation-model-stack/fms-acceleration/pull/15 - # if FSDP distributed need to convert the AutoGPTQ model's + # if FSDP distributed need to convert the AutoGPTQ model's # parameters (in tensors) to parameters. Also need to # store the int32 tensors in a float type @@ -141,7 +147,7 @@ def model_loader(self, model_name: str, **kwargs): ): # these parameters are to be patched for triton v2 # consider making a map if patching more kernels - PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros'] + PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] # patch all the QuantLinear base layers for mod in model.modules(): @@ -151,14 +157,17 @@ def model_loader(self, model_name: str, **kwargs): # so FSDP can shard them for attr_name in PATCH_FOR_FSDP_TRITON_V2: attr = getattr(mod, attr_name) - attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False) + attr = torch.nn.Parameter( + attr.view(torch_dtype), requires_grad=False + ) setattr(mod, attr_name, attr) - # this patches the forward to convert them back to original + # this patches the forward to convert them back to original # type (i.e. int32) before the function call into the kernels _forward = patch_forward_to_view_attributes_before_call( - mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2, - torch_dtype=torch.int32, # patch it back to + mod.forward, + attribute_names=PATCH_FOR_FSDP_TRITON_V2, + torch_dtype=torch.int32, # patch it back to ) mod.forward = MethodType(_forward, mod) @@ -193,11 +202,11 @@ def augmentation( ): # guarded imports # Third Party - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear - from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel + from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel # Local - from .autogptq_utils import create_new_module_peft, replace_module_peft + from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel (peft_config,) = modifiable_args # unpack modifiable args diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index dfd5fbc8..6e71d11a 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -23,7 +23,7 @@ # Third Party from fms_acceleration import AccelerationPlugin -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments import torch @@ -41,7 +41,7 @@ def _prepare_model_for_kbit_training( if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} - for name, param in model.named_parameters(): + for _, param in model.named_parameters(): # freeze base model's layers param.requires_grad = False @@ -56,22 +56,24 @@ def _prepare_model_for_kbit_training( model.enable_input_require_grads() else: - def make_inputs_require_grad(module, input, output): + def make_inputs_require_grad(_module, _input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook( make_inputs_require_grad ) - # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs + # To support older transformers versions, + # check if the model supports gradient_checkpointing_kwargs _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( inspect.signature(model.gradient_checkpointing_enable).parameters ) if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: warnings.warn( - "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." - " if you want to use that feature, please upgrade to the latest version of transformers.", + "gradient_checkpointing_kwargs is not supported in this version of transformers.", + "The passed kwargs will be ignored. if you want to use that feature,", + "please upgrade to the latest version of transformers.", FutureWarning, ) @@ -124,16 +126,14 @@ def model_loader(self, model_name: str, **kwargs): "If running in FSDP, this is probably because accelerate is not used. " "This will most probably result in error." ) - elif ( - world_size == 1 - and self._no_peft_model == True - ): + elif world_size == 1 and self._no_peft_model is True: warnings.warn( """Running on single device and setting plugin config `no_peft_model` as `True` - PEFT preparation will be managed by SFTTrainer and will cause a slowdown in training speed - due to extraneous dtype casting when SFTTrainer prepares the model using + PEFT preparation will be managed by SFTTrainer and + will cause a slowdown in training speed due to + extraneous dtype casting when SFTTrainer prepares the model using https://github.com/huggingface/trl/blob/e90e8d91d2265e484f229c45a5eb8982f94a2936/trl/trainer/sft_trainer.py#L210""" - ) + ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, diff --git a/plugins/accelerated-peft/tests/__init__.py b/plugins/accelerated-peft/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/accelerated-peft/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/accelerated-peft/tests/test_peft_plugins.py b/plugins/accelerated-peft/tests/test_peft_plugins.py index 42404ddc..bb0621e5 100644 --- a/plugins/accelerated-peft/tests/test_peft_plugins.py +++ b/plugins/accelerated-peft/tests/test_peft_plugins.py @@ -134,7 +134,7 @@ def test_configure_bnb_plugin(): require_packages_check=False, ): # check flags and callbacks - assert (not correct_value)==framework.requires_agumentation + assert (not correct_value) == framework.requires_agumentation # attempt to activate plugin with configuration pointing to wrong path # - raise with message that no plugins can be configured diff --git a/plugins/accelerated-peft/tox.ini b/plugins/accelerated-peft/tox.ini index b79d0691..0c303352 100644 --- a/plugins/accelerated-peft/tox.ini +++ b/plugins/accelerated-peft/tox.ini @@ -4,23 +4,24 @@ envlist = py, lint [testenv] deps = pytest>=7 - # for the tests, we need to install the deps ourselves # as the package will install the github version -e {toxinidir}/../framework -skip_install = true +skip_install = true # set skip install as it will install AutoGPTQ before deps, will throw error when AutoGPTQ needs torch commands = - # install the current package pip install --no-deps {toxinidir} - pytest {posargs:tests} -[testenv:lint] +[testenv:lint] description = run linters deps = + -e {toxinidir}/../framework + pytest>=7 pylint>=2.16.2,<=3.1.0 -commands = pylint src tests +commands = + pip install -e {toxinidir} + pylint src tests allowlist_externals = pylint [testenv:fmt]