From c3c1cddabb679bfe8dbbd4cebda06b808183062f Mon Sep 17 00:00:00 2001 From: achew010 <165894159+achew010@users.noreply.github.com> Date: Tue, 28 May 2024 15:18:48 +0800 Subject: [PATCH 1/8] Linting and Formatting for FMS-Acceleration-Peft package (#24) * linting and formatting changes * removed AutoGPTQ dep in linting * added additional comments in tox --- .github/workflows/format.yml | 2 +- plugins/accelerated-peft/.pylintrc | 649 ++++++++++++++++++ .../fms_acceleration_peft/autogptq_utils.py | 13 +- .../framework_plugin_autogptq.py | 57 +- .../framework_plugin_bnb.py | 26 +- plugins/accelerated-peft/tests/__init__.py | 13 + .../tests/test_peft_plugins.py | 2 +- plugins/accelerated-peft/tox.ini | 16 +- 8 files changed, 729 insertions(+), 49 deletions(-) create mode 100644 plugins/accelerated-peft/.pylintrc create mode 100644 plugins/accelerated-peft/tests/__init__.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 2ec2bbca..294a0f6d 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -27,7 +27,7 @@ jobs: matrix: plugin_name: - "framework" - # - "accelerated-peft" # enable later + - "accelerated-peft" steps: - uses: actions/checkout@v4 diff --git a/plugins/accelerated-peft/.pylintrc b/plugins/accelerated-peft/.pylintrc new file mode 100644 index 00000000..45da4212 --- /dev/null +++ b/plugins/accelerated-peft/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index e3b2dc6d..c4b497fc 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -15,10 +15,12 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ +# Standard +from typing import Callable, List + # Third Party from peft import LoraConfig from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ -from typing import List, Callable import torch @@ -55,10 +57,12 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module + # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, - attribute_names: List[str], torch_dtype, + attribute_names: List[str], + torch_dtype, ): # patch old_forward to view attribtues to torch_dype # before call @@ -67,7 +71,7 @@ def _forward(self, *args, **kwargs): # perform a view on all these attributes for attr_name in attribute_names: - # the view should be a passthrough + # the view should be a passthrough # if attr.dtype == torch_dtype attr = getattr(self, attr_name) @@ -80,6 +84,7 @@ def _forward(self, *args, **kwargs): # this means already have attr_name as a parameter, then # just assign this way self.__dict__[attr_name] = attr - + return old_forward(*args, **kwargs) + return _forward diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index fa6082ab..89ea2862 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -20,15 +20,15 @@ from functools import partial from types import MethodType from typing import Dict, Tuple +import os # Third Party from fms_acceleration import AccelerationPlugin from peft import LoraConfig, prepare_model_for_kbit_training from peft.tuners.lora.model import LoraModel -import torch.distributed from transformers import AutoModelForCausalLM, TrainingArguments import torch -import os +import torch.distributed class AutoGPTQAccelerationPlugin(AccelerationPlugin): @@ -51,9 +51,11 @@ def model_loader(self, model_name: str, **kwargs): # guarded imports # Third Party - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction - from .autogptq_utils import patch_forward_to_view_attributes_before_call + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error + + # Local + from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel # Currently we allow only a quantized checkpoint to be loaded, we do not # implement the quantization process here. @@ -61,16 +63,18 @@ def model_loader(self, model_name: str, **kwargs): # The quantization process is used to convert a non-quantized checkpoint # (provided in model_name) into a quantized one. This entails # 1. providing a BaseQuantizeConfig with the appropriate quantization settings - # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm (may take time, e.g. hours) + # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm + # (may take time, e.g. hours) # 3. calling BaseGPTQForCausalLM.save_pretrained to save a quantized checkpoint # # The reasons for not implementing the flow at this point are. # 1. The quantization can take very long for large models. As such, it is more appropriate - # to run it once outside of training, and save the checkpoint to be used for multiple runs. + # to run it once outside of training, and save the checkpoint to be used for multiple runs. # 2. Requires some API changes to point to where the quantized checkpoint should be saved. # Can be confusing to the user since it will be different from model_name # NOTE: there will be a warning that can be ignored - # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet, switching to cuda/cuda_old/triton backend." + # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet, + # switching to cuda/cuda_old/triton backend." # assume model_name points to a quantized checkpoint. Thus we load the quantization # config directly from the checkpoint. quantize_config = BaseQuantizeConfig.from_pretrained(model_name) @@ -81,8 +85,10 @@ def model_loader(self, model_name: str, **kwargs): attn_implementation = kwargs.get("attn_implementation") if low_cpu_mem_usage: - # Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled. - # e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990 + # Note that low_cpu_mem_usage is typically set to + # transformers.modeling_utils.is_fsdp_enabled. + # e.g., + # https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990 # but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51 # which does not properly check if a QuantLayer has a bias set or not, @@ -95,16 +101,16 @@ def model_loader(self, model_name: str, **kwargs): # there are some kwargs that we wont be passed to AutoModel, so we need # to patch them in _old_from_config = AutoModelForCausalLM.from_config - # Standard - from functools import partial _from_config = partial( _old_from_config, attn_implementation=attn_implementation ) AutoModelForCausalLM.from_config = _from_config # patch - # NOTE: need to set the device map as below as we want to use AutoGPTQ for training. - # device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference + # NOTE: need to set the device map as below as we want to + # use AutoGPTQ for training. + # device_map is for inference only + # ref: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference # Thus we set it as below to effectively disable it. device_map = ( {"": torch.cuda.current_device()} if torch.cuda.is_available() else None @@ -119,14 +125,14 @@ def model_loader(self, model_name: str, **kwargs): low_cpu_mem_usage=low_cpu_mem_usage, use_marlin=False, # disable, cannot be used for training (no forward+backward) disable_exllama=True, # disable, cannot be used for training (no backward) - warmup_triton=False, # disable for now, because it will try to run the warmup while on CPU + warmup_triton=False, # disable for now as it will try to run the warmup while on CPU use_tritonv2=True, trainable=True, # only support trainable mode device_map=device_map, ) # https://github.com/foundation-model-stack/fms-acceleration/pull/15 - # if FSDP distributed need to convert the AutoGPTQ model's + # if FSDP distributed need to convert the AutoGPTQ model's # parameters (in tensors) to parameters. Also need to # store the int32 tensors in a float type @@ -141,7 +147,7 @@ def model_loader(self, model_name: str, **kwargs): ): # these parameters are to be patched for triton v2 # consider making a map if patching more kernels - PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros'] + PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] # patch all the QuantLinear base layers for mod in model.modules(): @@ -151,14 +157,17 @@ def model_loader(self, model_name: str, **kwargs): # so FSDP can shard them for attr_name in PATCH_FOR_FSDP_TRITON_V2: attr = getattr(mod, attr_name) - attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False) + attr = torch.nn.Parameter( + attr.view(torch_dtype), requires_grad=False + ) setattr(mod, attr_name, attr) - # this patches the forward to convert them back to original + # this patches the forward to convert them back to original # type (i.e. int32) before the function call into the kernels _forward = patch_forward_to_view_attributes_before_call( - mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2, - torch_dtype=torch.int32, # patch it back to + mod.forward, + attribute_names=PATCH_FOR_FSDP_TRITON_V2, + torch_dtype=torch.int32, # patch it back to ) mod.forward = MethodType(_forward, mod) @@ -193,11 +202,11 @@ def augmentation( ): # guarded imports # Third Party - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear - from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error + from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error # Local - from .autogptq_utils import create_new_module_peft, replace_module_peft + from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel (peft_config,) = modifiable_args # unpack modifiable args diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index dfd5fbc8..6e71d11a 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -23,7 +23,7 @@ # Third Party from fms_acceleration import AccelerationPlugin -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments import torch @@ -41,7 +41,7 @@ def _prepare_model_for_kbit_training( if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} - for name, param in model.named_parameters(): + for _, param in model.named_parameters(): # freeze base model's layers param.requires_grad = False @@ -56,22 +56,24 @@ def _prepare_model_for_kbit_training( model.enable_input_require_grads() else: - def make_inputs_require_grad(module, input, output): + def make_inputs_require_grad(_module, _input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook( make_inputs_require_grad ) - # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs + # To support older transformers versions, + # check if the model supports gradient_checkpointing_kwargs _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( inspect.signature(model.gradient_checkpointing_enable).parameters ) if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: warnings.warn( - "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." - " if you want to use that feature, please upgrade to the latest version of transformers.", + "gradient_checkpointing_kwargs is not supported in this version of transformers.", + "The passed kwargs will be ignored. if you want to use that feature,", + "please upgrade to the latest version of transformers.", FutureWarning, ) @@ -124,16 +126,14 @@ def model_loader(self, model_name: str, **kwargs): "If running in FSDP, this is probably because accelerate is not used. " "This will most probably result in error." ) - elif ( - world_size == 1 - and self._no_peft_model == True - ): + elif world_size == 1 and self._no_peft_model is True: warnings.warn( """Running on single device and setting plugin config `no_peft_model` as `True` - PEFT preparation will be managed by SFTTrainer and will cause a slowdown in training speed - due to extraneous dtype casting when SFTTrainer prepares the model using + PEFT preparation will be managed by SFTTrainer and + will cause a slowdown in training speed due to + extraneous dtype casting when SFTTrainer prepares the model using https://github.com/huggingface/trl/blob/e90e8d91d2265e484f229c45a5eb8982f94a2936/trl/trainer/sft_trainer.py#L210""" - ) + ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, diff --git a/plugins/accelerated-peft/tests/__init__.py b/plugins/accelerated-peft/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/accelerated-peft/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/accelerated-peft/tests/test_peft_plugins.py b/plugins/accelerated-peft/tests/test_peft_plugins.py index 42404ddc..bb0621e5 100644 --- a/plugins/accelerated-peft/tests/test_peft_plugins.py +++ b/plugins/accelerated-peft/tests/test_peft_plugins.py @@ -134,7 +134,7 @@ def test_configure_bnb_plugin(): require_packages_check=False, ): # check flags and callbacks - assert (not correct_value)==framework.requires_agumentation + assert (not correct_value) == framework.requires_agumentation # attempt to activate plugin with configuration pointing to wrong path # - raise with message that no plugins can be configured diff --git a/plugins/accelerated-peft/tox.ini b/plugins/accelerated-peft/tox.ini index b79d0691..eb53996e 100644 --- a/plugins/accelerated-peft/tox.ini +++ b/plugins/accelerated-peft/tox.ini @@ -4,23 +4,27 @@ envlist = py, lint [testenv] deps = pytest>=7 - # for the tests, we need to install the deps ourselves # as the package will install the github version -e {toxinidir}/../framework -skip_install = true +# set skip package installation as it will install package pyproject.toml before deps, will throw error when AutoGPTQ needs torch +skip_install = true commands = - # install the current package pip install --no-deps {toxinidir} - pytest {posargs:tests} -[testenv:lint] +[testenv:lint] description = run linters deps = + -e {toxinidir}/../framework + pytest>=7 pylint>=2.16.2,<=3.1.0 -commands = pylint src tests +commands = + # installs package without autogptq dep to lint without CUDA, + # autogptq pylint import-errors are disabled inline + pip install --no-deps {toxinidir} + pylint src tests allowlist_externals = pylint [testenv:fmt] From 25171a048ceb433151e645b8d26464869afb5a66 Mon Sep 17 00:00:00 2001 From: achew010 <165894159+achew010@users.noreply.github.com> Date: Wed, 29 May 2024 18:10:16 +0800 Subject: [PATCH 2/8] Workaround Low-Mem-Mode Patch for GPTQ-LoRA (#26) * workaround low-mem patch * resolve conflicts and define patch function * resolve conflicts and define patch function * Apply suggestions from code review Co-authored-by: Yu Chin Fabian Lim * revert hack to avoid low memory bug in HF memory metrics calculation * reversed formatting * reverse more formatting --------- Co-authored-by: Yu Chin Fabian Lim --- .../fms_acceleration_peft/autogptq_utils.py | 65 +++++++++++++++++- .../framework_plugin_autogptq.py | 66 ++++++++++++------- scripts/benchmarks/README.md | 3 +- scripts/benchmarks/benchmark.py | 4 +- 4 files changed, 109 insertions(+), 29 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index c4b497fc..31fc9a74 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -16,7 +16,8 @@ # https://spdx.dev/learn/handling-license-info/ # Standard -from typing import Callable, List +from typing import Any, Callable, List +import importlib # Third Party from peft import LoraConfig @@ -24,6 +25,68 @@ import torch +# This function may be moved after merging +# https://github.com/foundation-model-stack/fms-acceleration/pull/25 +def _patch_target_module( + to_patch: str, + replace_with: Any, + target_module: str = None, +): + to_patch = to_patch.split(".") + assert len(to_patch) > 1, "must have an object to patch" + + to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1] + to_patch = ".".join(to_patch) + source = importlib.import_module(to_patch) + original_obj = getattr(source, obj_name_to_patch) + setattr(source, obj_name_to_patch, replace_with) + + if target_module is not None: + # reload and this should get the patched object + target_module = importlib.import_module(target_module) + importlib.reload(target_module) + + # replace it + setattr(source, obj_name_to_patch, original_obj) + + +def make_sure_no_tensor_in_meta_device( + model, + use_triton: bool, + desc_act: bool, + group_size: int, + bits: int, + disable_exllama: bool, + disable_exllamav2: bool, + use_marlin: bool = False, + use_tritonv2: bool = False, +): + # Third Party + # guarded import + from auto_gptq.utils.import_utils import ( # pylint: disable=import-outside-toplevel,import-error + dynamically_import_QuantLinear, + ) + + QuantLinear = dynamically_import_QuantLinear( + use_triton, + desc_act, + group_size, + bits=bits, + disable_exllama=disable_exllama, + disable_exllamav2=disable_exllamav2, + use_marlin=use_marlin, + use_tritonv2=use_tritonv2, + ) + for _, m in model.named_modules(): + bias = getattr(m, "bias", None) + if bias: + if isinstance(m, QuantLinear) and bias.device == torch.device("meta"): + m.register_buffer( + "bias", + torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"), + ) + + def replace_module_peft(self, parent_module, child_name, new_module, old_module): # replace the lora linear diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 89ea2862..62e7abe3 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -27,6 +27,7 @@ from peft import LoraConfig, prepare_model_for_kbit_training from peft.tuners.lora.model import LoraModel from transformers import AutoModelForCausalLM, TrainingArguments +from transformers.modeling_utils import is_fsdp_enabled import torch import torch.distributed @@ -48,14 +49,15 @@ def __init__(self, configurations: Dict[str, Dict]): ) def model_loader(self, model_name: str, **kwargs): - # guarded imports # Third Party from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error # Local - from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel + from .autogptq_utils import ( # pylint: disable=import-outside-toplevel + patch_forward_to_view_attributes_before_call, + ) # Currently we allow only a quantized checkpoint to be loaded, we do not # implement the quantization process here. @@ -84,20 +86,6 @@ def model_loader(self, model_name: str, **kwargs): low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage") attn_implementation = kwargs.get("attn_implementation") - if low_cpu_mem_usage: - # Note that low_cpu_mem_usage is typically set to - # transformers.modeling_utils.is_fsdp_enabled. - # e.g., - # https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990 - # but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device - # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51 - # which does not properly check if a QuantLayer has a bias set or not, - # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514 - raise ValueError( - "low_cpu_mem_usage set to True. This may raise error if model has no bias, " - "due to AutoGPTQ bug. Not supporting at the moment." - ) - # there are some kwargs that we wont be passed to AutoModel, so we need # to patch them in _old_from_config = AutoModelForCausalLM.from_config @@ -107,14 +95,40 @@ def model_loader(self, model_name: str, **kwargs): ) AutoModelForCausalLM.from_config = _from_config # patch - # NOTE: need to set the device map as below as we want to - # use AutoGPTQ for training. + # this is a HF method that checks if the low_cpu_mem mode is enabled + # via HF accelerate + if is_fsdp_enabled(): + # Local + from .autogptq_utils import ( # pylint: disable=import-outside-toplevel + _patch_target_module, + make_sure_no_tensor_in_meta_device, + ) + + # We patch `make_sure_no_tensor_in_meta_device` + # from autogptq to avoid errors on models without bias + _patch_target_module( + to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device", + replace_with=make_sure_no_tensor_in_meta_device, + target_module="auto_gptq.modeling._base", + ) + low_cpu_mem_usage = True + + # NOTE: need to set the device map as below as we want to use AutoGPTQ for training. # device_map is for inference only - # ref: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference - # Thus we set it as below to effectively disable it. - device_map = ( - {"": torch.cuda.current_device()} if torch.cuda.is_available() else None - ) + # https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference + # For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu" + # to avoid gpu consumption before train + # This approach will divert consumption to cpu memory, + # a better approach would be to load the checkpoints to meta device + # QLoRA is currently implemented by the former approach and will encounter the same issue. + # see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262 + device_map = { + "": ( + (torch.cuda.current_device() if not low_cpu_mem_usage else "cpu") + if torch.cuda.is_available() + else None + ) + } # currently only enable triton_v2, because the triton kernels are the only ones # that have backwards @@ -204,9 +218,11 @@ def augmentation( # Third Party from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error - # Local - from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel + from .autogptq_utils import ( # pylint: disable=import-outside-toplevel + create_new_module_peft, + replace_module_peft, + ) (peft_config,) = modifiable_args # unpack modifiable args diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md index 115719b7..269d3ead 100644 --- a/scripts/benchmarks/README.md +++ b/scripts/benchmarks/README.md @@ -164,6 +164,7 @@ We currently compute the memory values in the report by taking the largest of su For allocated memory value ``` max([ + stage0_mem, stage0_mem + stage1_allocated_delta, stage0_mem + stage1_allocated_delta + stage2_allocated_delta, ... @@ -173,13 +174,13 @@ max([ For peak memory value ``` max([ + stage0_mem, stage0_mem + stage1_allocated_delta + stage1_peaked_delta, stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta, ... ]) ``` -Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14). diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index afbf61cf..64e34b0c 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -112,8 +112,8 @@ def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: return 0, 0 trainer_stage_order = [ - (HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False), - (HF_TRAINER_LOG_GPU_STAGE_INIT, False), + (HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, True), + (HF_TRAINER_LOG_GPU_STAGE_INIT, True), (HF_TRAINER_LOG_GPU_STAGE_TRAIN, True), ] alloc_running_sum = 0 From 70cdf71626da1bf3ee7b2e27892fc24c885dba0a Mon Sep 17 00:00:00 2001 From: achew010 <165894159+achew010@users.noreply.github.com> Date: Thu, 30 May 2024 10:24:39 +0800 Subject: [PATCH 3/8] Group memory field names with prefix and minor fixes (#27) * group memory field names with prefix and minor fixes * change to drop index on index reset --- README.md | 1 - scripts/benchmarks/benchmark.py | 6 +++--- scripts/benchmarks/display_bench_results.py | 8 ++++++-- scripts/run_benchmarks.sh | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c068f023..9b8eb699 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,6 @@ For example: - GPTQ-LoRA: 22-44 % token throughput increase on 1 GPU as compared to using Hugging Face BNB QLoRA - GPTQ-LoRA: Straightforward integration with multiple GPU as compared to using Hugging Face BNB QLoRA -*Huggingface BNB QLoRA numbers taken with legacy approaches, but we are aware of [this issue](https://github.com/foundation-model-stack/fms-acceleration/issues/10) and will update our benches*. *The above includes numbers using fusedOps-and-kernels and actual impl coming soon, see below*. **This package is in BETA and is under development. Expect breaking changes!** diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 64e34b0c..2594089d 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -77,7 +77,7 @@ GPU_LOG_USED_MEM_COLUMN_NAME = "memory.used [MiB]" GPU_LOG_METRIC_SUFFIX = " MiB" GPU_TABLE = "timestamp,name,index,memory.used" -RESULT_FIELD_RESERVED_GPU_MEM = "nvidia_mem_reserved" +RESULT_FIELD_RESERVED_GPU_MEM = "mem_nvidia_mem_reserved" RESULT_FIELD_DEVICE_NAME = "gpu_device_name" HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT = "before_init_mem_gpu" @@ -86,8 +86,8 @@ KEYWORD_PEAKED_DELTA = "peaked_delta" KEYWORD_ALLOC_DELTA = "alloc_delta" HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics" -RESULT_FIELD_ALLOCATED_GPU_MEM = "torch_mem_alloc_in_bytes" -RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "peak_torch_mem_alloc_in_bytes" +RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes" +RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes" def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py index b590f26c..1de9b2a5 100644 --- a/scripts/benchmarks/display_bench_results.py +++ b/scripts/benchmarks/display_bench_results.py @@ -22,7 +22,7 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns df = df.loc[df.error_messages.isna()] except: pass - df = df.reset_index().drop("output_dir", axis=1) + df = df.reset_index(drop=True).drop("output_dir", axis=1) df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False) print("***************** Report Created ******************") print(f"Total lines: '{len(df)}'") @@ -55,4 +55,8 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns ) args = parser.parse_args() - main(args.bench_outputs, output_filename=args.result_file, remove_columns=args.remove_columns) + main( + args.bench_outputs, + output_filename=args.result_file, + remove_columns=args.remove_columns, + ) diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index e08125b3..8cbd8587 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -38,7 +38,7 @@ PIP_REQUIREMENTS_FILE=requirements.txt DRY_RUN=${DRY_RUN:-"false"} NO_DATA_PROCESSING=${NO_DATA_PROCESSING:-"false"} NO_OVERWRITE=${NO_OVERWRITE:-"false"} -MEMORY_LOGGING=${MEMORY_LOGGING:-"huggingface"} +MEMORY_LOGGING=${MEMORY_LOGGING:-"all"} # inputs NUM_GPUS_MATRIX=${1-"1 2"} From 06a1af97f55ff3a31c38912dd843c0867a6d5150 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 30 May 2024 16:21:00 +0800 Subject: [PATCH 4/8] Initial Addition of FusedOps and Kernels Plugin With Model Patcher (#25) * initial commit Signed-off-by: Yu Chin Fabian Lim * add fast quantized plugin Signed-off-by: Yu Chin Fabian Lim * add mistral and fix plugin Signed-off-by: Yu Chin Fabian Lim * add licensing notices and instructions for adding new plugin. Signed-off-by: Yu Chin Fabian Lim * handle linting, formatting Signed-off-by: Yu Chin Fabian Lim * 2nd round of linting Signed-off-by: Yu Chin Fabian Lim * activate workflow and some more lint fixes Signed-off-by: Yu Chin Fabian Lim * add sample config Signed-off-by: Yu Chin Fabian Lim * updates to benchmark, scenarios Signed-off-by: Yu Chin Fabian Lim * fix tests Signed-off-by: Yu Chin Fabian Lim --------- Signed-off-by: Yu Chin Fabian Lim --- .github/workflows/format.yml | 1 + README.md | 2 +- plugins/framework/README.md | 50 ++ .../src/fms_acceleration/constants.py | 1 + plugins/fused-ops-and-kernels/.isort.cfg | 13 + plugins/fused-ops-and-kernels/.pylintrc | 650 +++++++++++++++ plugins/fused-ops-and-kernels/README.md | 60 ++ .../configs/fast_quantized_peft.yaml | 27 + plugins/fused-ops-and-kernels/pyproject.toml | 31 + .../src/fms_acceleration_foak/__init__.py | 16 + .../framework_plugin_fast_quantized_peft.py | 169 ++++ .../fused_ops/__init__.py | 13 + .../fused_ops/unsloth_lora/__init__.py | 22 + .../fused_ops/unsloth_lora/bnb/__init__.py | 24 + .../fused_ops/unsloth_lora/bnb/fast_lora.py | 396 ++++++++++ .../fused_ops/unsloth_lora/geglu.py | 202 +++++ .../fused_ops/unsloth_lora/gptq/__init__.py | 3 + .../fused_ops/unsloth_lora/gptq/fast_lora.py | 737 ++++++++++++++++++ .../unsloth_lora/gptq/triton/__init__.py | 3 + .../unsloth_lora/gptq/triton/kernels.py | 149 ++++ .../unsloth_lora/gptq/triton/layers.py | 170 ++++ .../unsloth_lora/gptq/triton/tuner.py | 425 ++++++++++ .../fused_ops/unsloth_lora/swiglu.py | 98 +++ .../fused_ops/unsloth_lora/utils.py | 247 ++++++ .../fms_acceleration_foak/kernels/__init__.py | 13 + .../kernels/unsloth/__init__.py | 17 + .../kernels/unsloth/cross_entropy_loss.py | 292 +++++++ .../kernels/unsloth/rms_layernorm.py | 192 +++++ .../kernels/unsloth/rope_embedding.py | 138 ++++ .../kernels/unsloth/utils.py | 29 + .../fms_acceleration_foak/models/__init__.py | 24 + .../src/fms_acceleration_foak/models/llama.py | 88 +++ .../fms_acceleration_foak/models/mistral.py | 94 +++ .../models/model_patcher.py | 470 +++++++++++ .../src/fms_acceleration_foak/models/utils.py | 164 ++++ .../fused-ops-and-kernels/tests/__init__.py | 13 + .../tests/test_foak_plugins.py | 84 ++ plugins/fused-ops-and-kernels/tox.ini | 42 + sample-configurations/CONTENTS.yaml | 8 +- ...ft-autogptq-foak-sample-configuration.yaml | 44 ++ scripts/benchmarks/benchmark.py | 2 +- scripts/benchmarks/scenarios.yaml | 2 +- scripts/generate_sample_configurations.py | 20 +- scripts/run_benchmarks.sh | 2 + tox.ini | 1 + 45 files changed, 5239 insertions(+), 9 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/.isort.cfg create mode 100644 plugins/fused-ops-and-kernels/.pylintrc create mode 100644 plugins/fused-ops-and-kernels/README.md create mode 100644 plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml create mode 100644 plugins/fused-ops-and-kernels/pyproject.toml create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py create mode 100644 plugins/fused-ops-and-kernels/tests/__init__.py create mode 100644 plugins/fused-ops-and-kernels/tests/test_foak_plugins.py create mode 100644 plugins/fused-ops-and-kernels/tox.ini create mode 100644 sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 294a0f6d..f0bab9d6 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -28,6 +28,7 @@ jobs: plugin_name: - "framework" - "accelerated-peft" + - "fused-ops-and-kernels" steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 9b8eb699..a7534ed1 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status --|--|--|--|-- [framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta [accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Beta - fusedOps-and-kernels | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon +[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon ## Usage with FMS HF Tuning diff --git a/plugins/framework/README.md b/plugins/framework/README.md index 2fe9cba0..2794f4fb 100644 --- a/plugins/framework/README.md +++ b/plugins/framework/README.md @@ -88,3 +88,53 @@ Each [package](#packages) in this monorepo: - When instantiating `fms_acceleration.AccelerationFramework`, it internally parses through the configuration stanzas. For plugins that are installed, it will instantiate them; for those that are not, it will simply *passthrough*. - `AccelerationFramework` will manage plugins transparently for user. User only needs to call `AccelerationFramework.model_loader` and `AccelerationFramework.augmentation`. + +## Adding New Plugins + +To add new plugins: + +1. Create an appropriately `pip`-packaged plugin in `plugins`; the package needs to be named like `fms-acceleration-` . +2. For `framework` to properly load and manage plugin, add the package `` to [constants.py](./src/fms_acceleration/constants.py): + + ```python + PLUGINS = [ + "peft", + "unsloth", + "", + ] + ``` +3. Create a sample template YAML file inside the `/configs` to demonstrate how to configure the plugin. As an example, reference the [sample config for accelerated peft](../accelerated-peft/configs/autogptq.yaml). +4. Update [generate_sample_configurations.py](../../scripts/generate_sample_configurations.py) and run `tox -e gen-configs` on the top level directory to generate the sample configurations. + + ```python + KEY_AUTO_GPTQ = "auto_gptq" + KEY_BNB_NF4 = "bnb-nf4" + PLUGIN_A = "" + + CONFIGURATIONS = { + KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", + KEY_BNB_NF4: ( + "plugins/accelerated-peft/configs/bnb.yaml", + [("peft.quantization.bitsandbytes.quant_type", "nf4")], + ), + PLUGIN_A: ( + "plugins//configs/plugin_config.yaml", + [ + (<1st field in plugin_config.yaml>, ), + (<2nd field in plugin_config.yaml>, ), + ] + ) + } + + # Passing a tuple of configuration keys will combine the templates together + COMBINATIONS = [ + ("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)), + ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)), + (<"combined name with your plugin">), (KEY_AUTO_GPTQ, PLUGIN_A) + (<"combined name with your plugin">), (KEY_BNB_NF4, PLUGIN_A) + ] + ``` +5. After sample configuration is generated by `tox -e gen-configs`, update [CONTENTS.yaml](../../sample-configurations/CONTENTS.yaml) with the shortname and the configuration fullpath. +6. Update [scenarios YAML](../../scripts/benchmarks/scenarios.yaml) to configure benchmark test scenarios that will be triggered when running `tox -e run-benches` on the top level directory. +7. Update the [top-level tox.ini](../../tox.ini) to install the plugin for the `run-benches`. + diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index 7fe2688a..9b5fa9cc 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,5 @@ PLUGINS = [ "peft", + "foak" ] diff --git a/plugins/fused-ops-and-kernels/.isort.cfg b/plugins/fused-ops-and-kernels/.isort.cfg new file mode 100644 index 00000000..4aa62fac --- /dev/null +++ b/plugins/fused-ops-and-kernels/.isort.cfg @@ -0,0 +1,13 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning + +# skip code imported from unsloth +skip_glob=**/unsloth*/** diff --git a/plugins/fused-ops-and-kernels/.pylintrc b/plugins/fused-ops-and-kernels/.pylintrc new file mode 100644 index 00000000..31cb902c --- /dev/null +++ b/plugins/fused-ops-and-kernels/.pylintrc @@ -0,0 +1,650 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +# NOTE: do not lint code imported from unsloth +ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth* + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md new file mode 100644 index 00000000..a1b01d94 --- /dev/null +++ b/plugins/fused-ops-and-kernels/README.md @@ -0,0 +1,60 @@ +# FMS Acceleration for Fused Operations and Kernels + +This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following: + + +1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth). + - Low-Rank Adapter Fused Operations + - Fast RoPE Triton Kernels + - Fast RMS LayerNorm Triton Kernels + - Fast Cross Entropy Triton Kernels + +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅ + +### Code Extracted from Unsloth + + + +Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth): +- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143). + ``` + it would require a commercial license if used to run on more than 4 GPUs, see + https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143 + ``` +- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc). + * These model files are **not extracted**. +- All code extracted here before the Feb 2024 Release, see table below. + +Path | Description | Extracted From | Modifications | Date +--|--|--|--|-- +[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 +[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 +[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024 +[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024 + + + + +## Known Issues + +- MixedPrecision `--fp16` should be used `fast_lora`. Also consider loading the model in `torch.float16`. +- `fast_lora` has issues with FSDP with the `peft` style of FSDP wrapping. + * This is because the adapter's forward functions are bypassed in the fused ops. + * For AutoGPTQ this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops. + * However for QLoRA this is not yet done https://github.com/foundation-model-stack/fms-acceleration/issues/3. +- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results. \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml new file mode 100644 index 00000000..2151beb3 --- /dev/null +++ b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml @@ -0,0 +1,27 @@ +# PEFT-related acceleration +peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: auto_gptq + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: True + + # fast loss triton kernels + fast_loss: True + + # fast rms norm triton kernels + fast_rsm_layernorm: True + + # fast RoPE embedding triton kernels + fast_rope_embeddings: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/pyproject.toml b/plugins/fused-ops-and-kernels/pyproject.toml new file mode 100644 index 00000000..2b2aef78 --- /dev/null +++ b/plugins/fused-ops-and-kernels/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-foak" +version = '0.0.1' +description = "FMS Acceleration using Fused Operations and Kernels" +authors = [ + {name = "Fabian Lim", email = "flim@sg.ibm.com"}, + {name = "Aaron Chew", email = "aaron.chew1@ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.9" +keywords = ['fms-hf-tuning', 'acceleration', 'fused-ops', 'triton'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = ['pandas'] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_foak"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py new file mode 100644 index 00000000..edf3f23d --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py @@ -0,0 +1,16 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py new file mode 100644 index 00000000..ad0a399c --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -0,0 +1,169 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Callable, Dict, Tuple + +# Third Party +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from peft.tuners.lora.layer import LoraLayer +from transformers import TrainingArguments +from transformers.utils import logging +import torch +import torch.distributed as dist + +# want to use the transformers logger, but a bit of pain +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger.setLevel(logging._get_default_logging_level()) +logger.addHandler(logging._default_handler) + + +def log_patch_summary( + logging_func: Callable = None, +): + if logging_func is None: + logging_func = print + + # this is a guarded import, because the model rule registration + # does not need to be loaded unless patch_model is required + # Local + from .models.model_patcher import ( # pylint: disable=import-outside-toplevel + patch_model_summary, + ) + + for line in patch_model_summary().split("\n"): + logging_func(line) + + +# consider moving this somewhere else later +def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin): + """ + This function installs hooks on the target adapter parameters and + reduces the accumulated gradients across devices + """ + + fsdp_plugin.ignored_modules = modules + + def _all_reduce_hook(grad): + if grad is not None: + grad = grad.contiguous() + dist.all_reduce(grad, op=dist.ReduceOp.AVG, group=None) + return grad + + for mod in modules: + # install hooks on the adapters + mod.lora_A.default.weight.register_hook(_all_reduce_hook) + mod.lora_B.default.weight.register_hook(_all_reduce_hook) + + +class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin): + + # NOTE: may remove this when we have generic model rules + restricted_model_archs = [ + "MixtralForCausalLM", + "LlamaForCausalLM", + "MistralForCausalLM", + ] + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + + self._base_layer = self._check_config_and_maybe_check_values( + key="peft.quantization.fused_ops_and_kernels.base_layer", + values=[ + "auto_gptq", + # "bitsandbytes" # enable later when we have BNB implemented + ], + ) + + # only support these at the moment + self._check_config_equal( + key="peft.quantization.fused_ops_and_kernels.fused_lora", value=True + ) + self._check_config_equal( + key="peft.quantization.fused_ops_and_kernels.fast_loss", value=True + ) + self._check_config_equal( + key="peft.quantization.fused_ops_and_kernels.fast_rsm_layernorm", + value=True, + ) + self._check_config_equal( + key="peft.quantization.fused_ops_and_kernels.fast_rope_embeddings", + value=True, + ) + + @property + def requires_agumentation(self): + return True + + def augmentation( + self, + model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + # NOTE: how do I check this now that the modifiable args are missing + # assert peft_config.lora_dropout == 0, \ + # "Fused Attention requires lora_dropout argument to be set to 0" + + # need to check why this is needed + assert ( + model.dtype == torch.float16 and train_args.fp16 + ), "need to run in fp16 mixed precision or load model in fp16" + + # this is a guarded import, because the model rule registration + # does not need to be loaded unless patch_model is required + # Local + from .models.model_patcher import ( # pylint: disable=import-outside-toplevel + patch_model, + ) + + model = patch_model(model, base_type=self._base_layer) + return model, modifiable_args + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator=None + ): + + # if this is moved to framework, it can be handled as the same way as + # log_initialization_message + # log the patch summary + if accelerator is not None and accelerator.is_main_process: + log_patch_summary(logging_func=logger.info) + + callbacks = [] + if ( + accelerator is not None + and getattr(accelerator.state, "fsdp_plugin", None) is not None + ): + # This function installs grad reduction hooks on adapters if + # FSDP is detected. Because of incompatibility between FSDP and + # fused modules, adapters are not sharded - instead + # accumulated gradients from adapters in each device are reduced + # in these grad reduce hooks + # This function might be removed in future if the incompatiblity + # is resolved + lora_adapters_switch_ddp_from_fsdp( + [mod for mod in model.modules() if isinstance(mod, LoraLayer)], + accelerator.state.fsdp_plugin, + ) + return callbacks + + +# register +AccelerationPlugin.register_plugin( + FastQuantizedPeftAccelerationPlugin, + configuration_and_paths=["peft.quantization.fused_ops_and_kernels"], +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py new file mode 100644 index 00000000..b994759e --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py new file mode 100644 index 00000000..a35f05f9 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel +from .geglu import ( + geglu_exact_forward_kernel, + geglu_exact_backward_kernel, + geglu_approx_forward_kernel, + geglu_approx_backward_kernel, +) +from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py new file mode 100644 index 00000000..a5c556b4 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from .fast_lora import ( + get_lora_parameters, + apply_lora_mlp_swiglu, + apply_lora_mlp_geglu_exact, + apply_lora_mlp_geglu_approx, + apply_lora_qkv, + apply_lora_o, +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py new file mode 100644 index 00000000..82f78f74 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -0,0 +1,396 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from ..utils import fast_dequantize, QUANT_STATE, get_lora_parameters, matmul_lora + + +class LoRA_MLP(torch.autograd.Function): + """ + ### LoRA weights + G = G + Ag @ Bg + U = U + Au @ Bu + W = W + Aw @ Bw + + ### SwiGLU(X) + e = X @ G + f = e * sigmoid(e) + g = X @ U + h = f * g + i = h @ W + + ### Backpropagation chain rule + See our blog post for more details + + df = sigmoid(e) * (1 - f) + f + dC/dW = h.T @ dY + dC/dU = X.T @ (D @ W.T * f) + dC/dG = X.T @ (D @ W.T * df * g) + + ### Down projection LoRA weights + dC/dAw = dC/dW @ B.T + dC/dBw = A.T @ dC/dW + dC/dAw = h.T @ dY @ B.T + dC/dBw = A.T @ h.T @ dY + + ### Up projection LoRA weights + dC/dAu = X.T @ (D @ W.T * f) @ B.T + dC/dBu = A.T @ X.T @ (D @ W.T * f) + + ### Gate projection LoRA weights + dC/dAg = X.T @ (D @ W.T * df * g) @ B.T + dC/dBg = A.T @ X.T @ (D @ W.T * df * g) + + Don't forget to see our blog post for more details! + """ + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, X : torch.Tensor, + gateW, gateW_quant, gateA, gateB, gateS, + upW, upW_quant, upA, upB, upS, + downW, downW_quant, downA, downB, downS, + _forward_function, _backward_function,): + dtype = X.dtype + + e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS) + g = matmul_lora(X, upW, upW_quant, upA, upB, upS) + h = _forward_function(e, g) + i = matmul_lora(h, downW, downW_quant, downA, downB, downS) + + ctx.custom_saved_tensors = ( + gateW, gateW_quant, gateS, + upW, upW_quant, upS, + downW, downW_quant, downS, + _backward_function, + ) + ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, + X, e, g) + return i + pass + + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY : torch.Tensor): + gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \ + _backward_function = ctx.custom_saved_tensors + gateA, gateB, upA, upB, downA, downB, \ + X, e, g = ctx.saved_tensors + + gateA, gateB, upA, upB, downA, downB = \ + gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t() + + batch, seq_len, hd = X.shape + dY = dY.view(-1, dY.shape[-1]) + X = X .view(-1, X .shape[-1]) + e = e .view(-1, e .shape[-1]) + g = g .view(-1, g .shape[-1]) + dtype = X.dtype + + DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS) + DW, e, g = _backward_function(DW, e, g) + h, df, de = DW, e, g + + # Down projection LoRA weights + d_downA = h.t() @ (dY @ downB.t()) + d_downB = (downA.t() @ h.t()) @ dY + d_downA *= downS + d_downB *= downS + + # Up projection LoRA weights + d_upA = X.t() @ (df @ upB.t()) + d_upB = (upA.t() @ X.t()) @ df + d_upA *= upS + d_upB *= upS + + # Gate projection LoRA weights + d_gateA = X.t() @ (de @ gateB.t()) + d_gateB = (gateA.t() @ X.t()) @ de + d_gateA *= gateS + d_gateB *= gateS + + # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) + # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) + upW = fast_dequantize(upW.t(), upW_quant) + dX = torch.matmul(df, upW.t(), out = X) + del upW + dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) + + gateW = fast_dequantize(gateW.t(), gateW_quant) + dX += de @ gateW.t() + del gateW + dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) + + # gateW, gateW_quant, gateA, gateB, gateS, + # upW, upW_quant, upA, upB, upS, + # downW, downW_quant, downA, downB, downS, + return dX.view(batch, seq_len, hd), \ + None, None, d_gateA.t(), d_gateB.t(), None, \ + None, None, d_upA.t(), d_upB.t(), None, \ + None, None, d_downA.t(), d_downB.t(), None, \ + None, None, # _backward and _forward + pass +pass + + +from ..swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel +def apply_lora_mlp_swiglu(self, X): + gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) + downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + out = LoRA_MLP.apply(X, + gateW, gateW_quant, gateA, gateB, gateS, + upW, upW_quant, upA, upB, upS, + downW, downW_quant, downA, downB, downS, + swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,) + return out +pass + + +from ..geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel +def apply_lora_mlp_geglu_exact(self, X): + gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) + downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + out = LoRA_MLP.apply(X, + gateW, gateW_quant, gateA, gateB, gateS, + upW, upW_quant, upA, upB, upS, + downW, downW_quant, downA, downB, downS, + geglu_exact_forward_kernel, geglu_exact_backward_kernel,) + return out +pass + + +from ..geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel +def apply_lora_mlp_geglu_approx(self, X): + gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) + downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + out = LoRA_MLP.apply(X, + gateW, gateW_quant, gateA, gateB, gateS, + upW, upW_quant, upA, upB, upS, + downW, downW_quant, downA, downB, downS, + geglu_approx_forward_kernel, geglu_approx_backward_kernel,) + return out +pass + + +class LoRA_QKV(torch.autograd.Function): + """ + ### LoRA weights + Wq = Wq + Aq @ Bq + Wk = Wk + Ak @ Bk + Wv = Wv + Av @ Bv + Q = X @ Wq = X @ Wq + X @ Aq @ Bq + K = X @ Wk = X @ Wk + X @ Ak @ Bk + V = X @ Wv = X @ Wv + X @ Av @ Bv + + ### Backpropagation chain rule + See our blogpost for more details. + + dC/dWq = X.T @ D(Wq) + dC/dWk = X.T @ D(Wk) + dC/dWv = X.T @ D(Wv) + We then sum them all find dC/dX + + ### Q projection LoRA weights + dC/dAq = X.T @ D(Wq) @ B.T + dC/dBq = A.T @ X.T @ D(Wq) + + ### K projection LoRA weights + dC/dAk = X.T @ D(Wk) @ B.T + dC/dBk = A.T @ X.T @ D(Wk) + + ### V projection LoRA weights + dC/dAv = X.T @ D(Wv) @ B.T + dC/dBv = A.T @ X.T @ D(Wv) + """ + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, X : torch.Tensor, + QW, QW_quant, QA, QB, QS, + KW, KW_quant, KA, KB, KS, + VW, VW_quant, VA, VB, VS,): + dtype = X.dtype + + Q = matmul_lora(X, QW, QW_quant, QA, QB, QS) + K = matmul_lora(X, KW, KW_quant, KA, KB, KS) + V = matmul_lora(X, VW, VW_quant, VA, VB, VS) + + ctx.custom_saved_tensors = ( + QW, QW_quant, QS, + KW, KW_quant, KS, + VW, VW_quant, VS, + ) + ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,) + return Q, K, V + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dQ, dK, dV): + QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \ + ctx.custom_saved_tensors + X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors + + QA, QB, KA, KB, VA, VB = \ + QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t() + + batch, seq_len, hd = X.shape + dQ = dQ.view(-1, dQ.shape[-1]) + dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T + dV = dV.view(-1, dV.shape[-1]) + X = X .view(-1, X .shape[-1]) + dtype = X.dtype + + ### Weight projection LoRA weights + # See our blogpost for more details. + + # Q Projection + d_QA = X.t() @ (dQ @ QB.t()) + d_QB = (QA.t() @ X.t()) @ dQ + d_QA *= QS + d_QB *= QS + + # K Projection + d_KA = X.t() @ (dK @ KB.t()) + d_KB = (KA.t() @ X.t()) @ dK + d_KA *= KS + d_KB *= KS + + # V Projection + d_VA = X.t() @ (dV @ VB.t()) + d_VB = (VA.t() @ X.t()) @ dV + d_VA *= VS + d_VB *= VS + + # Combine derivatives to find dX + # dQ + QW = fast_dequantize(QW.t(), QW_quant) + dX = torch.matmul(dQ, QW.t(), out = X) + del QW + dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) + + # dK + KW = fast_dequantize(KW.t(), KW_quant) + dX += dK @ KW.t() + del KW + dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) + + # dV + VW = fast_dequantize(VW.t(), VW_quant) + dX += dV @ VW.t() + del VW + dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) + + # QW, QW_quant, QA, QB, QS, + # KW, KW_quant, KA, KB, KS, + # VW, VW_quant, VA, VB, VS, + return dX.view(batch, seq_len, hd), \ + None, None, d_QA.t(), d_QB.t(), None, \ + None, None, d_KA.t(), d_KB.t(), None, \ + None, None, d_VA.t(), d_VB.t(), None + pass +pass + + +def apply_lora_qkv(self, X): + QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) + KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) + VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) + Q, K, V = LoRA_QKV.apply(X, + QW, QW_quant, QA, QB, QS, + KW, KW_quant, KA, KB, KS, + VW, VW_quant, VA, VB, VS, + ) + return Q, K, V +pass + + +class LoRA_W(torch.autograd.Function): + """ + ### LoRA weights + Wq = Wq + Aq @ Bq + Wk = Wk + Ak @ Bk + Wv = Wv + Av @ Bv + Q = X @ Wq = X @ Wq + X @ Aq @ Bq + K = X @ Wk = X @ Wk + X @ Ak @ Bk + V = X @ Wv = X @ Wv + X @ Av @ Bv + + ### Backpropagation chain rule + dC/dWq = X.T @ D(Wq) + dC/dWk = X.T @ D(Wk) + dC/dWv = X.T @ D(Wv) + + ### Q projection LoRA weights + dC/dAq = X.T @ D(Wq) @ B.T + dC/dBq = A.T @ X.T @ D(Wq) + + ### K projection LoRA weights + dC/dAk = X.T @ D(Wk) @ B.T + dC/dBk = A.T @ X.T @ D(Wk) + + ### V projection LoRA weights + dC/dAv = X.T @ D(Wv) @ B.T + dC/dBv = A.T @ X.T @ D(Wv) + """ + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, X : torch.Tensor, + W, W_quant, A, B, S): + dtype = X.dtype + XW = matmul_lora(X, W, W_quant, A, B, S) + ctx.custom_saved_tensors = (W, W_quant, S,) + ctx.save_for_backward(A, B, X) + return XW + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY : torch.Tensor): + W, W_quant, S = ctx.custom_saved_tensors + A, B, X = ctx.saved_tensors + + A, B = A.t(), B.t() + + batch, seq_len, hd = X.shape + dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape + X = X .reshape(-1, X .shape[-1]) # Must be reshape + dtype = X.dtype + + ### Weight projection LoRA weights + # Weight projection + d_A = X.t() @ (dY @ B.t()) + d_B = (A.t() @ X.t()) @ dY + d_A *= S + d_B *= S + + # Get derivative for dX + W = fast_dequantize(W.t(), W_quant) + dX = dY @ W.t() + del W + dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) + + # W, W_quant, A, B, S + return dX.view(batch, seq_len, hd), \ + None, None, d_A.t(), d_B.t(), None + pass +pass + + +def apply_lora_o(self, X): + OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) + O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) + return O +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py new file mode 100644 index 00000000..3441c59d --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py @@ -0,0 +1,202 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl +import torch + + +@triton.jit +def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) + # h = f * up + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) + f_row = f_row.to(g_row.dtype) # Exact copy from HF + h_row = f_row * g_row + + # Store h + tl.store(h + offsets, h_row, mask = mask) +pass + + +def geglu_exact_forward_kernel(gate, up): + batch, seq_len, hd = gate.shape + n_elements = gate.numel() + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + return out +pass + + +@triton.jit +def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): + """ + f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) + h = f * up + + df/de (with help of Wolfram :) + df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2) + + Reuse via + f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e + """ + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + # Break e_row away for re-use + # f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) + f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) + f_row = f_partial_row * e_row + + f_row = f_row.to(DW_row.dtype) + # h = f * g + h_row = f_row * g_row + # df = DW * f + df_row = DW_row * f_row + # dg = DW * g + dg_row = DW_row * g_row + + # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2) + t = 0.3989422804014327 # 1/sqrt(2*pi) + df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row) + + de_row = dg_row.to(tl.float32) * df_de + de_row = de_row.to(DW_row.dtype) + + # Store derivatives in buffers + tl.store(DW + offsets, h_row, mask = mask) # h = f * g + tl.store(e + offsets, df_row, mask = mask) # df = DW * f + tl.store(g + offsets, de_row, mask = mask) # de +pass + + +def geglu_exact_backward_kernel(DW, e, g): + batch_seq_len, hd = e.shape + n_elements = e.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + return DW, e, g +pass + + +@triton.jit +def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) )) + # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )) + # h = f * up + s = 0.7978845608028654 # math.sqrt(2 / math.pi) + + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + f_row = 0.5 * e_row * ( + tl.math.tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \ + + 1.0 + ) + f_row = f_row.to(g_row.dtype) # Exact copy from HF + h_row = f_row * g_row + + # Store h + tl.store(h + offsets, h_row, mask = mask) +pass + + +def geglu_approx_forward_kernel(gate, up): + batch, seq_len, hd = gate.shape + n_elements = gate.numel() + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + return out +pass + + +@triton.jit +def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): + """ + f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )) + h = f * up + + df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :)) + df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] + + 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \ + ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) ) + + Notice sech^2(x) = 1 - tanh^2(x) + So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ) + + See https://www.desmos.com/calculator/nqprfoni6x + """ + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + # See https://www.desmos.com/calculator/nqprfoni6x + s = 0.7978845608028654 # math.sqrt(2 / math.pi) + a = s * e_row # a = sqrt(2 / pi) * x + b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2 + T = 1.0 + tl.math.tanh(a + b) + T2 = 0.5 * T + # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b) + Q2 = -T2 * (T - 2.0) * (a + 3.0 * b) + df_de = T2 + Q2 # 1/2 * (T + Q) + + # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) )) + f_row = T2 * e_row + f_row = f_row.to(DW_row.dtype) + # h = f * g + h_row = f_row * g_row + # df = DW * f + df_row = DW_row * f_row + # dg = DW * g + dg_row = DW_row * g_row + + de_row = dg_row.to(tl.float32) * df_de + de_row = de_row.to(DW_row.dtype) + + # Store derivatives in buffers + tl.store(DW + offsets, h_row, mask = mask) # h = f * g + tl.store(e + offsets, df_row, mask = mask) # df = DW * f + tl.store(g + offsets, de_row, mask = mask) # de +pass + + +def geglu_approx_backward_kernel(DW, e, g): + batch_seq_len, hd = e.shape + n_elements = e.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + return DW, e, g +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py new file mode 100644 index 00000000..b9b793a0 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py @@ -0,0 +1,3 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py new file mode 100644 index 00000000..3808fba7 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -0,0 +1,737 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 + +import math +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + +from .triton.kernels import dequant248 +from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel + +logger = getLogger(__name__) + + +@dataclass +class GPTQuantState: + """ + Stores params for GPTQ linear layer quantization + """ + + infeatures: int + outfeatures: int + + bits: int + group_size: int + maxq: int + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + + # cuda_kernel params (not used currently) + kernel_switch_threshold: int + autogptq_cuda_available: bool = False + autogptq_cuda: bool = False + + wf: Optional[torch.Tensor] = None + use_cuda_fp16: bool = False + + bias: Optional[torch.Tensor] = None + trainable: bool = True + + +def unpack_gptqstate(qstate): + qweight, scales, qzeros, g_idx, bits = ( + qstate.qweight, + qstate.scales, + qstate.qzeros, + qstate.g_idx, + qstate.bits, + ) + return qweight, scales, qzeros, g_idx, bits + + +def extract_gptq_state(qmodule): + if hasattr(qmodule, "base_layer"): + qmodule = qmodule.base_layer + + def check_bias(qmodule): + if hasattr(qmodule, "bias") and qmodule.bias is not None: + if qmodule.bias.count_nonzero() > 0: + return qmodule.bias + return None + + return GPTQuantState( + infeatures=qmodule.infeatures, + outfeatures=qmodule.outfeatures, + bits=qmodule.bits, + group_size=qmodule.group_size, + maxq=qmodule.maxq, + qweight=qmodule.qweight.cuda(), + qzeros=qmodule.qzeros.cuda(), + scales=qmodule.scales.cuda(), + g_idx=qmodule.g_idx.cuda(), + bias=check_bias(qmodule), + wf=qmodule.wf.cuda() if hasattr(qmodule, "wf") else None, + kernel_switch_threshold=( + qmodule.kernel_switch_threshold + if hasattr(qmodule, "kernel_switch_threshold") + else None + ), + autogptq_cuda_available=( # fixed by @aaron.chew1@sg.ibm.com + qmodule.autogptq_cuda_available + if hasattr(qmodule, "autogptq_cuda_available") else False + ), + # use_cuda_fp16=qmodule.use_cuda_fp16, + ) + + +def get_lora_parameters(proj): + # For DPO or disabled adapters + base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj + qstate = extract_gptq_state(base_layer) + + if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + return qstate, None, None, None + + active_adapter = ( + proj.active_adapters[0] + if hasattr(proj, "active_adapters") + else proj.active_adapter + ) + A = proj.lora_A[active_adapter].weight + B = proj.lora_B[active_adapter].weight + s = proj.scaling[active_adapter] + return qstate, A, B, s + + +def matmul_lora_canonicalized(X, W, A, B, s): + """ + X: rank-2 tensor (batch, seq_len) x (din) + W: rank-2 tensor (din, dout) + out: rank-2 tensor (batch, seq_len) x (dout) + din = X.shape[1] + dout = W.shape[1] + """ + + out = torch.matmul(X, W) + + A, B = A.t(), B.t() + out += (X @ A) @ (s * B) + + return out + + +def matmul_lora(X, W, A, B, s, out=None): + dtype = X.dtype + + if X.dim() == 3: + batch, seq_len, d = X.shape + X = X.view(-1, X.shape[-1]) + reshape = True + else: + reshape = False + + out = torch.matmul(X, W, out=out) + + if A is not None: + # LoRA is enabled + A, B = A.t(), B.t() + out += (X @ A.to(dtype)) @ (s * B.to(dtype)) + + return out.view(batch, seq_len, -1) if reshape else out + + +class LoRA_MLP(torch.autograd.Function): + """ + ### LoRA weights + G = G + Ag @ Bg + U = U + Au @ Bu + W = W + Aw @ Bw + + ### SwiGLU(X) + e = X @ G + f = e * sigmoid(e) + g = X @ U + h = f * g + i = h @ W + + ### Backpropagation chain rule + See our blog post for more details + + df = sigmoid(e) * (1 - f) + f + dC/dW = h.T @ dY + dC/dU = X.T @ (D @ W.T * f) + dC/dG = X.T @ (D @ W.T * df * g) + + ### Down projection LoRA weights + dC/dAw = dC/dW @ B.T + dC/dBw = A.T @ dC/dW + dC/dAw = h.T @ dY @ B.T + dC/dBw = A.T @ h.T @ dY + + ### Up projection LoRA weights + dC/dAu = X.T @ (D @ W.T * f) @ B.T + dC/dBu = A.T @ X.T @ (D @ W.T * f) + + ### Gate projection LoRA weights + dC/dAg = X.T @ (D @ W.T * df * g) @ B.T + dC/dBg = A.T @ X.T @ (D @ W.T * df * g) + + Don't forget to see our blog post for more details! + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward( + ctx, + X: torch.Tensor, + gate_qweight, + gate_scales, + gate_qzeros, + gate_g_idx, + gate_bits, + gateA, + gateB, + gateS, + up_qweight, + up_scales, + up_qzeros, + up_g_idx, + up_bits, + upA, + upB, + upS, + down_qweight, + down_scales, + down_qzeros, + down_g_idx, + down_bits, + downA, + downB, + downS, + ): + dtype = X.dtype + + # Separate dequant248 from matmul + gateW = dequant248( + gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits + ) + e = matmul_lora(X, gateW, gateA, gateB, gateS) + upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) + g = matmul_lora(X, upW, upA, upB, upS) + # f = torch.nn.functional.silu(e) + # h = f * g + h = swiglu_fg_kernel(e, g) + + downW = dequant248( + down_qweight, down_scales, down_qzeros, down_g_idx, down_bits + ) + i = matmul_lora(h, downW, downA, downB, downS) + + ctx.custom_saved_tensors = ( + gate_qweight, + gate_scales, + gate_qzeros, + gate_g_idx, + gate_bits, + gateS, + up_qweight, + up_scales, + up_qzeros, + up_g_idx, + up_bits, + upS, + down_qweight, + down_scales, + down_qzeros, + down_g_idx, + down_bits, + downS, + ) + ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g) + return i + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY: torch.Tensor): + ( + gate_qweight, + gate_scales, + gate_qzeros, + gate_g_idx, + gate_bits, + gateS, + up_qweight, + up_scales, + up_qzeros, + up_g_idx, + up_bits, + upS, + down_qweight, + down_scales, + down_qzeros, + down_g_idx, + down_bits, + downS, + ) = ctx.custom_saved_tensors + gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors + + gateA, gateB, upA, upB, downA, downB = ( + gateA.t(), + gateB.t(), + upA.t(), + upB.t(), + downA.t(), + downB.t(), + ) + + batch, seq_len, hd = X.shape + dY = dY.view(-1, dY.shape[-1]) + X = X.view(-1, X.shape[-1]) + e = e.view(-1, e.shape[-1]) + g = g.view(-1, g.shape[-1]) + dtype = X.dtype + + downW = dequant248( + down_qweight, down_scales, down_qzeros, down_g_idx, down_bits + ) + DW = matmul_lora(dY, downW.t(), downB, downA, downS) + # e = e.float() + # se = 1.0 / (1.0 + torch.exp(-e)) + # f = (se * e).to(dtype) + # h = f * g + # df = DW * f + # dg = DW * g + # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) + DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g) + h, df, de = DW, e, g + + # Down projection LoRA weights + d_downA = h.t() @ (dY @ downB.t()) + d_downB = (downA.t() @ h.t()) @ dY + d_downA *= downS + d_downB *= downS + + # Up projection LoRA weights + d_upA = X.t() @ (df @ upB.t()) + d_upB = (upA.t() @ X.t()) @ df + d_upA *= upS + d_upB *= upS + + # Gate projection LoRA weights + d_gateA = X.t() @ (de @ gateB.t()) + d_gateB = (gateA.t() @ X.t()) @ de + d_gateA *= gateS + d_gateB *= gateS + + # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) + # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) + upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) + dX = torch.matmul(df, upW.t()) # , out=X) + del upW + dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) + + gateW = dequant248( + gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits + ) + dX += de @ gateW.t() + del gateW + dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) + + # qweight, scales, qzeros, g_idx, bits + # upW, upW_quant, upA, upB, upS, + # downW, downW_quant, downA, downB, downS, + return ( + dX.view(batch, seq_len, hd), + None, # qweight + None, # scales + None, # qzeros + None, # g_idx + None, # bits + d_gateA.t(), + d_gateB.t(), + None, + None, + None, + None, + None, + None, + d_upA.t(), + d_upB.t(), + None, # dS + None, + None, + None, + None, + None, + d_downA.t(), + d_downB.t(), + None, + ) + + +def apply_lora_mlp(self, X): + gateQstate, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upQState, upA, upB, upS = get_lora_parameters(self.up_proj) + downQState, downA, downB, downS = get_lora_parameters(self.down_proj) + out = LoRA_MLP.apply( + X, + *unpack_gptqstate(gateQstate), + gateA, + gateB, + gateS, + *unpack_gptqstate(upQState), + upA, + upB, + upS, + *unpack_gptqstate(downQState), + downA, + downB, + downS, + ) + return out + + +class LoRA_QKV(torch.autograd.Function): + """ + ### LoRA weights + Wq = Wq + Aq @ Bq + Wk = Wk + Ak @ Bk + Wv = Wv + Av @ Bv + Q = X @ Wq = X @ Wq + X @ Aq @ Bq + K = X @ Wk = X @ Wk + X @ Ak @ Bk + V = X @ Wv = X @ Wv + X @ Av @ Bv + + ### Backpropagation chain rule + See our blogpost for more details. + + dC/dWq = X.T @ D(Wq) + dC/dWk = X.T @ D(Wk) + dC/dWv = X.T @ D(Wv) + We then sum them all find dC/dX + + ### Q projection LoRA weights + dC/dAq = X.T @ D(Wq) @ B.T + dC/dBq = A.T @ X.T @ D(Wq) + + ### K projection LoRA weights + dC/dAk = X.T @ D(Wk) @ B.T + dC/dBk = A.T @ X.T @ D(Wk) + + ### V projection LoRA weights + dC/dAv = X.T @ D(Wv) @ B.T + dC/dBv = A.T @ X.T @ D(Wv) + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward( + ctx, + X: torch.Tensor, + Q_qweight, + Q_scales, + Q_qzeros, + Q_g_idx, + Q_bits, + QA, + QB, + QS, + K_qweight, + K_scales, + K_qzeros, + K_g_idx, + K_bits, + KA, + KB, + KS, + V_qweight, + V_scales, + V_qzeros, + V_g_idx, + V_bits, + VA, + VB, + VS, + ): + dtype = X.dtype + + QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits) + KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits) + VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits) + Q = matmul_lora(X, QW, QA, QB, QS) + K = matmul_lora(X, KW, KA, KB, KS) + V = matmul_lora(X, VW, VA, VB, VS) + + ctx.custom_saved_tensors = ( + Q_qweight, + Q_scales, + Q_qzeros, + Q_g_idx, + Q_bits, + QS, + K_qweight, + K_scales, + K_qzeros, + K_g_idx, + K_bits, + KS, + V_qweight, + V_scales, + V_qzeros, + V_g_idx, + V_bits, + VS, + ) + ctx.save_for_backward( + X, + QA, + QB, + KA, + KB, + VA, + VB, + ) + return Q, K, V + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dQ, dK, dV): + ( + Q_qweight, + Q_scales, + Q_qzeros, + Q_g_idx, + Q_bits, + QS, + K_qweight, + K_scales, + K_qzeros, + K_g_idx, + K_bits, + KS, + V_qweight, + V_scales, + V_qzeros, + V_g_idx, + V_bits, + VS, + ) = ctx.custom_saved_tensors + ( + X, + QA, + QB, + KA, + KB, + VA, + VB, + ) = ctx.saved_tensors + + QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t() + + batch, seq_len, hd = X.shape + dQ = dQ.view(-1, dQ.shape[-1]) + dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T + dV = dV.view(-1, dV.shape[-1]) + X = X.view(-1, X.shape[-1]) + dtype = X.dtype + + ### Weight projection LoRA weights + # See our blogpost for more details. + + # Q Projection + d_QA = X.t() @ (dQ @ QB.t()) + d_QB = (QA.t() @ X.t()) @ dQ + d_QA *= QS + d_QB *= QS + + # K Projection + d_KA = X.t() @ (dK @ KB.t()) + d_KB = (KA.t() @ X.t()) @ dK + d_KA *= KS + d_KB *= KS + + # V Projection + d_VA = X.t() @ (dV @ VB.t()) + d_VB = (VA.t() @ X.t()) @ dV + d_VA *= VS + d_VB *= VS + + # Combine derivatives to find dX + # dQ + QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits) + dX = torch.matmul(dQ, QW.t()) # , out=X) + del QW + dX += dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()) + + # dK + KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits) + dX += dK @ KW.t() + del KW + dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) + + # dV + VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits) + dX += dV @ VW.t() + del VW + dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) + + # Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, QA, QB, QS, + # K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, KA, KB, KS, + # V_qweight, V_scales, V_qzeros, V_wf, V_g_idx, V_bits, VA, VB, VS, + return ( + dX.view(batch, seq_len, hd), + None, + None, + None, + None, + None, + d_QA.t(), + d_QB.t(), + None, # d_QS.t(), + None, + None, + None, + None, + None, + d_KA.t(), + d_KB.t(), + None, # d_KS.t(), + None, + None, + None, + None, + None, + d_VA.t(), + d_VB.t(), + None, + ) + + +def apply_lora_qkv(self, X): + Qqstate, QA, QB, QS = get_lora_parameters(self.q_proj) + Kqstate, KA, KB, KS = get_lora_parameters(self.k_proj) + Vqstate, VA, VB, VS = get_lora_parameters(self.v_proj) + Q, K, V = LoRA_QKV.apply( + X, + *unpack_gptqstate(Qqstate), + QA, + QB, + QS, + *unpack_gptqstate(Kqstate), + KA, + KB, + KS, + *unpack_gptqstate(Vqstate), + VA, + VB, + VS, + ) + return Q, K, V + + +class LoRA_W(torch.autograd.Function): + """ + ### LoRA weights + Wq = Wq + Aq @ Bq + Wk = Wk + Ak @ Bk + Wv = Wv + Av @ Bv + Q = X @ Wq = X @ Wq + X @ Aq @ Bq + K = X @ Wk = X @ Wk + X @ Ak @ Bk + V = X @ Wv = X @ Wv + X @ Av @ Bv + + ### Backpropagation chain rule + dC/dWq = X.T @ D(Wq) + dC/dWk = X.T @ D(Wk) + dC/dWv = X.T @ D(Wv) + + ### Q projection LoRA weights + dC/dAq = X.T @ D(Wq) @ B.T + dC/dBq = A.T @ X.T @ D(Wq) + + ### K projection LoRA weights + dC/dAk = X.T @ D(Wk) @ B.T + dC/dBk = A.T @ X.T @ D(Wk) + + ### V projection LoRA weights + dC/dAv = X.T @ D(Wv) @ B.T + dC/dBv = A.T @ X.T @ D(Wv) + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward( + ctx, + X: torch.Tensor, + O_qweight, + O_scales, + O_qzeros, + O_g_idx, + O_bits, + A, + B, + S, + ): + W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) + XW = matmul_lora(X, W, A, B, S) + del W + ctx.custom_saved_tensors = ( + O_qweight, + O_scales, + O_qzeros, + O_g_idx, + O_bits, + S, + ) + ctx.save_for_backward(A, B, X) + return XW + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY: torch.Tensor): + O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors + A, B, X = ctx.saved_tensors + + A, B = A.t(), B.t() + + batch, seq_len, hd = X.shape + dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape + X = X.reshape(-1, X.shape[-1]) # Must be reshape + dtype = X.dtype + + ### Weight projection LoRA weights + # Weight projection + d_A = X.t() @ (dY @ B.t()) + d_B = (A.t() @ X.t()) @ dY + d_A *= S + d_B *= S + + # Get derivative for dX + W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) + dX = dY @ W.t() + del W + dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) + + # O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, A, B, S + return ( + dX.view(batch, seq_len, hd), + None, + None, + None, + None, + None, + d_A.t(), + d_B.t(), + None, + ) + + +def apply_lora_o(self, X): + Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj) + O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) + return O diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py new file mode 100644 index 00000000..b9b793a0 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py @@ -0,0 +1,3 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py new file mode 100644 index 00000000..c252d26d --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py @@ -0,0 +1,149 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 + +import itertools +from logging import getLogger + +import torch +import triton +import triton.language as tl + +logger = getLogger(__name__) + + +def dequant_ref(qstate): + # assert bits == 4, "Only 4-bit quantization is supported" + qweight, scales, qzeros, wf, g_idx, bits = ( + qstate.qweight, + qstate.scales, + qstate.qzeros, + qstate.wf, + qstate.g_idx, + qstate.bits, + ) + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weights = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + weights = torch.bitwise_and(weights, (2**bits) - 1) + weights = weights.reshape(weights.shape[0] * weights.shape[1], weights.shape[2]) + weights = scales[g_idx] * (weights - zeros[g_idx]) + return weights + + +def make_dequant_configs(block_sizes, num_warps): + configs = [] + for bs, ws in itertools.product(block_sizes, num_warps): + configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws)) + return configs + + +DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8]) + + +@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"]) +@triton.jit +def dequant_kernel_248( + g_idx_ptr, + scales_ptr, + qweight_ptr, + qzeros_ptr, + out_ptr, + numels, + maxq: tl.constexpr, + bits: tl.constexpr, + outfeatures: tl.constexpr, + num_groups: tl.constexpr, + X_BLOCK: tl.constexpr = 1024, +): + # Block indexing + xoffset = tl.program_id(0) * X_BLOCK + x_index = xoffset + tl.arange(0, X_BLOCK) + xmask = x_index < numels + row_idx = x_index // outfeatures + col_idx = x_index % outfeatures + + elements_per_feature: tl.constexpr = 32 // bits + + # Load parameters + g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last") + qweights = tl.load( + qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))), + None, + ) + + wf_weights = (row_idx % elements_per_feature) * bits + + wf_zeros = (col_idx % elements_per_feature) * bits + + tmp1 = g_idx + num_groups + tmp2 = g_idx < 0 + tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0") + groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx + + scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to( + tl.float32 + ) + + # Unpack weights + weights = qweights >> wf_weights # bit shift qweight + + weights = weights & maxq + + # Unpack zeros + qzero_ncols: tl.constexpr = outfeatures // elements_per_feature + qzeros = tl.load( + qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)), + None, + eviction_policy="evict_last", + ) + zeros = qzeros >> wf_zeros + zeros = zeros & maxq + + # Dequantize + zeros = zeros + 1 + weights = weights - zeros + weights = weights.to(tl.float32) + weights = scales * weights + + tl.store(out_ptr + (x_index), weights, mask=xmask) + + +def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None): + """Launcher for triton dequant kernel + Only valid for bits = 2, 4, 8 + + """ + + assert bits in [2, 4, 8], "Only 2, 4, 8-bit GPTQ quantization is supported" + num_groups = scales.shape[0] + outfeatures = scales.shape[1] + infeatures = g_idx.shape[0] + + out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16) + numels = out.numel() + maxq = 2**bits - 1 if maxq is None else maxq + grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) + + dequant_kernel_248[grid]( + g_idx, + scales, + qweight, + qzeros, + out, + numels, + maxq=maxq, + bits=bits, + outfeatures=outfeatures, + num_groups=num_groups, + ) + return out diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py new file mode 100644 index 00000000..d8ed096c --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py @@ -0,0 +1,170 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 + +import logging + +import torch +import torch.nn as nn +from auto_gptq.nn_modules.qlinear.qlinear_triton import ( + QuantLinearInferenceOnlyFunction, + quant_matmul_inference_only_248, + transpose_quant_matmul_248, +) +# fixed by aaron.chew1@sg.ibm.com +from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( + QuantLinearFunction, quant_matmul_248 +) + +logger = logging.getLogger(__name__) +import math + +""" +For testing only -- replaces HuggingFace default GPTQ QLinear layer (`cuda / torch` -> `triton`) +""" + + +# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/__init__.py +class GPTQuantLinear(nn.Linear): + def __init__(self, quant_linear_module, trainable=True): + if hasattr(quant_linear_module, "base_layer"): + quant_linear_module = quant_linear_module.base_layer + + bias = ( + True + if hasattr(quant_linear_module, "bias") + and quant_linear_module.bias.count_nonzero() > 0 + else False + ) + + super().__init__( + in_features=quant_linear_module.infeatures, + out_features=quant_linear_module.outfeatures, + bias=bias, + ) + + self.infeatures = quant_linear_module.infeatures + self.outfeatures = quant_linear_module.outfeatures + self.bits = quant_linear_module.bits + self.group_size = quant_linear_module.group_size + self.maxq = quant_linear_module.maxq + + self.weight.requires_grad = False + + self.weight.data = quant_linear_module.qweight + self.register_buffer("qweight", quant_linear_module.qweight) + if bias: + self.bias.data = quant_linear_module.bias + self.bias.requires_grad = False + + self.qweight.requires_grad = False + + self.register_buffer("qzeros", quant_linear_module.qzeros) + self.register_buffer("scales", quant_linear_module.scales) + self.register_buffer("g_idx", quant_linear_module.g_idx) + + if hasattr(quant_linear_module, "wf"): + self.wf = quant_linear_module.wf + if hasattr(quant_linear_module, "kernel_switch_threshold"): + self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold + if hasattr(quant_linear_module, "autogptq_cuda_available"): + self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available + + self.trainable = trainable + self.QUANT_TYPE = "triton" + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + quant_linear_fn = ( + QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction + ) + out = quant_linear_fn.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out.half().reshape(out_shape) + out = out + self.bias if self.bias is not None else out + + return out + + @classmethod + def warmup(cls, model, transpose=True, seqlen=2048): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + assert cls.QUANT_TYPE == "triton" + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, cls): + continue + + k = m.infeatures + n = m.outfeatures + + if (k, n) not in kn_values: + kn_values[(k, n)] = ( + m.qweight, + m.scales, + m.qzeros, + m.g_idx, + m.bits, + m.maxq, + ) + + logger.info(f"Found {len(kn_values)} unique KN Linear values.") + logger.info("Warming up autotune cache ...") + with torch.no_grad(): + for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): + m = 2**m + for (k, n), ( + qweight, + scales, + qzeros, + g_idx, + bits, + maxq, + ) in kn_values.items(): + if transpose: + a = torch.randn(m, k, dtype=torch.float16, device=model.device) + quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) + a = torch.randn(m, n, dtype=torch.float16, device=model.device) + transpose_quant_matmul_248( + a, qweight, scales, qzeros, g_idx, bits, maxq + ) + else: + a = torch.randn(m, k, dtype=torch.float16, device=model.device) + quant_matmul_inference_only_248( + a, qweight, scales, qzeros, g_idx, bits, maxq + ) + del kn_values + + @classmethod + def inject_to_model(cls, model, target_module_type, **kwargs): + count = 0 + for name, m in model.named_modules(): + if not isinstance(m, target_module_type): + continue + new_m = cls(m, **kwargs) + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] + parent = model.get_submodule(parent_name) + else: + parent_name = "" + parent = model + child_name = name + + setattr(parent, child_name, new_m) + count += 1 + logger.warning_once( + f"Injected {count} triton qlinear layers in place of {target_module_type} layers." + ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py new file mode 100644 index 00000000..9c68bd61 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py @@ -0,0 +1,425 @@ +# taken from +# https://github.com/jeromeku/unsloth/commit/ +# 2839d390ef3bb318904289bfb9a7751a782c4e44 + +import builtins +import heapq +import math +import time +from typing import Dict + +import triton + +# code based on https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: + continue + + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + + +CUSTOM_MATMUL_AUTOTUNE_CONFIGS = dict( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), # 3090 + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), # 3090 + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), # 3090 + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), # 3090 + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), # 3090 + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, + warmup=25, + rep=40, +) + +CUSTOM_MATMUL_TRANSPOSE_AUTOTUNE_CONFIGS = dict( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + warmup=25, + rep=40, +) + + +class CustomizedTritonAutoTuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + warmup=25, + rep=40, + ): + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + self.warmup = warmup + self.rep = rep + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench( + kernel_call, quantiles=(0.5, 0.2, 0.8), rep=self.rep, warmup=self.warmup + ) + except triton.OutOfResources: + return (float("inf"), float("inf"), float("inf")) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def custom_autotune( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + nearest_power_of_two=False, + warmup=25, + rep=40, +): + def decorator(fn): + return CustomizedTritonAutoTuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + warmup=warmup, + rep=rep, + ) + + return decorator diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py new file mode 100644 index 00000000..fca96782 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py @@ -0,0 +1,98 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + # f = e * sigmoid(e) + f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row)) + f_row = f_row.to(g_row.dtype) # Exact copy from HF + # h = f * g + h_row = f_row * g_row + + # Store h + tl.store(h + offsets, h_row, mask = mask) +pass + + +def swiglu_fg_kernel(e, g): + batch, seq_len, hd = e.shape + n_elements = e.numel() + h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda") + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) + return h +pass + + +@triton.jit +def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): + """ + e = e.float() + se = 1.0 / (1.0 + torch.exp(-e)) + f = (se * e).to(dtype) + h = f * g + df = DW * f + dg = DW * g + de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) + """ + block_idx = tl.program_id(0) + offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) + + # e = e.float() + # se = 1.0 / (1.0 + torch.exp(-e)) + se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row)) + # f = (se * e).to(dtype) + f_row = se_row * e_row + f_row = f_row.to(DW_row.dtype) + # h = f * g + h_row = f_row * g_row + # df = DW * f + df_row = DW_row * f_row + # dg = DW * g + dg_row = DW_row * g_row + # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) + de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row)) + de_row = de_row.to(DW_row.dtype) + + # Store derivatives in buffers + tl.store(DW + offsets, h_row, mask = mask) # h = f * g + tl.store(e + offsets, df_row, mask = mask) # df = DW * f + tl.store(g + offsets, de_row, mask = mask) # de +pass + + +def swiglu_DWf_DW_dfg_kernel(DW, e, g): + batch_seq_len, hd = e.shape + n_elements = e.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + return DW, e, g +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py new file mode 100644 index 00000000..6ea90780 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py @@ -0,0 +1,247 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +MAX_FUSED_SIZE = 65536 +next_power_of_2 = triton.next_power_of_2 + +def calculate_settings(n): + BLOCK_SIZE = next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ + f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") + num_warps = 4 + if BLOCK_SIZE >= 32768: num_warps = 32 + elif BLOCK_SIZE >= 8192: num_warps = 16 + elif BLOCK_SIZE >= 2048: num_warps = 8 + return BLOCK_SIZE, num_warps +pass + +# import guard added by flim@sg.ibm.com +from transformers.utils.import_utils import _bitsandbytes_available +if _bitsandbytes_available: + import bitsandbytes as bnb + get_ptr = bnb.functional.get_ptr + import ctypes + import torch + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 + cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 + cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 + cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 + + +def QUANT_STATE(W): + return getattr(W, "quant_state", None) +pass + + +def get_lora_parameters(proj): + # For DPO or disabled adapters + base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + W = base_layer.weight + + if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + return W, QUANT_STATE(W), None, None, None + pass + + active_adapter = proj.active_adapters[0] if \ + hasattr(proj, "active_adapters") else proj.active_adapter + A = proj.lora_A [active_adapter].weight + B = proj.lora_B [active_adapter].weight + s = proj.scaling[active_adapter] + return W, QUANT_STATE(W), A, B, s +pass + + +def fast_dequantize(W, quant_state = None, out = None): + if quant_state is None: return W + if type(quant_state) is not list: + # New quant_state as a class + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + absmax = quant_state.absmax + shape = quant_state.shape + dtype = quant_state.dtype + blocksize = quant_state.blocksize + offset = quant_state.offset + state2 = quant_state.state2 + absmax2 = state2.absmax + code2 = state2.code + blocksize2 = state2.blocksize + else: + # Old quant_state as a list of lists + absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state + offset, state2 = compressed_stats + absmax2, code2, blocksize2, _, _, _, _ = state2 + pass + + # Create weight matrix + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda") + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + + # NF4 dequantization of statistics + n_elements_absmax = absmax.numel() + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda") + + # Do dequantization + ptr_out_absmax = get_ptr(out_absmax) + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax) + ) + out_absmax += offset + + fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ + cdequantize_blockwise_bf16_nf4 + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes.c_int(blocksize), ctypes.c_int(out.numel())) + + # Careful returning transposed data + is_transposed = (True if W.shape[0] == 1 else False) + return out.t() if is_transposed else out +pass + + +def fast_gemv(X, W, quant_state, out = None): + if quant_state is None: return torch.matmul(X, W, out = out) + # For fast X @ W where seq_len == 1 + # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 + _, q_len, hd = X.shape + # assert(q_len == 1) + + if type(quant_state) is not list: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + absmax = quant_state.absmax + shape = quant_state.shape + dtype = quant_state.dtype + blocksize = quant_state.blocksize + stats = quant_state.code + offset = quant_state.offset + state2 = quant_state.state2 + absmax2 = state2.absmax + code2 = state2.code + blocksize2 = state2.blocksize + else: + absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state + offset, state2 = compressed_stats + absmax2, code2, blocksize2, _, _, _, _ = state2 + pass + # assert(dtype == X.dtype) + bout = shape[0] + + if out is None: + out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda") + # else: + # assert(out.shape == (1, 1, bout,)) + # pass + + n = 1 + m = shape[0] + k = shape[1] + lda = shape[0] + ldc = shape[0] + ldb = (hd+1)//2 + m = ctypes.c_int32(m) + n = ctypes.c_int32(n) + k = ctypes.c_int32(k) + lda = ctypes.c_int32(lda) + ldb = ctypes.c_int32(ldb) + ldc = ctypes.c_int32(ldc) + + df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda") + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), + ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), + ) + df += offset + absmax = df + + fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ + cgemm_4bit_inference_naive_bf16 + + blocksize = ctypes.c_int32(blocksize) + fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), + lda, ldb, ldc, blocksize) + + return out +pass + + +def fast_linear_forward(proj, X, temp_lora = None, out = None): + + W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj) + bsz, q_len, in_dim = X.shape + if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) + + if W_quant is None: + out = torch.matmul(X, W.t(), out = out) + elif bsz == 1 and q_len == 1: + out = fast_gemv(X, W, W_quant, out = out) + else: + W = fast_dequantize(W.t(), W_quant) + out = torch.matmul(X, W, out = out) + pass + + # Add in LoRA weights + if lora_A is not None: + out_dim = out.shape[2] + dtype = X.dtype + + if not hasattr(lora_A, "_fast_lora"): + lora_A._fast_lora = lora_A.to(dtype) + lora_B._fast_lora = lora_B.to(dtype) + pass + + if bsz == 1: + out = out.view(out_dim) + temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) + else: + out = out.view(bsz, out_dim) + temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) + out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) + pass + out = out.view(bsz, 1, out_dim) + pass + + return out +pass + + +def matmul_lora(X, W, W_quant, A, B, s, out = None): + dtype = X.dtype + W = fast_dequantize(W.t(), W_quant) + + if X.dim() == 3: + batch, seq_len, d = X.shape + X = X.view(-1, X.shape[-1]) + reshape = True + else: + reshape = False + pass + + out = torch.matmul(X, W, out = out) + if W_quant is not None: del W + + if A is not None: + # LoRA is enabled + A, B = A.t(), B.t() + out += (X @ A.to(dtype)) @ (s * B.to(dtype)) + pass + + return out.view(batch, seq_len, -1) if reshape else out +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py new file mode 100644 index 00000000..b994759e --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py new file mode 100644 index 00000000..0c5c2706 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cross_entropy_loss import fast_cross_entropy_loss +from .rms_layernorm import fast_rms_layernorm +from .rope_embedding import fast_rope_embedding \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py new file mode 100644 index 00000000..ebf6f3d0 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py @@ -0,0 +1,292 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl +import torch +from .utils import calculate_settings, MAX_FUSED_SIZE + + +@triton.jit +def _cross_entropy_forward( + logits_ptr, logits_row_stride, + loss_ptr, + logsumexp_ptr, + labels_ptr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, +): + """ + Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] + Pi = exp(xi) / sum(exp(xi)) + CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ] + = -y [ x - log[sum(exp(x))] ] + = y * (log[sum(exp(x))] - x) + If y == 0: CE_i = 0 + If y == 1: CE_i = logsumexp - x + + logsumexp is also stable + Take y = log[sum(exp(x))] + exp(y) = sum(exp(x)) + exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x + exp(y) = exp(c)*sum(exp(x - c)) + y = log(exp(c)*sum(exp(x - c))) + y = c + log[sum(exp(x - c))] + This means we can set c = max(x) to make sure + exp(x - c) always is exp(x - max(x)). + This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. + """ + row_idx = tl.program_id(0) + logits_ptr += row_idx * logits_row_stride.to(tl.int64) + loss_ptr += row_idx + logsumexp_ptr += row_idx + labels_ptr += row_idx + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < VOCAB_SIZE + + label_idx = tl.load(labels_ptr).to(tl.int32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + c = tl.max(logits, 0) + logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) + + if label_idx != -100: + x = tl.load(logits_ptr + label_idx).to(tl.float32) + loss = logsumexp - x + else: + loss = 0.0 + tl.store(logsumexp_ptr, logsumexp) + tl.store(loss_ptr, loss) +pass + + +@triton.jit +def _chunked_cross_entropy_forward( + logits_ptr, logits_row_stride, + loss_ptr, + logsumexp_ptr, + labels_ptr, + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, +): + """ + 256K vocab divided in 4 chunks + + |-65536-| |-65536-| |-65536-| |-65536-| + |-------| |-------| |-------| |-------| + |-------| |-------| |-------| |-------| + + If y == 0: CE_i = 0 + If y == 1: CE_i = logsumexp - x + + Notice we can do logsumexp for each chunk and then + logsumexp[chunk_sum(logsumexp)] == logsumexp + + chunk_sum = log[chunk_sum(logsumexp)] + = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))] + = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])] + = log[sum(exp(a)) + ... + sum(exp(z))] + = logsumexp(x) + + This means we can perform a logsumexp for each chunk, then do a + final logsumexp reduction! + + Ie do: logsumexp(chunked_logsumexp) - x + """ + row_idx = tl.program_id(0) + chunk_idx = tl.program_id(1) + logits_ptr += row_idx * logits_row_stride.to(tl.int64) + loss_ptr += row_idx + logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx + labels_ptr += row_idx + + col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < VOCAB_SIZE + + label_idx = tl.load(labels_ptr).to(tl.int32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + c = tl.max(logits, 0) + logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) + + if chunk_idx == 0: + # logsumexp(chunked_logsumexp) - x + # Do the -x separately + if label_idx != -100: + x = tl.load(logits_ptr + label_idx).to(tl.float32) + loss = -1.0 * x + else: + loss = 0.0 + tl.store(loss_ptr, loss) + pass + tl.store(logsumexp_ptr, logsumexp) +pass + + +@triton.jit +def _cross_entropy_backward( + logits_ptr, logits_row_stride, + dloss_ptr, dloss_row_stride, + logsumexp_ptr, + labels_ptr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, +): + """ + CE_i = -y log(P) = y * (log[sum(exp(x))] - x) + dC/dx = d/dx (y * log[sum(exp(x))] - x * y) + + From https://en.wikipedia.org/wiki/LogSumExp + d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x) + + dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y) + dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick + dC/dx = y * exp[x - logsumexp] - d/dx (x * y) + + If y == 0: dC/dx = 0 + If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1 + If y == 1 and x != label: dC/dx = exp[x - logsumexp] + """ + row_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + logits_ptr += row_idx * logits_row_stride.to(tl.int64) + dloss_ptr += row_idx * dloss_row_stride + col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < VOCAB_SIZE + label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) + + if label_idx != -100: + dloss = tl.load(dloss_ptr) + else: + dloss = 0.0 + + x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + logsumexp = tl.load(logsumexp_ptr + row_idx) + y = tl.exp(x - logsumexp) + y = tl.where( + col_offsets == label_idx, + y - 1.0, # exp(x - logsumexp) - 1 + y, # exp(x - logsumexp) + ) + + # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. + tl.store(logits_ptr + col_offsets, dloss * y, mask = mask) +pass + + +MAX_FUSED_SIZE = 65536 # 2**16 + +class Fast_CrossEntropyLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, labels): + n_rows, vocab_size = logits.shape + + div, mod = divmod(vocab_size, MAX_FUSED_SIZE) + n_chunks = div + (mod != 0) + losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + + if n_chunks == 1: + # For small vocabs <= 65336 like Llama, Mistral + BLOCK_SIZE, num_warps = calculate_settings(vocab_size) + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + + _cross_entropy_forward[(n_rows,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + else: + # For large vocabs > 65336 like Gemma 256K + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda") + + _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + num_warps = 32, + ) + # logsumexp(chunked_logsumexp) - x + # Do the -x separately + logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum + losses += logsumexp + losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out! + pass + + ctx.save_for_backward(logits, logsumexp, labels) + return losses + pass + + @staticmethod + def backward(ctx, dlosses): + logits, logsumexp, labels = ctx.saved_tensors + n_rows, vocab_size = logits.shape + + BLOCK_SIZE = 4096 + div, mod = divmod(vocab_size, BLOCK_SIZE) + n_blocks = div + (mod != 0) + + _cross_entropy_backward[(n_rows, n_blocks,)]( + logits, logits.stride(0), + dlosses, dlosses.stride(0), + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = 8, + ) + return logits, None, None, + pass +pass + + +def fast_cross_entropy_loss(logits, labels): + """ + Arguments: + logits: (batch, seq_len, vocab_size) + labels: (batch, seq_len,) + Returns: + losses: float + """ + batch, seq_len, d = logits.shape + assert(labels.shape == (batch, seq_len)) + + loss = Fast_CrossEntropyLoss.apply( + logits.view(batch*seq_len, d), + labels.view(-1), + ) + n_items = torch.count_nonzero(labels != -100) + return loss.sum() / n_items +pass + +# added by flim@sg.ibm.com +class FastCrossEntropyLoss(torch.nn.CrossEntropyLoss): + + def __init__(self): + super().__init__() + + def forward(self, input, target): + loss = Fast_CrossEntropyLoss.apply( + input, target + ) + n_items = torch.count_nonzero(target != -100) + return loss.sum() / n_items diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py new file mode 100644 index 00000000..4db89b78 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py @@ -0,0 +1,192 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl +import torch +from .utils import calculate_settings + + +@triton.jit +def _rms_layernorm_forward( + Y, Y_row_stride, + X, X_row_stride, + W, W_row_stride, + r, r_row_stride, + n_cols, eps, + BLOCK_SIZE : tl.constexpr +): + """ + Fast RMS Layernorm kernel + Inspiration from a Triton tutorial: + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y += row_idx * Y_row_stride + X += row_idx * X_row_stride + r += row_idx * r_row_stride + + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32) + + row_var = tl.sum(X_row * X_row, axis = 0) / n_cols + inv_var = tl.math.rsqrt(row_var + eps) + tl.store(r, inv_var) + normed = X_row * inv_var + normed = normed.to(W_row.dtype) # Exact copy from HF + output = normed * W_row + tl.store(Y + col_offsets, output, mask = mask) +pass + + +@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],}) +@triton.jit +def _rms_layernorm_backward( + dY, dY_row_stride, + X, X_row_stride, + W, W_row_stride, + r, r_row_stride, + dW, dW_row_stride, + n_cols, eps, + GEMMA : tl.constexpr, + BLOCK_SIZE : tl.constexpr, +): + """ + Fast RMS Layernorm kernel for the backward pass + Inspiration from a Triton tutorial: + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dY += row_idx * dY_row_stride + X += row_idx * X_row_stride + r += row_idx * r_row_stride + + dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) + + # Get saved row variance + inv_var = tl.load(r).to(tl.float32) + normed = X_row * inv_var + + if GEMMA: dY_W = dY_row * (W_row + 1.0) + else: dY_W = dY_row * W_row + + rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) + output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) + tl.store(dY + col_offsets, output, mask = mask) +pass + + +@triton.jit +def _gemma_rms_layernorm_forward( + Y, Y_row_stride, + X, X_row_stride, + W, W_row_stride, + r, r_row_stride, + n_cols, eps, + BLOCK_SIZE : tl.constexpr, +): + # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31 + # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33 + # exactly. Essentially all in float32! + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y += row_idx * Y_row_stride + X += row_idx * X_row_stride + r += row_idx * r_row_stride + + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) + + row_var = tl.sum(X_row * X_row, axis = 0) / n_cols + inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl + tl.store(r, inv_var) + normed = X_row * inv_var + output = normed * (W_row + 1.0) + + tl.store(Y + col_offsets, output, mask = mask) +pass + + +class Fast_RMS_Layernorm(torch.autograd.Function): + @staticmethod + def forward(ctx, X, W, eps, gemma = False): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda") + r = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + + fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward + fx[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + ctx.eps = eps + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.GEMMA = gemma + ctx.save_for_backward(X, W, r) + return Y.view(*shape) + pass + + @staticmethod + def backward(ctx, dY): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + X, W, r = ctx.saved_tensors + n_rows, n_cols = dY.shape + dW = X + + _rms_layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + X, X .stride(0), + W, W .stride(0), + r, r .stride(0), + dW, dW.stride(0), + n_cols, ctx.eps, + GEMMA = ctx.GEMMA, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) + dX = dY.view(*shape) + return dX, None, None, None + pass +pass + + +def fast_rms_layernorm(layernorm, X, gemma = False): + W = layernorm.weight + eps = layernorm.variance_epsilon + out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) + return out +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py new file mode 100644 index 00000000..49b04fce --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -0,0 +1,138 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl +import torch +from .utils import calculate_settings + +ROPE_GROUP_SIZE = 4 + +@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],}) +@triton.jit +def _rope_embedding( + Q, Q_row_stride, + cos, cos_row_stride, + sin, sin_row_stride, + seqlen, + head_dim : tl.constexpr, + n_heads : tl.constexpr, + BACKWARD_PASS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, +): + """ + Calculates the RoPE Embedding quickly + RoPE is Q * cos + rotate_half(Q) * sin + See our blog post for more info + """ + row_position = tl.program_id(0) + group_head_position = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + half_head_dim = head_dim // 2 + mask = col_offsets < half_head_dim + + sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \ + half_head_dim*0 + col_offsets, mask = mask, other = 0) + cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \ + half_head_dim*0 + col_offsets, mask = mask, other = 0) + + if BACKWARD_PASS: + # See our blog post for more info. + sin1 = -sin1 + pass + + # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8 + head_start = group_head_position * ROPE_GROUP_SIZE + head_end = min((head_start + ROPE_GROUP_SIZE), n_heads) + + # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238) + for k in range(head_start, head_end): + offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets + offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim + + # For Gemma - sometimes RoPE must be done in float32 and not bfloat16 + Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype) + Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype) + + tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask) + tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask) + pass +pass + + +class Fast_RoPE_Embedding(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, cos, sin): + cos, sin = cos.squeeze(), sin.squeeze() + batch, seq_len, n_heads, head_dim = Q.shape + Q = Q.view(batch*seq_len, n_heads*head_dim) + n_rows, n_cols = Q.shape + assert(seq_len <= cos.shape[0]) + + # [TODO] Changing blocksize to head_dim//2 seems to have + # some concurrency / un-deterministic issues. + BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2) + + # group_size = 4 # 4 or 8, too large group_size can hurt performance. + div, mod = divmod(n_heads, ROPE_GROUP_SIZE) + n_groups = div + (mod != 0) + + _rope_embedding[(n_rows, n_groups, )]( + Q, Q.stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, + head_dim, n_heads, + BACKWARD_PASS = False, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.n_groups = n_groups + ctx.cos = cos + ctx.sin = sin + return Q.view(batch, seq_len, n_heads, head_dim) + pass + + @staticmethod + def backward(ctx, dY): + batch, seq_len, n_heads, head_dim = dY.shape + dY = dY.reshape(batch*seq_len, n_heads*head_dim) + # Must be reshape not view + n_rows, n_cols = dY.shape + + cos = ctx.cos + sin = ctx.sin + + _rope_embedding[(n_rows, ctx.n_groups, )]( + dY, dY .stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, head_dim, n_heads, + BACKWARD_PASS = True, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) + dY = dY.view(batch, seq_len, n_heads, head_dim) + return dY, None, None, + pass +pass + + +def fast_rope_embedding(Q, K, cos, sin): + Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) + K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) + return Q, K +pass \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py new file mode 100644 index 00000000..8d4aa881 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py @@ -0,0 +1,29 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +MAX_FUSED_SIZE = 65536 +next_power_of_2 = triton.next_power_of_2 + +def calculate_settings(n): + BLOCK_SIZE = next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ + f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") + num_warps = 4 + if BLOCK_SIZE >= 32768: num_warps = 32 + elif BLOCK_SIZE >= 8192: num_warps = 16 + elif BLOCK_SIZE >= 2048: num_warps = 8 + return BLOCK_SIZE, num_warps +pass diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py new file mode 100644 index 00000000..7d6df3bc --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py @@ -0,0 +1,24 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .model_patcher import ModelPatcher + +PATCHES = [".models.llama", ".models.mistral"] +PLUGIN_PREFIX = "fms_acceleration_foak" + +# TODO: remove the need for the prefix +ModelPatcher.load_patches( + [f"{PLUGIN_PREFIX}{postfix}" for postfix in PATCHES], +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py new file mode 100644 index 00000000..3d01311a --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -0,0 +1,88 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger +from .utils import build_lora_fused_ops, trigger_fused_ops + +# TODO: have a generic version of this rule +# - do regex on RMSNorm class name +# - check on the tensors required for fast_rms_layernorm +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-rms", + trigger=ModelPatcherTrigger(check=LlamaRMSNorm), + forward=fast_rms_layernorm, + ), +) + +# TODO: have a generic version of this rule +# - do regex on Attention class name +# - have a set of qkv / o module names and check on that +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-qkvo", + trigger=ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + qkv_module_names=["q_proj", "k_proj", "v_proj"], + o_module_name="o_proj", + ) + ), + forward_builder=partial( + build_lora_fused_ops, + qkv_module_names=["q_proj", "k_proj", "v_proj"], + o_module_name="o_proj", + ), + forward_builder_args=["base_type"], + ) +) + +# TODO: have a generic version of this rule +# - get the module_name and reload on that +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.llama.modeling_llama", + ), + ) +) + +# TODO: have a generic version of this rule +# - get the module name +# - check if "apply_rotary_pos_emb" exists +# - patch +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-rope", + import_and_maybe_reload=( + "transformers.models.llama.modeling_llama.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py new file mode 100644 index 00000000..a8e6795f --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -0,0 +1,94 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralRMSNorm, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding as _fast_rope_embedding +from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger +from .utils import build_lora_fused_ops, trigger_fused_ops + + +# NOTE: fast_rope_embedding does not work with position_ids +# currently they are ignored +def fast_rope_embedding(Q, K, cos, sin, position_ids=None): + return _fast_rope_embedding(Q, K, cos, sin) + + +# - do regex on RMSNorm class name +# - check on the tensors required for fast_rms_layernorm +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-rms", + trigger=ModelPatcherTrigger(check=MistralRMSNorm), + forward=fast_rms_layernorm, + ), +) + +# - do regex on Attention class name +# - have a set of qkv / o module names and check on that +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-qkvo", + trigger=ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + qkv_module_names=["q_proj", "k_proj", "v_proj"], + o_module_name="o_proj", + ) + ), + forward_builder=partial( + build_lora_fused_ops, + qkv_module_names=["q_proj", "k_proj", "v_proj"], + o_module_name="o_proj", + ), + forward_builder_args=["base_type"], + ) +) + +# - get the module_name and reload on that +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mistral.modeling_mistral", + ), + ) +) + +# - get the module name +# - check if "apply_rotary_pos_emb" exists +# - patch +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-rope", + import_and_maybe_reload=( + "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py new file mode 100644 index 00000000..3355aa67 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py @@ -0,0 +1,470 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import asdict, dataclass +from enum import Enum +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +import importlib +import inspect + +# Third Party +import pandas as pd +import torch + +# ------------------------ helpers ----------------------- + + +def _patch_target_module( + to_patch: str, + replace_with: Any, + target_module: str = None, +): + to_patch = to_patch.split(".") + assert len(to_patch) > 1, "must have an object to patch" + + to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1] + to_patch = ".".join(to_patch) + source = importlib.import_module(to_patch) + original_obj = getattr(source, obj_name_to_patch) + setattr(source, obj_name_to_patch, replace_with) + + if target_module is not None: + # reload and this should get the patched object + target_module = importlib.import_module(target_module) + importlib.reload(target_module) + + # replace it + setattr(source, obj_name_to_patch, original_obj) + + +# ------------------------ classes ----------------------- + +# Rules will trigger on either +# - module class, which triggers on isinstance +# - callable, which will be useful to trigger on custom checks +# - (consider): adding a regex will will apply on the name +# ModelPatcherTrigger = Union[ +# torch.nn.Module, # trigger on isinstance +# Callable[[torch.nn.Module], bool] # trigger on callable +# ] +# NOTE: triggering on instance checks will not be robust to reloading + + +class ModelPatcherTriggerType(Enum): + module = 1 + callable = 2 + + +@dataclass +class ModelPatcherTrigger: + "Holds the triggering logic for the model patcher rule." + + # the trigger operation + check: Union[ + torch.nn.Module, # trigger on isinstance + Callable[[torch.nn.Module], bool], # trigger on callable + ] + + # holds the type of the trigger + # - type is None that it will be a single call + type: ModelPatcherTriggerType = None + + # if the trigger is specific to model name + module_name: str = None + + def is_triggered( + self, + module: torch.nn.Module, + module_name: str, + ): + "Check if trigger returns truthful." + + if self.module_name is not None and module_name != self.module_name: + return False + + if self.type == ModelPatcherTriggerType.module and isinstance( + module, self.check + ): + return True + + try: + # the function call may raise + if self.type == ModelPatcherTriggerType.callable and self.check(module): + return True + except Exception: # pylint: disable=broad-exception-caught + # NOTE: not sure if its good idea to let the exception pass through + pass + + return False + + def __post_init__(self): + + if self.type is None: + if inspect.isclass(self.check) and issubclass(self.check, torch.nn.Module): + self.type = ModelPatcherTriggerType.module + else: + self.type = ModelPatcherTriggerType.callable + + +# type for model forward +ModelForward = Callable + + +@dataclass +class ModelPatcherRule: + # id, must be unique + rule_id: str + + # trigger + # - if trigger is none, then it will be a model file patching + trigger: ModelPatcherTrigger = None + + # takes in the torch module to build the forward. + # will be helpful to + # - do any pre-modification on the torch module + + # this is mutually exclusive from forward_builder + forward: ModelForward = None + + # returns either + # - a callable, which will be patched on the triggered module + # - a list of trigger-forward tuples + forward_builder: Callable[ + [torch.nn.Module], + Union[ModelForward, List[Tuple[ModelPatcherTrigger, ModelForward]]], + ] = None + + # if specified, these will be passed on frrom ModelPatcher.patch + # (if they exist) + forward_builder_args: List[str] = None + + # this is mutually exclusive from forward and forward builder + import_and_maybe_reload: Tuple[ + str, # path to the object to be patched (e.g., 'torch.nn.CrossEntropyLoss') + Type, # replacement object (e.g., FastCrossEntropyLoss) + Optional[ + str + ], # path to module to be reloaded (e.g., transformers.models.llama.modeling_llama) + ] = None + + def __post_init__(self): + if ( + self.forward is not None + and self.forward_builder is not None + and self.import_and_maybe_reload is not None + ): + raise ValueError( + f"Rule '{self.rule_id}' must only have only one of forward, " + "foward builder, or import_and_maybe_reload, specified." + ) + + if self.import_and_maybe_reload is not None and self.trigger is not None: + raise ValueError( + f"Rule '{self.rule_id}' has import_and_maybe_reload specified, " + "and trigger must be None." + ) + + if self.forward_builder_args is not None and self.forward_builder is None: + raise ValueError( + f"Rule '{self.rule_id}' has forward_builder_args but no " + "forward_builder." + ) + + +# helpful to keep a history of all patching that has been done +@dataclass +class ModelPatcherHistory: + # instance id of the class that was patched + instance: int + + # class of the torch.nn.Module that was patched + cls: str + + # parent class of the torch.nn.Module that was patched + parent_cls: str + + # module name + module_name: str + + # parent + parent_module_name: str + + # name of the rule that was applied + rule_id: str + + +# singleton class for patching models +class ModelPatcher: + + # singleton history of patches + history: List[ModelPatcherHistory] = [] + + # singleton list of rules that have been registered + rules: Dict[str, ModelPatcherRule] = {} + + @staticmethod + def load_patches(module_names: List[str], reload: bool = False): + # each patch should be in a module that calls + # ModelPatcher.register. So these will search + # and load all the modules it can find + + # reload will trigger the register in that module + for plugin_name in module_names: + if importlib.util.find_spec(plugin_name): + m = importlib.import_module(plugin_name) + + # attempt a reload of imported patch modules if requested + # NOTE: but this is brittle as triggering on instance types is + # not robust to reloading + if reload: + try: + importlib.reload(m) + except AssertionError: + # this is if it was loaded already + pass + + @staticmethod + def register(rule: ModelPatcherRule): + # raise if added rule in duplicity + assert ( + rule.rule_id not in ModelPatcher.rules + ), f"patch rule '{rule.rule_id}' already exists" + + ModelPatcher.rules[rule.rule_id] = rule + + @staticmethod + def did_rule_trigger(module: torch.nn.Module, module_name: str): + for name, rule in ModelPatcher.rules.items(): + + # if there is no trigger + if rule.trigger is None: + continue + + if rule.trigger.is_triggered(module, module_name): + return name, rule + + return None, None + + @staticmethod + def _import_and_reload(model: torch.nn.Module): + # each rule.import_and_maybe_reload is a triple + # - path to be patched + # - replacement object + # - path to be reloaded + + # USE CASE 1: + # from a import A # <- want to replace A by A_patched + # def func(): + # obj = A() + + # USE CASE 2: + # from a import + # def A(): # <- want to replace A by A_patched + # ... + + # for 1: requires a reload of the func def. + # - the patch of A does not need to be perm + # for 2: just requires a patch of a.A. + # - the patch of a.A needs to be perm + # - once a.A has been patched, 'a' cannot be reloaded + + # so for simplicity: + # - only allow a single reload + # - this is to allow the reload to happen first + # - any forward patches that happen after / before + # this import and reload should not be affected + + # (a more advanced version could be considered) + # targets that have a reload path as a prefix, then + # the reload path happens first + + # this will be the path to the module + module_path = model.__module__ + + # activate the one time rules (i.e. those with no trigger) + _with_reload = [] + _no_reload = [] + for rule in ModelPatcher.rules.values(): + if rule.import_and_maybe_reload is not None: + _target, _, _reload = rule.import_and_maybe_reload + if _reload and _reload.startswith(module_path): + _with_reload.append(rule) + elif _target.startswith(module_path): + _no_reload.append(rule) + + assert len(_with_reload) <= 1, "cannot have have at most one rule with reload" + + # handle those with reload first + for rule in _with_reload + _no_reload: + _target, _object, _reload = rule.import_and_maybe_reload + _patch_target_module(_target, _object, _reload) + ModelPatcher.history.append( + ModelPatcherHistory( + instance=id(model), + cls=model.__class__.__name__, + parent_cls="", + module_name="", + parent_module_name="", + rule_id=rule.rule_id, + ) + ) + + @staticmethod + def _patch_forwards( + model: torch.nn.Module, + patch_kwargs: Dict = None, + visited: Set = None, + parent_prefix: str = None, + parent_mcn: str = None, + ): + # NOTE: should we avoid repatching of the forwards + + if patch_kwargs is None: + patch_kwargs = {} + + if visited is None: + visited = set() + + for name, mod in model.named_modules(): + + # some stats + mod_id = id(mod) + mod_class_name = mod.__class__.__name__ + name = name.split(".") + if len(name) > 2: + parent_module_name, module_name = ".".join(name[:-1]), name[-1] + parent_mod = model.get_submodule(parent_module_name) + parent_mod_class_name = parent_mod.__class__.__name__ + else: + # patching on model itself + module_name = name[0] + parent_mod_class_name = parent_module_name = "" + if parent_prefix is not None: + parent_module_name = parent_prefix + "." + parent_module_name + if parent_mcn is not None: + parent_mod_class_name = parent_mcn + + rule_id, rule = ModelPatcher.did_rule_trigger(mod, module_name) + if rule_id is None: + continue + + # otherwise triggered + if rule.forward is not None: + forward = rule.forward + else: + fba = {} + if rule.forward_builder_args is not None: + fba = { + k: w + for k, w in patch_kwargs.items() + if rule.forward_builder_args + } + forward = rule.forward_builder(mod, **fba) + + if isinstance(forward, list): + # this will be list of tuples case + + # will descend down but + # - clear old rules + # - replace new rules + old_rules = ModelPatcher.rules + ModelPatcher.rules = {} + for i, (trig, forw) in enumerate(forward): + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{rule_id}-{i+1}", + trigger=trig, + forward=forw, + ) + ) + + # this is an isolated patch + ModelPatcher.patch( + mod, + patch_kwargs=patch_kwargs, + visited=visited, + parent_prefix=parent_module_name, + parent_mcn=parent_mod_class_name, + ) + + # replace the rules + ModelPatcher.rules = old_rules + + # done + continue + + # otherwise + mod.forward = MethodType(forward, mod) + ModelPatcher.history.append( + ModelPatcherHistory( + instance=mod_id, + cls=mod_class_name, + parent_cls=parent_mod_class_name, + module_name=module_name, + parent_module_name=parent_module_name, + rule_id=rule_id, + ) + ) + + @staticmethod + def patch(model: torch.nn.Module, **kwargs): + # NOTE: for a set of rules, this patch function should be called + # only once. We do not have any checks for this at the moment + try: + ModelPatcher._import_and_reload(model.get_base_model()) + except AttributeError: + ModelPatcher._import_and_reload(model) + + # this will patch the forwards + ModelPatcher._patch_forwards(model, patch_kwargs=kwargs) + + @staticmethod + def summary(raw: bool = False): + df = pd.DataFrame([asdict(entry) for entry in ModelPatcher.history]) + if raw: + return df + + if len(df) == 0: + return "" + + # summarize and return string + df = ( + df.groupby(["rule_id", "module_name", "cls"])["instance"] + .count() + .reset_index() + ) + result = [] + result.append("***************** Module Forwards Patching *************") + for x in df.to_dict("records"): + result.append( + "Rule: {0:15s} Module: {1:25s} Class: {2:15s} Num: {3:2d}".format( + x["rule_id"], x["module_name"], x["cls"], x["instance"] + ) + ) + + return "\n".join(result) + + +# ------------------------ function ----------------------- + + +def patch_model(model: torch.nn.Module, **kwargs): + ModelPatcher.patch(model, **kwargs) + return model + + +def patch_model_summary(): + return ModelPatcher.summary() diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py new file mode 100644 index 00000000..2840aacd --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -0,0 +1,164 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Callable, List, Type + +# Third Party +import torch + +# Local +# GPTQ imports +from ..fused_ops.unsloth_lora.gptq.fast_lora import LoRA_W as LoRA_W_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import ( + get_lora_parameters as get_lora_parameters_gptq, +) +from ..fused_ops.unsloth_lora.gptq.fast_lora import unpack_gptqstate +from .model_patcher import ModelPatcherTrigger + + +# simple utility function to guess if its lora layer +def _is_loralayer(module: torch.nn.Module, names: List[str] = None): + if names is None: + names = ["lora_A", "lora_B", "base_layer"] + return all(hasattr(module, x) for x in names) + + +# builds a triple of forward functions, that each can be attached +# on a series of QKV's, where if the first one is called, will call the +# fused op +# NOTE: this is not thread-safe (issue warning?) +# NOTE: the unsloth fused_operation "apply_lora_qkv" assumes that the +# modules are called q_proj, k_proj, and v_proj, respectively. +# the fused operation can be changed, depending on what the base layer is +# i.e. gptq or bnb +def _build_qkv_forwards( + attn: torch.nn.Module, + fused_operation: Callable = fused_op_qkv_gptq, + module_names: List[str] = None, +): + if module_names is None: + module_names = ["q_proj", "k_proj", "v_proj"] + + Q = K = V = None + + # the fused operation will be called on first one that passes in the + # input X. + # - populates the triple Q, K, V + # - subsequent calls will be a no-op until ALL Q, K, V get reset to None + def _fused_op(X): + nonlocal Q, K, V + if Q is None and K is None and V is None: + Q, K, V = fused_operation(attn, X) + + # each of these functions + # - calls the fused op + # - + error_msg = ( + "QKV fused_op needs to be first reset with sequential calls to each of them" + ) + + def _forward_q(self, X): + nonlocal Q + _fused_op(X) + assert Q is not None, error_msg + out, Q = Q, None # unload + return out + + def _forward_k(self, X): + nonlocal K + _fused_op(X) + assert K is not None, error_msg + out, K = K, None # unload + return out + + def _forward_v(self, X): + nonlocal V + _fused_op(X) + assert V is not None, error_msg + out, V = V, None # unload + return out + + return zip(module_names, [_forward_q, _forward_k, _forward_v]) + + +# fused ops for outputs for GPTQ +def fused_op_o_gptq(self, X): + Oqstate, OA, OB, OS = get_lora_parameters_gptq(self) + O = LoRA_W_gptq.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) + return O + + +# TODO: add the MLP +def build_lora_fused_ops( + attn: torch.nn.Module, + base_type: str = "auto_gptq", + qkv_module_names: List[str] = None, + o_module_name: str = "o_proj", +): + if qkv_module_names is None: + qkv_module_names = ["q_proj", "k_proj", "v_proj"] + + # handle the QKVs + if base_type == "auto_gptq": + _qkv_fused_op = fused_op_qkv_gptq + _o_fused_op = fused_op_o_gptq + else: + raise NotImplementedError( + f"Cannot build fused ops for base type '{base_type}'." + ) + + trigger_and_forwards = [ + (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward) + for name, forward in _build_qkv_forwards( + attn, + fused_operation=_qkv_fused_op, + module_names=qkv_module_names, + ) + ] + + # handle the self-attn output + _output_module = getattr(attn, o_module_name) + if _is_loralayer(_output_module): + trigger_and_forwards.append( + ( + ModelPatcherTrigger(check=_is_loralayer, module_name=o_module_name), + _o_fused_op, + ) + ) + + # return + return trigger_and_forwards + + +# trigger if either of the conditions are met +# 1. qkv all have LoRA adapters for a fused op +# 2. o has a lora adapter for the fused op +def trigger_fused_ops( + module: torch.nn.Module, + attn_cls: Type, + qkv_module_names: List[str] = None, + o_module_name: str = "o_proj", +): + if qkv_module_names is None: + qkv_module_names = ["q_proj", "k_proj", "v_proj"] + + _o = getattr(module, o_module_name) + _qkv = [getattr(module, x) for x in qkv_module_names] + + # trigger on the attention layer + return isinstance(module, attn_cls) and ( + all(_is_loralayer(x) for x in _qkv) or _is_loralayer(_o) + ) diff --git a/plugins/fused-ops-and-kernels/tests/__init__.py b/plugins/fused-ops-and-kernels/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/fused-ops-and-kernels/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py new file mode 100644 index 00000000..dd7b472d --- /dev/null +++ b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py @@ -0,0 +1,84 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +import os + +# Third Party +from fms_acceleration import AccelerationPluginConfigError +from fms_acceleration.utils import ( + instantiate_framework, + read_configuration, + update_configuration_contents, +) +import pytest # pylint: disable=import-error + +# instantiate_fromwork will handle registering and activating AutoGPTQAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH_AUTO_GPTQ_FOAK = os.path.join( + DIRNAME, "../configs/fast_quantized_peft.yaml" +) + + +def test_configure_gptq_foak_plugin(): + "test foak plugin loads correctly" + + # test that provided configuration correct correct instantiates plugin + with instantiate_framework( + read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK), require_packages_check=False + ) as framework: + + # check flags and callbacks + assert framework.requires_custom_loading is False + assert framework.requires_agumentation + assert len(framework.get_callbacks_and_ready_for_train()) == 0 + + # attempt to activate plugin with configuration pointing to wrong path + # - raise with message that no plugins can be configured + with pytest.raises(ValueError) as e: + with instantiate_framework( + update_configuration_contents( + read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK), + "peft.quantization.fused_ops_and_kernels", + "something", + ), + ): + pass + + e.match("No plugins could be configured") + + # NOTE: currently only have all-or-one until address the generic patching + # rules + # attempt to actiavte plugin with unsupported settings + # - raise with appropriate message complaining about wrong setting + for key, wrong_value in [ + ("peft.quantization.fused_ops_and_kernels.fused_lora", False), + ("peft.quantization.fused_ops_and_kernels.fast_loss", False), + ("peft.quantization.fused_ops_and_kernels.fast_rsm_layernorm", False), + ("peft.quantization.fused_ops_and_kernels.fast_rope_embeddings", False), + ]: + with pytest.raises(AccelerationPluginConfigError) as e: + with instantiate_framework( + update_configuration_contents( + read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK), key, wrong_value + ), + ): + pass + + e.match(f"FastQuantizedPeftAccelerationPlugin: Value at '{key}'") diff --git a/plugins/fused-ops-and-kernels/tox.ini b/plugins/fused-ops-and-kernels/tox.ini new file mode 100644 index 00000000..8b6e5930 --- /dev/null +++ b/plugins/fused-ops-and-kernels/tox.ini @@ -0,0 +1,42 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + -e {toxinidir}/../framework +commands = pytest {posargs:tests} + +[testenv:lint] +description = run linters +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 +commands = pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + # exclude the code ported from unsloth + black --exclude .*unsloth.* src + black --exclude .*unsloth.* tests + isort . + +# [testenv:build] +# description = build wheel +# deps = +# build +# commands = python -m build -w +# skip_install = True +# +# [testenv:twinecheck] +# description = check wheel +# deps = +# twine +# commands = twine check dist/* +# skip_install = True diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 8d45bedf..c43a5adf 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -19,4 +19,10 @@ framework_configs: - shortname: baseline-peft-bnb plugins: - accelerated-peft - filename: baseline-peft-bnb-nf4-sample-configuration.yaml \ No newline at end of file + filename: baseline-peft-bnb-nf4-sample-configuration.yaml + + - shortname: accelerated-peft-autogptq-foak + plugins: + - accelerated-peft + - fused-ops-and-kernels + filename: accelerated-peft-autogptq-foak-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml new file mode 100644 index 00000000..1eb38df3 --- /dev/null +++ b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml @@ -0,0 +1,44 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # AutoGPTQ quantized base weights. + auto_gptq: + + # Kernel to be used for GPTQ linear laeyer + # NOTE: Not all kernels are suitable for PEFT training; need to use + # kernels that support autograd forward / backward. The best + # recommendation at the moment is "triton_v2". + kernel: triton_v2 + + # If true, then will already expect quantized checkpoint + # passed into TrainingArguments.model_name_or_path + from_quantized: true + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: auto_gptq + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 2594089d..651b227a 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -582,7 +582,7 @@ class DryRunExperiment(Experiment): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def run(self, run_cmd: str, environment_variables: Dict = None): + def run(self, run_cmd: str, environment_variables: Dict = None, **kwargs): def _dummy(*args, **kwargs): pass diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index 248eacb2..ae62930d 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -68,7 +68,7 @@ scenarios: - name: accelerated-peft-gptq framework_config: - - accelerated-peft-autogptq + - accelerated-peft-autogptq-foak arguments: learning_rate: 2e-4 fp16: True diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 67ad4058..fd51d965 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -139,9 +139,10 @@ def read_configuration(path: str) -> Dict: # # NOTE: an augmentation (path, value) will augment a config at the # specified key path, with the value. -KEY_AUTO_GPTQ = "auto_gptq" +KEY_AUTO_GPTQ = "auto-gptq" KEY_BNB_NF4 = "bnb-nf4" KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" +KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -156,6 +157,10 @@ def read_configuration(path: str) -> Dict: ("peft.quantization.bitsandbytes.no_peft_model", True), ], ), + KEY_AUTO_GPTQ_FOAK: ( + "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", + [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")], + ), } # list of (tag, combi) tuples @@ -167,19 +172,24 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)), ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)), ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), + ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ] - # TODO: throw error if merge conflicts def merge_configs(config_contents: List[Dict]): "helper function to merge configuration contents." # merge in place def _merge(result: Dict, new_contents: Dict): - for k in new_contents: + for k, v in new_contents.items(): if k not in result: - result[k] = {} - _merge(result[k], new_contents) + # if k is not in result, it means v does not + # exist as a subtree under result, so we just do + # an assingment + result[k] = v + else: + # otherwise we call the merge + _merge(result[k], v) if len(config_contents) == 0: return {} diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 8cbd8587..798138bf 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -98,6 +98,8 @@ elif [ "$MEMORY_LOGGING" = "all" ]; then fi # dump out the environment +echo "Creating $RESULT_DIR" +mkdir -p $RESULT_DIR pip freeze > $PIP_REQUIREMENTS_FILE # run the bench diff --git a/tox.ini b/tox.ini index d719cb3e..e8d8aa92 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ commands = # install the plugins for test # NOTE: when there are more plugins install here python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft + python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" benchmark_outputs} From b2b8fe66bf4490367f36cb17b466d255c3d2e864 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 30 May 2024 16:37:26 +0800 Subject: [PATCH 5/8] fix fsdp casting issue for autogptq and fused ops (#28) Signed-off-by: Yu Chin Fabian Lim --- .../fms_acceleration_peft/autogptq_utils.py | 48 ++++++++++++------- .../framework_plugin_autogptq.py | 8 ++-- .../src/fms_acceleration/framework.py | 11 +++-- .../src/fms_acceleration_foak/models/utils.py | 39 +++++++++++++++ scripts/benchmarks/scenarios.yaml | 1 + 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index 31fc9a74..b8a7558d 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -24,6 +24,9 @@ from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ import torch +# these parameters are to be patched for triton v2 +# consider making a map if patching more kernels +PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] # This function may be moved after merging # https://github.com/foundation-model-stack/fms-acceleration/pull/25 @@ -120,34 +123,47 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module - # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, attribute_names: List[str], - torch_dtype, + torch_dtype: torch.dtype, + submodule_names: str = None, + is_method_forward: bool = True, ): # patch old_forward to view attribtues to torch_dype # before call + + if submodule_names is None: + submodule_names = '' + if isinstance(submodule_names, str): + submodule_names = [submodule_names] def _forward(self, *args, **kwargs): - # perform a view on all these attributes - for attr_name in attribute_names: - # the view should be a passthrough - # if attr.dtype == torch_dtype - attr = getattr(self, attr_name) + for sub_name in submodule_names: + mod = self.get_submodule(sub_name) + + # perform a view on all these attributes + for attr_name in attribute_names: + + # the view should be a passthrough + # if attr.dtype == torch_dtype + attr = getattr(mod, attr_name) - # perform view - attr = attr.view(torch_dtype) + # perform view + attr = attr.view(torch_dtype) - try: - setattr(self, attr_name, attr) - except TypeError: - # this means already have attr_name as a parameter, then - # just assign this way - self.__dict__[attr_name] = attr + try: + setattr(mod, attr_name, attr) + except TypeError: + # this means already have attr_name as a parameter, then + # just assign this way + mod.__dict__[attr_name] = attr - return old_forward(*args, **kwargs) + if is_method_forward: + # in this case, the self is already bound + return old_forward(*args, **kwargs) + return old_forward(self, *args, **kwargs) return _forward diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 62e7abe3..30492a2b 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -55,8 +55,9 @@ def model_loader(self, model_name: str, **kwargs): from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error # Local - from .autogptq_utils import ( # pylint: disable=import-outside-toplevel - patch_forward_to_view_attributes_before_call, + from .autogptq_utils import ( #pylint: disable=import-outside-toplevel + patch_forward_to_view_attributes_before_call, + PATCH_FOR_FSDP_TRITON_V2 ) # Currently we allow only a quantized checkpoint to be loaded, we do not @@ -159,9 +160,6 @@ def model_loader(self, model_name: str, **kwargs): world_size > 1 and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" ): - # these parameters are to be patched for triton v2 - # consider making a map if patching more kernels - PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] # patch all the QuantLinear base layers for mod in model.modules(): diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 6d545ac7..ff83dd0c 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -179,11 +179,12 @@ def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None ): # show the initialized message - log_initialization_message( - {x for x, _ in self.active_plugins}, - PLUGIN_REGISTRATIONS, - logging_func=logger.info, - ) + if accelerator is not None and accelerator.is_main_process: + log_initialization_message( + {x for x, _ in self.active_plugins}, + PLUGIN_REGISTRATIONS, + logging_func=logger.info, + ) cbks = [] for _, plugin in self.active_plugins: diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index 2840aacd..b048b8e4 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -17,6 +17,7 @@ # Third Party import torch +import os # Local # GPTQ imports @@ -59,6 +60,7 @@ def _build_qkv_forwards( # - populates the triple Q, K, V # - subsequent calls will be a no-op until ALL Q, K, V get reset to None def _fused_op(X): + nonlocal Q, K, V if Q is None and K is None and V is None: Q, K, V = fused_operation(attn, X) @@ -115,6 +117,43 @@ def build_lora_fused_ops( if base_type == "auto_gptq": _qkv_fused_op = fused_op_qkv_gptq _o_fused_op = fused_op_o_gptq + + # this is required due to this FSDP fix + # https://github.com/foundation-model-stack/fms-acceleration/pull/15 + try: + world_size = torch.distributed.get_world_size() + except ValueError: + world_size = 1 # pg not init + + if ( + world_size > 1 + and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" + ): + + # guarded import + from fms_acceleration_peft.autogptq_utils import ( #pylint: disable=import-outside-toplevel + patch_forward_to_view_attributes_before_call, + PATCH_FOR_FSDP_TRITON_V2 + ) + + # patch each of the fused ops to view the attributes + # back into torch.int32 + # TODO: add the MLP fused op also + _qkv_fused_op = patch_forward_to_view_attributes_before_call( + _qkv_fused_op, + PATCH_FOR_FSDP_TRITON_V2, torch.int32, + submodule_names=[ + n + '.base_layer' for n in qkv_module_names + ], + is_method_forward=False, + ) + _o_fused_op = patch_forward_to_view_attributes_before_call( + _o_fused_op, + PATCH_FOR_FSDP_TRITON_V2, torch.int32, + submodule_names='base_layer', + is_method_forward=False, + ) + else: raise NotImplementedError( f"Cannot build fused ops for base type '{base_type}'." diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index ae62930d..c935ac31 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -68,6 +68,7 @@ scenarios: - name: accelerated-peft-gptq framework_config: + - accelerated-peft-autogptq - accelerated-peft-autogptq-foak arguments: learning_rate: 2e-4 From 810323822cf38f1f1a1ecae5363555cd75636714 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 2 Jun 2024 14:46:07 +0800 Subject: [PATCH 6/8] Add MLP & QLoRA Fused Ops and Kernels, Mixtral (#29) * refactor Signed-off-by: Yu Chin Fabian Lim * fixes Signed-off-by: Yu Chin Fabian Lim * refactor mistral Signed-off-by: Yu Chin Fabian Lim * add mixtral Signed-off-by: Yu Chin Fabian Lim * some refactoring after introducing mlp Signed-off-by: Yu Chin Fabian Lim * remove extranous files Signed-off-by: Yu Chin Fabian Lim * add bnb Signed-off-by: Yu Chin Fabian Lim * lint + fmt and improvements to readme Signed-off-by: Yu Chin Fabian Lim * bench fixes * need to handle lora adapters device due to #26 * allow replay of failed benches, addressing comment in #14 * update benches (remove l40) --------- Signed-off-by: Yu Chin Fabian Lim --- README.md | 3 +- .../fms_acceleration_peft/autogptq_utils.py | 6 +- .../framework_plugin_autogptq.py | 25 +- plugins/fused-ops-and-kernels/README.md | 36 +-- .../framework_plugin_fast_quantized_peft.py | 21 +- .../fused_ops/unsloth_lora/bnb/fast_lora.py | 7 + .../fused_ops/unsloth_lora/gptq/fast_lora.py | 7 + .../kernels/unsloth/rope_embedding.py | 5 +- .../fms_acceleration_foak/models/__init__.py | 2 +- .../src/fms_acceleration_foak/models/llama.py | 62 ++++- .../fms_acceleration_foak/models/mistral.py | 72 ++++-- .../fms_acceleration_foak/models/mixtral.py | 104 ++++++++ .../models/model_patcher.py | 25 ++ .../src/fms_acceleration_foak/models/utils.py | 229 +++++++++--------- sample-configurations/CONTENTS.yaml | 8 +- ...eft-bnb-nf4-foak-sample-configuration.yaml | 44 ++++ scripts/benchmarks/benchmark.py | 26 +- scripts/benchmarks/display_bench_results.py | 40 ++- scripts/benchmarks/refs/a100_80gb.csv | 143 ++++++----- scripts/benchmarks/refs/l40_40gb.csv | 49 ---- scripts/benchmarks/scenarios.yaml | 3 +- scripts/generate_sample_configurations.py | 15 +- scripts/run_benchmarks.sh | 20 +- 23 files changed, 626 insertions(+), 326 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py create mode 100644 sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml delete mode 100644 scripts/benchmarks/refs/l40_40gb.csv diff --git a/README.md b/README.md index a7534ed1..707c8662 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status --|--|--|--|-- [framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta [accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Beta -[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon +[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon ## Usage with FMS HF Tuning @@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark See below CSV files for various results: - [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv). -- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv). ### Code Architecture diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index b8a7558d..913a6b7e 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -28,6 +28,7 @@ # consider making a map if patching more kernels PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] + # This function may be moved after merging # https://github.com/foundation-model-stack/fms-acceleration/pull/25 def _patch_target_module( @@ -123,6 +124,7 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module + # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, @@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call( ): # patch old_forward to view attribtues to torch_dype # before call - + if submodule_names is None: - submodule_names = '' + submodule_names = "" if isinstance(submodule_names, str): submodule_names = [submodule_names] diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 30492a2b..7928d9a9 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]): def model_loader(self, model_name: str, **kwargs): # guarded imports # Third Party - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error + from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error + AutoGPTQForCausalLM, + BaseQuantizeConfig, + ) + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error + QuantLinear, + ) # Local - from .autogptq_utils import ( #pylint: disable=import-outside-toplevel - patch_forward_to_view_attributes_before_call, - PATCH_FOR_FSDP_TRITON_V2 + from .autogptq_utils import ( # pylint: disable=import-outside-toplevel + PATCH_FOR_FSDP_TRITON_V2, + patch_forward_to_view_attributes_before_call, ) # Currently we allow only a quantized checkpoint to be loaded, we do not @@ -214,8 +219,14 @@ def augmentation( ): # guarded imports # Third Party - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error - from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error + QuantLinear, + ) + from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error + GPTQLoraModel, + get_gptq_peft_model, + ) + # Local from .autogptq_utils import ( # pylint: disable=import-outside-toplevel create_new_module_peft, diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md index a1b01d94..a1777671 100644 --- a/plugins/fused-ops-and-kernels/README.md +++ b/plugins/fused-ops-and-kernels/README.md @@ -3,7 +3,7 @@ This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following: -1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth). +1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth). - Low-Rank Adapter Fused Operations - Fast RoPE Triton Kernels - Fast RMS LayerNorm Triton Kernels @@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t Plugin | Description | Depends | Loading | Augmentation | Callbacks --|--|--|--|--|-- -[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅ +[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅ ### Code Extracted from Unsloth - Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth): -- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143). +- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143). ``` - it would require a commercial license if used to run on more than 4 GPUs, see - https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143 + it would require a commercial license if used to run on more than 4 GPUs ... ``` -- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc). - * These model files are **not extracted**. -- All code extracted here before the Feb 2024 Release, see table below. +- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183). +- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**. +- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**. +- All extracted code appears before the Feb 2024 Release. +- In the table below we record what was extracted, and the exact commit from which it was taken. Path | Description | Extracted From | Modifications | Date --|--|--|--|-- [fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 -[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 +[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024 [fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024 -[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024 - - - +[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`
`rms_layernorm.py` | 28 Jan 2024 ## Known Issues diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index ad0a399c..7eab87f0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -16,6 +16,7 @@ from typing import Callable, Dict, Tuple # Third Party +from accelerate.utils import set_module_tensor_to_device from fms_acceleration import AccelerationPlugin from peft import LoraConfig from peft.tuners.lora.layer import LoraLayer @@ -63,9 +64,20 @@ def _all_reduce_hook(grad): return grad for mod in modules: + # NOTE: assuming lora has no bias + A = mod.lora_A.default + B = mod.lora_B.default + # install hooks on the adapters - mod.lora_A.default.weight.register_hook(_all_reduce_hook) - mod.lora_B.default.weight.register_hook(_all_reduce_hook) + A.weight.register_hook(_all_reduce_hook) + B.weight.register_hook(_all_reduce_hook) + + # because we will ignore these from FSDP, we need to manually + # move them to gpu if they are already not on them + if not A.weight.is_cuda: + set_module_tensor_to_device(A, "weight", "cuda") + if not B.weight.is_cuda: + set_module_tensor_to_device(B, "weight", "cuda") class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin): @@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]): self._base_layer = self._check_config_and_maybe_check_values( key="peft.quantization.fused_ops_and_kernels.base_layer", - values=[ - "auto_gptq", - # "bitsandbytes" # enable later when we have BNB implemented - ], + values=["auto_gptq", "bitsandbytes"], ) # only support these at the moment diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 82f78f74..71d7070c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -394,3 +394,10 @@ def apply_lora_o(self, X): O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) return O pass + +# added by flim@sg.ibm.com +# this will be patchable on the actual module +def apply_lora_o_v2(self, X): + OW, OW_quant, OA, OB, OS = get_lora_parameters(self) + O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) + return O \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index 3808fba7..ee5055ed 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -735,3 +735,10 @@ def apply_lora_o(self, X): Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj) O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) return O + +# added by flim@sg.ibm.com +# this version can be directly patched on the output linear +def apply_lora_o_v2(self, X): + Oqstate, OA, OB, OS = get_lora_parameters(self) + O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) + return O diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py index 49b04fce..3577b586 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -130,8 +130,9 @@ def backward(ctx, dY): pass pass - -def fast_rope_embedding(Q, K, cos, sin): +# modified by flim@sg.ibm.com +# NOTE: fast_rope embeddings currently does not account for position ids +def fast_rope_embedding(Q, K, cos, sin, position_ids=None): Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) return Q, K diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py index 7d6df3bc..ebd49924 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py @@ -15,7 +15,7 @@ # Local from .model_patcher import ModelPatcher -PATCHES = [".models.llama", ".models.mistral"] +PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"] PLUGIN_PREFIX = "fms_acceleration_foak" # TODO: remove the need for the prefix diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 3d01311a..290d1217 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -16,14 +16,24 @@ from functools import partial # Third Party -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaMLP, + LlamaRMSNorm, +) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger -from .utils import build_lora_fused_ops, trigger_fused_ops +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops # TODO: have a generic version of this rule # - do regex on RMSNorm class name @@ -42,18 +52,54 @@ ModelPatcher.register( ModelPatcherRule( rule_id="llama-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-mlp", trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=LlamaAttention, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + attn_cls=LlamaMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), forward_builder=partial( build_lora_fused_ops, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, ), forward_builder_args=["base_type"], ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index a8e6795f..37809fd1 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -18,22 +18,22 @@ # Third Party from transformers.models.mistral.modeling_mistral import ( MistralAttention, + MistralMLP, MistralRMSNorm, ) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm -from ..kernels.unsloth.rope_embedding import fast_rope_embedding as _fast_rope_embedding -from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger -from .utils import build_lora_fused_ops, trigger_fused_ops - - -# NOTE: fast_rope_embedding does not work with position_ids -# currently they are ignored -def fast_rope_embedding(Q, K, cos, sin, position_ids=None): - return _fast_rope_embedding(Q, K, cos, sin) - +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops # - do regex on RMSNorm class name # - check on the tensors required for fast_rms_layernorm @@ -45,29 +45,62 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None): ), ) -# - do regex on Attention class name -# - have a set of qkv / o module names and check on that ModelPatcher.register( ModelPatcherRule( rule_id="mistral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-mlp", trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=MistralAttention, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + attn_cls=MistralMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), forward_builder=partial( build_lora_fused_ops, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, ), forward_builder_args=["base_type"], ) ) -# - get the module_name and reload on that ModelPatcher.register( ModelPatcherRule( rule_id="mistral-cross-ent", @@ -79,9 +112,6 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None): ) ) -# - get the module name -# - check if "apply_rotary_pos_emb" exists -# - patch ModelPatcher.register( ModelPatcherRule( rule_id="mistral-rope", diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py new file mode 100644 index 00000000..1522ef8d --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -0,0 +1,104 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralRMSNorm, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + +# - do regex on RMSNorm class name +# - check on the tensors required for fast_rms_layernorm +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-rms", + trigger=ModelPatcherTrigger(check=MixtralRMSNorm), + forward=fast_rms_layernorm, + ), +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mixtral.modeling_mixtral", + ), + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-rope", + import_and_maybe_reload=( + "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py index 3355aa67..7f803330 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py @@ -468,3 +468,28 @@ def patch_model(model: torch.nn.Module, **kwargs): def patch_model_summary(): return ModelPatcher.summary() + + +def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"): + assert logic == "OR", "only OR logic implemented for combining triggers" + + # NOTE: this can be probably simplified + def _or_logic(*args, **kwargs): + for trig in triggers: + if trig.check(*args, **kwargs): + return True + return False + + return ModelPatcherTrigger(check=_or_logic) + + +def combine_functions(*funcs: Callable, logic: str = "APPEND"): + assert logic == "APPEND", "only APPEND logic implemented for combining functions" + + def _append(*args, **kwargs): + results = [] + for f in funcs: + results += f(*args, **kwargs) + return results + + return _append diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index b048b8e4..10819fc0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -1,34 +1,40 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Standard +from functools import partial from typing import Callable, List, Type +import os # Third Party import torch -import os # Local -# GPTQ imports -from ..fused_ops.unsloth_lora.gptq.fast_lora import LoRA_W as LoRA_W_gptq -from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq -from ..fused_ops.unsloth_lora.gptq.fast_lora import ( - get_lora_parameters as get_lora_parameters_gptq, +# NOTE: the default activation is swiglu in both cases +from ..fused_ops.unsloth_lora.bnb.fast_lora import ( + apply_lora_mlp_swiglu as fused_op_mlp_bnb, ) -from ..fused_ops.unsloth_lora.gptq.fast_lora import unpack_gptqstate +from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_o_v2 as fused_op_o_bnb +from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_qkv as fused_op_qkv_bnb +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq from .model_patcher import ModelPatcherTrigger +KEY_QKV = "qkv" +KEY_O = "o" +KEY_MLP = "mlp" + +FUSED_OPS = { + "auto_gptq": { + KEY_QKV: fused_op_qkv_gptq, + KEY_O: fused_op_o_gptq, + KEY_MLP: fused_op_mlp_gptq, + }, + "bitsandbytes": { + KEY_QKV: fused_op_qkv_bnb, + KEY_O: fused_op_o_bnb, + KEY_MLP: fused_op_mlp_bnb, + }, +} + # simple utility function to guess if its lora layer def _is_loralayer(module: torch.nn.Module, names: List[str] = None): @@ -45,15 +51,15 @@ def _is_loralayer(module: torch.nn.Module, names: List[str] = None): # modules are called q_proj, k_proj, and v_proj, respectively. # the fused operation can be changed, depending on what the base layer is # i.e. gptq or bnb -def _build_qkv_forwards( +def _build_fused_forwards( attn: torch.nn.Module, fused_operation: Callable = fused_op_qkv_gptq, - module_names: List[str] = None, + submodule_names: List[str] = None, ): - if module_names is None: - module_names = ["q_proj", "k_proj", "v_proj"] + # fused opts expected to produce singular or multiple results + # module names must be passed in order of what the fused - Q = K = V = None + outs = {} # the fused operation will be called on first one that passes in the # input X. @@ -61,62 +67,52 @@ def _build_qkv_forwards( # - subsequent calls will be a no-op until ALL Q, K, V get reset to None def _fused_op(X): - nonlocal Q, K, V - if Q is None and K is None and V is None: - Q, K, V = fused_operation(attn, X) + # if all of the outs are not yet populated + if all(x not in outs for x in submodule_names): + fused_outs = fused_operation(attn, X) + try: + fused_outs = list(fused_outs) # not sure if this is correct + except TypeError: + # if fused_outs is not iterable + fused_outs = [fused_outs] + for n, x in zip(submodule_names, fused_outs): + outs[n] = x # each of these functions # - calls the fused op # - - error_msg = ( - "QKV fused_op needs to be first reset with sequential calls to each of them" - ) - - def _forward_q(self, X): - nonlocal Q - _fused_op(X) - assert Q is not None, error_msg - out, Q = Q, None # unload - return out - - def _forward_k(self, X): - nonlocal K - _fused_op(X) - assert K is not None, error_msg - out, K = K, None # unload - return out - def _forward_v(self, X): - nonlocal V + def _forward(self, X, name: str): _fused_op(X) - assert V is not None, error_msg - out, V = V, None # unload - return out - - return zip(module_names, [_forward_q, _forward_k, _forward_v]) - + assert ( + name in outs + ), "Fused_op needs to be first reset with sequential calls to each of them" + V = outs[name] + del outs[name] + return V -# fused ops for outputs for GPTQ -def fused_op_o_gptq(self, X): - Oqstate, OA, OB, OS = get_lora_parameters_gptq(self) - O = LoRA_W_gptq.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) - return O + return zip(submodule_names, [partial(_forward, name=n) for n in submodule_names]) -# TODO: add the MLP def build_lora_fused_ops( attn: torch.nn.Module, base_type: str = "auto_gptq", - qkv_module_names: List[str] = None, - o_module_name: str = "o_proj", + submodule_names: List[str] = None, + fused_op: str = KEY_QKV, ): - if qkv_module_names is None: - qkv_module_names = ["q_proj", "k_proj", "v_proj"] - # handle the QKVs + assert ( + len(submodule_names) > 0 + ), "When building lora fused ops requires more than one submodule." + + if submodule_names is None: + submodule_names = ["q_proj", "k_proj", "v_proj"] + + # get the fused op + fused_operation = FUSED_OPS[base_type][fused_op] + + # handle casting issues if base_type == "auto_gptq": - _qkv_fused_op = fused_op_qkv_gptq - _o_fused_op = fused_op_o_gptq # this is required due to this FSDP fix # https://github.com/foundation-model-stack/fms-acceleration/pull/15 @@ -131,55 +127,60 @@ def build_lora_fused_ops( ): # guarded import - from fms_acceleration_peft.autogptq_utils import ( #pylint: disable=import-outside-toplevel - patch_forward_to_view_attributes_before_call, - PATCH_FOR_FSDP_TRITON_V2 + # pylint: disable=import-outside-toplevel,import-error + # Third Party + from fms_acceleration_peft.autogptq_utils import ( + PATCH_FOR_FSDP_TRITON_V2, + patch_forward_to_view_attributes_before_call, ) # patch each of the fused ops to view the attributes # back into torch.int32 - # TODO: add the MLP fused op also - _qkv_fused_op = patch_forward_to_view_attributes_before_call( - _qkv_fused_op, - PATCH_FOR_FSDP_TRITON_V2, torch.int32, - submodule_names=[ - n + '.base_layer' for n in qkv_module_names - ], - is_method_forward=False, - ) - _o_fused_op = patch_forward_to_view_attributes_before_call( - _o_fused_op, - PATCH_FOR_FSDP_TRITON_V2, torch.int32, - submodule_names='base_layer', + # - if there are multiple submodules, then we assume that + # 'fused_operation' will be called on module that has + # submodules specified in 'submodule_names'. + # - otherwise if there is only a single 'submodule_name', then + # assume that 'fused_operation' called on the submodule specified + # by 'submodule_name' itself + if len(submodule_names) > 1: + patched_submodule_names = [n + ".base_layer" for n in submodule_names] + else: + # otherwise assume calling on the 'submodule_name' itself + # so its just the base layer. + patched_submodule_names = "base_layer" + + fused_operation = patch_forward_to_view_attributes_before_call( + fused_operation, + PATCH_FOR_FSDP_TRITON_V2, + torch.int32, + submodule_names=patched_submodule_names, is_method_forward=False, ) - else: - raise NotImplementedError( - f"Cannot build fused ops for base type '{base_type}'." - ) - - trigger_and_forwards = [ - (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward) - for name, forward in _build_qkv_forwards( - attn, - fused_operation=_qkv_fused_op, - module_names=qkv_module_names, - ) - ] - - # handle the self-attn output - _output_module = getattr(attn, o_module_name) - if _is_loralayer(_output_module): - trigger_and_forwards.append( + if fused_op == KEY_QKV: + return [ + (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward) + for name, forward in _build_fused_forwards( + attn, + fused_operation=fused_operation, + submodule_names=submodule_names, + ) + ] + if fused_op == KEY_O: + # otherwise its just a single op + submodule_names = submodule_names[0] + return [ ( - ModelPatcherTrigger(check=_is_loralayer, module_name=o_module_name), - _o_fused_op, + ModelPatcherTrigger(check=_is_loralayer, module_name=submodule_names), + fused_operation, ) - ) + ] + if fused_op == KEY_MLP: + # otherwise just return the fused_op that should be attached at the + # top MLP level + return fused_operation - # return - return trigger_and_forwards + raise NotImplementedError(f"Unknown fused op '{fused_op}'") # trigger if either of the conditions are met @@ -188,16 +189,10 @@ def build_lora_fused_ops( def trigger_fused_ops( module: torch.nn.Module, attn_cls: Type, - qkv_module_names: List[str] = None, - o_module_name: str = "o_proj", + submodule_names: List[str], ): - if qkv_module_names is None: - qkv_module_names = ["q_proj", "k_proj", "v_proj"] - - _o = getattr(module, o_module_name) - _qkv = [getattr(module, x) for x in qkv_module_names] - # trigger on the attention layer - return isinstance(module, attn_cls) and ( - all(_is_loralayer(x) for x in _qkv) or _is_loralayer(_o) - ) + # trigger if the module meets the attn class and the submodules + # are all loralayers + _mods = [getattr(module, x) for x in submodule_names] + return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods) diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index c43a5adf..75f7279b 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -25,4 +25,10 @@ framework_configs: plugins: - accelerated-peft - fused-ops-and-kernels - filename: accelerated-peft-autogptq-foak-sample-configuration.yaml \ No newline at end of file + filename: accelerated-peft-autogptq-foak-sample-configuration.yaml + + - shortname: accelerated-peft-bnb-foak + plugins: + - accelerated-peft + - fused-ops-and-kernels + filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml new file mode 100644 index 00000000..fcb9bb14 --- /dev/null +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml @@ -0,0 +1,44 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: bitsandbytes + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 651b227a..f5ff4a54 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -1,5 +1,6 @@ # Standard from itertools import product +from time import sleep from typing import Any, Callable, Dict, List, Tuple, Union import argparse import json @@ -88,6 +89,7 @@ HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics" RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes" RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes" +ERROR_MESSAGES = "error_messages" def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: @@ -357,6 +359,17 @@ def __init__( self.results_filename = os.path.join(self.save_dir, FILE_RESULTS) self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM) + @property + def is_completed(self): + if not os.path.exists(self.results_filename): + return False + # otherwise open it and check for errors + with open(self.results_filename) as f: + results = json.load(f) + + # return complete only if no errors + return not ERROR_MESSAGES in results + def run( self, run_cmd: str, @@ -552,7 +565,7 @@ def write_result(self): **self.get_experiment_final_metrics(), } else: - other_results = {"error_messages": maybe_error_messages} + other_results = {ERROR_MESSAGES: maybe_error_messages} # combine the final thing save_result = {**save_result, **other_results} @@ -781,6 +794,14 @@ def main(args): log_memory_in_trainer=args.log_memory_hf, ) ): + # store pointer to file for future result retrival + experiment_stats[experiment.tag] = experiment.results_filename + + if experiment.is_completed: + # if completed, dont proceed + sleep(0.1) # sleep a bit to allow the tqdm to update + continue + if experiment.num_gpus > 1: prefix = COMMAND_ACCELERATE.format( accelerate_config_path=args.accelerate_config, @@ -806,10 +827,9 @@ def main(args): log_nvidia_smi=args.log_nvidia_smi, ) - # write results and store pointers to files + # write results experiment.write_result() experiment.write_shell_command() - experiment_stats[experiment.tag] = experiment.results_filename # 4. Consolidates the experiment results into a summary for tag, path in experiment_stats.items(): diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py index 1de9b2a5..51ba5642 100644 --- a/scripts/benchmarks/display_bench_results.py +++ b/scripts/benchmarks/display_bench_results.py @@ -1,18 +1,21 @@ # Standard +from typing import List import argparse # First Party # import this because of alot of internal contants -from scripts.benchmarks.benchmark import gather_report, DIR_SAMP_CONFIGS -from typing import List +from scripts.benchmarks.benchmark import DIR_SAMP_CONFIGS, gather_report -def main(*directories: str, output_filename: str = "results.csv", remove_columns: List[str] = None): + +def main( + *directories: str, + output_filename: str = "results.csv", + remove_columns: List[str] = None, + keep_columns: List[str] = None, +): "gather outputs from a list of directories and output to a csv" - df, constant = gather_report(*directories, raw=False) - # filter result columns to keep by the inverse of remove_columns - if remove_columns: - df = df[df.columns[~df.columns.isin(remove_columns)]] + df, constant = gather_report(directories, raw=False) errors = [] try: @@ -22,12 +25,25 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns df = df.loc[df.error_messages.isna()] except: pass + + # filter result columns to keep by the inverse of remove_columns + if remove_columns: + df = df[df.columns[~df.columns.isin(remove_columns)]] + + # assume keep and remove are disjoint + kept = 0 + if keep_columns: + for c in keep_columns: + if c in constant: + df[c] = constant[c] + kept += 1 + df = df.reset_index(drop=True).drop("output_dir", axis=1) df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False) print("***************** Report Created ******************") print(f"Total lines: '{len(df)}'") print(f"Number columns included: '{len(df.columns)}'") - print(f"Number columns excluded: '{len(constant)}'") + print(f"Number columns excluded: '{len(constant)-kept}'") print(f"Excluding number of exceptions caught: '{len(errors)}'") print(f"Written report to '{output_filename}'") @@ -53,10 +69,16 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns nargs="*", help="list of columns to ignore from results.csv", ) + parser.add_argument( + "--keep_columns", + nargs="*", + help="list of columns to always include into results.csv", + ) args = parser.parse_args() main( - args.bench_outputs, + *args.bench_outputs, output_filename=args.result_file, remove_columns=args.remove_columns, + keep_columns=args.keep_columns, ) diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv index 4434d864..b83549a7 100644 --- a/scripts/benchmarks/refs/a100_80gb.csv +++ b/scripts/benchmarks/refs/a100_80gb.csv @@ -1,61 +1,82 @@ -epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,nvidia_mem_reserved,peak_torch_mem_alloc_in_bytes,peft_method,per_device_train_batch_size,r,target_modules,torch_mem_alloc_in_bytes,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -0.04,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,77705.0,72971724288.0,,4,,,44004763136.0,0.9278398831685384,177.1092,0.678,0.169,2775.237 -0.04,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,44706.0,36762859520.0,,2,,,29521119232.0,0.8970902442932129,91.086,1.317,0.329,2698.11 -0.09,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,74383.0,72972117504.0,,8,,,44005156352.0,0.9879656155904134,322.458,0.744,0.093,3048.583 -0.09,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,53907.0,36763056128.0,,4,,,29521315840.0,0.9259945551554362,167.7727,1.431,0.179,2929.678 -,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,4,,,,,,,, -,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79353.0,,,2,,,,,,,, -,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,8,,,,,,,, -,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79827.0,,,4,,,,,,,, -,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,4,,,,,,,, -,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,80830.0,,,2,,,,,,,, -,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,8,,,,,,,, -,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,80834.5,,,4,,,,,,,, -0.04,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,29731.0,26108963328.0,lora,4,16,q_proj k_proj v_proj o_proj,15119590912.0,0.9096682230631511,136.624,0.878,0.22,3597.611 -0.04,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,18697.0,15123161088.0,lora,2,16,q_proj k_proj v_proj o_proj,7850391552.0,0.8918854713439941,82.0311,1.463,0.366,2995.936 -0.09,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,43195.0,37098695168.0,lora,8,16,q_proj k_proj v_proj o_proj,15119984128.0,0.962119706471761,270.6301,0.887,0.111,3632.412 -0.09,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,26235.0,21433753600.0,lora,4,16,q_proj k_proj v_proj o_proj,7850588160.0,0.9218235015869141,143.8184,1.669,0.209,3417.643 -,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,62617.0,57540387840.0,lora,2,16,q_proj k_proj v_proj o_proj,47311452160.0,0.9361546834309896,179.3128,0.669,0.167,1370.566 -,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -0.09,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,69848.0,64347637760.0,lora,4,16,q_proj k_proj v_proj o_proj,47311648768.0,0.9383139928181966,280.8919,0.854,0.107,1749.855 -,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80894.0,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80979.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,baseline-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,27023.0,22825932800.0,lora,4,16,q_proj k_proj v_proj o_proj,5368221184.0,0.9589527130126954,178.8061,0.671,0.168,2748.9 -0.04,True,baseline-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13530.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9154380798339844,87.3652,1.374,0.343,2813.02 -0.09,True,baseline-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,47145.0,40278956032.0,lora,8,16,q_proj k_proj v_proj o_proj,5368614400.0,0.9702634493509928,341.2286,0.703,0.088,2880.884 -0.09,True,baseline-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21502.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.914565912882487,149.9341,1.601,0.2,3278.241 -0.04,True,baseline-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,48313.0,46419968512.0,lora,4,16,q_proj k_proj v_proj o_proj,25726225920.0,0.9744932492574055,351.8623,0.341,0.085,1396.91 -0.04,True,baseline-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25549.0,21922782720.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9303209940592448,171.4299,0.7,0.175,1433.589 -0.09,True,baseline-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,69931.0,67089150464.0,lora,8,16,q_proj k_proj v_proj o_proj,25726619136.0,0.9745417594909668,629.837,0.381,0.048,1560.785 -0.09,True,baseline-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29384115200.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9310146331787109,300.5119,0.799,0.1,1635.609 -,True,baseline-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80893.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,baseline-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52634.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.0399916648864747,584.3145,0.205,0.051,420.595 -,True,baseline-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,79557.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,True,baseline-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80749.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,accelerated-peft-bnb,36,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,19931.0,15860019712.0,lora,4,16,q_proj k_proj v_proj o_proj,4843384320.0,0.9652111371358235,143.3569,0.837,0.209,3428.645 -0.04,True,accelerated-peft-bnb,37,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13497.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9277165730794271,86.4307,1.388,0.347,2843.435 -0.09,True,accelerated-peft-bnb,38,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,34355.0,26849751552.0,lora,8,16,q_proj k_proj v_proj o_proj,4843777536.0,0.9493892669677735,279.7156,0.858,0.107,3514.427 -0.09,True,accelerated-peft-bnb,39,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21479.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.9110882759094239,149.3914,1.607,0.201,3290.15 -0.04,True,accelerated-peft-bnb,40,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,38405.0,36218024448.0,lora,4,16,q_proj k_proj v_proj o_proj,25201389056.0,0.9741149584452311,278.5888,0.431,0.108,1764.32 -0.04,True,accelerated-peft-bnb,41,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25592.0,21906697728.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9300654411315918,172.7359,0.695,0.174,1422.75 -0.09,True,accelerated-peft-bnb,42,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,50875.0,47207756288.0,lora,8,16,q_proj k_proj v_proj o_proj,25201782272.0,0.9748441060384114,512.2298,0.469,0.059,1919.139 -0.09,True,accelerated-peft-bnb,43,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29369087488.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9301350593566895,287.6381,0.834,0.104,1708.814 -0.04,True,accelerated-peft-bnb,44,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,72829.0,68159977472.0,lora,4,16,q_proj k_proj v_proj o_proj,37346815488.0,1.118430455525716,1075.2044,0.112,0.028,457.141 -0.04,True,accelerated-peft-bnb,45,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52632.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.040946865081787,586.651,0.205,0.051,418.92 -,True,accelerated-peft-bnb,46,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80405.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,True,accelerated-peft-bnb,47,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80954.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,accelerated-peft-autogptq,48,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,20453.0,15890329088.0,lora,4,16,q_proj k_proj v_proj o_proj,4873693696.0,1.3805528958638509,151.0359,0.795,0.199,3254.326 -0.04,True,accelerated-peft-autogptq,49,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,17198.0,9952175616.0,lora,2,16,q_proj k_proj v_proj o_proj,3005709312.0,1.1706618309020995,87.4109,1.373,0.343,2811.548 -0.09,True,accelerated-peft-autogptq,50,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,34247.0,26880060928.0,lora,8,16,q_proj k_proj v_proj o_proj,4874086912.0,1.2741642634073893,282.6391,0.849,0.106,3478.076 -0.09,True,accelerated-peft-autogptq,51,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,24783.0,16262768128.0,lora,4,16,q_proj k_proj v_proj o_proj,3005905920.0,1.043952751159668,152.5473,1.573,0.197,3222.083 -0.04,True,accelerated-peft-autogptq,52,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,37461.0,35528093184.0,lora,4,16,q_proj k_proj v_proj o_proj,24511457792.0,0.9936613400777181,263.6066,0.455,0.114,1864.597 -0.04,True,accelerated-peft-autogptq,53,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,46641.0,25708175360.0,lora,2,16,q_proj k_proj v_proj o_proj,12788874240.0,0.9420519828796386,167.065,0.718,0.18,1471.045 -0.09,True,accelerated-peft-autogptq,54,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,49925.0,46517825024.0,lora,8,16,q_proj k_proj v_proj o_proj,24511851008.0,0.9855653127034505,498.9022,0.481,0.06,1970.406 -0.09,True,accelerated-peft-autogptq,55,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,52358.0,27739090432.0,lora,4,16,q_proj k_proj v_proj o_proj,12789070848.0,0.9389812151590983,281.8034,0.852,0.106,1744.195 -0.04,True,accelerated-peft-autogptq,56,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,71565.0,65895347200.0,lora,4,16,q_proj k_proj v_proj o_proj,36290144768.0,1.0755928039550782,1060.8387,0.113,0.028,463.331 -0.04,True,accelerated-peft-autogptq,57,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80387.0,45397678592.0,lora,2,16,q_proj k_proj v_proj o_proj,18649885696.0,1.0256956418355305,576.0422,0.208,0.052,426.635 -,True,accelerated-peft-autogptq,58,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,80293.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -0.08,True,accelerated-peft-autogptq,59,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80363.0,70667573760.0,lora,4,16,q_proj k_proj v_proj o_proj,18650082304.0,1.0266701062520345,1089.3291,0.22,0.028,451.214 +epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8676117706298828,584.6749,0.684,0.171,2802.241 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,12512.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8593511199951172,279.9917,1.429,0.357,2925.801 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.86837890625,1149.6017,0.696,0.087,2850.378 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,20435.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8526134586334229,496.2449,1.612,0.202,3301.596 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,47079.0,46427906560,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8966263771057129,1169.4078,0.342,0.086,1401.051 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,24609.0,21937980416,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8650046825408936,564.3075,0.709,0.177,1451.691 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,68071.0,67121147392,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8866284656524658,2118.0176,0.378,0.047,1547.107 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,32054.0,29375012352,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8636721038818359,959.452,0.834,0.104,1707.641 +,True,baseline-peft-bnb,2e-4,16,0.0,80631.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.14,True,baseline-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9462522315979004,1951.2462,0.205,0.051,419.834 +,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,baseline-peft-bnb,2e-4,16,0.0,80801.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.935322732925415,3737.7987,0.214,0.027,438.333 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8679532146453858,480.1165,0.833,0.208,3412.505 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12477.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8598325538635254,281.0553,1.423,0.356,2914.729 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8708646774291993,944.515,0.847,0.106,3469.294 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20417.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8568318557739257,498.8375,1.604,0.2,3284.436 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37321.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8979199028015137,923.4329,0.433,0.108,1774.249 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24783.0,21940224000,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8649028778076172,564.1011,0.709,0.177,1452.222 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49847.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8877867794036866,1717.1699,0.466,0.058,1908.256 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,31907.0,29336790016,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8623861598968506,952.2959,0.84,0.105,1720.474 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.999151840209961,3662.4376,0.109,0.027,447.352 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9392572689056397,1950.7659,0.205,0.051,419.938 +,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80866.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9258937835693359,3744.4001,0.214,0.027,437.56 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0217428588867188,477.2159,0.838,0.21,3433.247 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12056.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9701251029968262,278.7874,1.435,0.359,2938.44 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9569056987762451,941.1761,0.85,0.106,3481.601 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19530.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9303163433074951,494.3287,1.618,0.202,3314.394 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9736110210418701,411.3906,0.972,0.243,3982.589 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11506.0,9174099456,2405399552,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0141907215118409,248.8178,1.608,0.402,3292.368 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9668986797332764,809.2016,0.989,0.124,4049.424 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,18635.0,15282316800,2405596160,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.942121753692627,444.2322,1.801,0.225,3688.162 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36435.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9004004192352295,879.8344,0.455,0.114,1862.169 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22962.5,20697435648,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8698519325256348,537.8597,0.744,0.186,1523.074 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48941.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8974114608764648,1669.3163,0.479,0.06,1962.959 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29756.0,27484941824,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8667408466339112,924.2282,0.866,0.108,1772.722 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36613.0,33671981056,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9003233146667481,814.7613,0.491,0.123,2010.896 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22421.0,20108989952,12191160320,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.867002067565918,506.3203,0.79,0.198,1617.948 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49691.0,42742948864,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.897435302734375,1534.4874,0.521,0.065,2135.436 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,28865.0,26629788672,12191300608,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.866525583267212,877.2087,0.912,0.114,1867.742 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.99012770652771,3600.8607,0.111,0.028,455.002 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49455.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9539268207550049,1890.9021,0.212,0.053,433.232 +,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79283.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9549467945098877,3679.8651,0.217,0.027,445.234 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9903428840637207,3295.1075,0.121,0.03,497.222 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,46207.0,41579411968,15105330176,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9634347057342529,1740.6214,0.23,0.057,470.637 +,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80949.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,74507.0,66445605376,15105526784,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9590920734405518,3441.8985,0.232,0.029,476.016 +0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9002080440521241,558.4193,0.716,0.179,2933.996 +0.15,,none,2e-5,,,43695.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8854282188415528,302.5551,1.322,0.331,2707.606 +0.29,,none,2e-5,,,73761.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.0202219200134277,1085.5804,0.737,0.092,3018.478 +0.29,,none,2e-5,,,52923.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8920887660980225,561.8731,1.424,0.178,2915.961 +,,none,2e-5,,,79961.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,, +,,none,2e-5,,,80925.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, +,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,, +,,none,2e-5,,,80703.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, +,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,, +,,none,2e-5,,,80922.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, +,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,, +,,none,2e-5,,,80782.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, +0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8848505210876465,456.0676,0.877,0.219,3592.45 +0.15,,none,2e-4,16,0.0,17655.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8546714687347412,267.0472,1.498,0.374,3067.623 +0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0078722095489503,909.6399,0.879,0.11,3602.305 +0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8803257846832275,477.2486,1.676,0.21,3433.012 +,,none,2e-4,16,0.0,78871.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,,none,2e-4,16,0.0,61532.0,57531527168,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8628986740112304,545.0419,0.734,0.183,1503.004 +,,none,2e-4,16,0.0,80991.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,,none,2e-4,16,0.0,68811.0,64348470272,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8795901584625244,919.9512,0.87,0.109,1780.964 +,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80760.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80987.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704845142364502,417.5391,0.958,0.239,3923.944 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5527.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8942180156707764,818.5228,0.977,0.122,4003.309 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5675.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37301.0,35622334464,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.887912654876709,861.4969,0.464,0.116,1901.806 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49955.0,46024318976,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8887538051605225,1590.7501,0.503,0.063,2059.909 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0002326488494873,3357.4377,0.119,0.03,487.991 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,80303.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,21095.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, diff --git a/scripts/benchmarks/refs/l40_40gb.csv b/scripts/benchmarks/refs/l40_40gb.csv deleted file mode 100644 index 2158c782..00000000 --- a/scripts/benchmarks/refs/l40_40gb.csv +++ /dev/null @@ -1,49 +0,0 @@ -acceleration_framework_config_file,epoch,error_messages,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second,training_data_path -,,,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,0.03,,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,2,,,0.9020393848419189,102.4493,0.781,0.195,1599.23,benchmark_outputs/data/cache.json -,,,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,0.06,,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,4,,,0.936076545715332,170.7722,0.937,0.117,1918.814,benchmark_outputs/data/cache.json -,,,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,2,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,2,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,4,,,,,,,,benchmark_outputs/data/cache.json -,0.03,,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9326287746429444,120.2794,0.665,0.166,2724.324,benchmark_outputs/data/cache.json -,0.03,,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9157441139221192,78.5825,1.018,0.255,2084.943,benchmark_outputs/data/cache.json -,0.06,,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.0113807678222657,241.3246,0.663,0.083,2715.679,benchmark_outputs/data/cache.json -,0.06,,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9433841228485107,133.2158,1.201,0.15,2459.768,benchmark_outputs/data/cache.json -,,,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,36,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.6183419704437256,137.2634,0.583,0.146,2387.235,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,37,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.7251328945159912,73.906,1.082,0.271,2216.871,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,38,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.5904263019561768,272.1958,0.588,0.073,2407.679,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,39,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.515465259552002,138.6152,1.154,0.144,2363.954,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,40,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.012540912628174,227.0536,0.352,0.088,1443.183,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,41,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.0235525131225587,121.7118,0.657,0.164,1346.13,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,42,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,43,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.0152217864990234,229.6679,0.697,0.087,1426.756,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,44,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,45,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,46,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,47,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,0,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9979345798492432,130.1845,0.615,0.154,2517.044,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,1,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.942676591873169,69.8209,1.146,0.286,2346.575,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,2,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9919514656066895,259.8776,0.616,0.077,2521.802,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,3,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.933735466003418,133.6157,1.197,0.15,2452.406,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,4,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.015654945373535,218.3215,0.366,0.092,1500.906,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,5,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9546889305114746,173.2373,0.462,0.115,945.755,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,6,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,7,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9585415840148925,273.4507,0.585,0.073,1198.315,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,8,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,9,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,10,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,11,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index c935ac31..42f7c753 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -52,6 +52,7 @@ scenarios: - name: accelerated-peft-bnb framework_config: - accelerated-peft-bnb + - accelerated-peft-bnb-foak arguments: fp16: True learning_rate: 2e-4 @@ -82,4 +83,4 @@ scenarios: model_name_or_path: - 'TheBloke/Mistral-7B-v0.1-GPTQ' - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' - - 'TheBloke/Llama-2-70B-GPTQ' \ No newline at end of file + - 'TheBloke/Llama-2-70B-GPTQ' diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index fd51d965..b3485e3c 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -143,6 +143,7 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4 = "bnb-nf4" KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" +KEY_BNB_NF4_FOAK = "bnb-nf4-foak" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -153,14 +154,18 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4_BASELINE: ( "plugins/accelerated-peft/configs/bnb.yaml", [ - ("peft.quantization.bitsandbytes.quant_type", "nf4"), - ("peft.quantization.bitsandbytes.no_peft_model", True), + ("peft.quantization.bitsandbytes.quant_type", "nf4"), + ("peft.quantization.bitsandbytes.no_peft_model", True), ], ), KEY_AUTO_GPTQ_FOAK: ( "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")], ), + KEY_BNB_NF4_FOAK: ( + "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", + [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], + ), } # list of (tag, combi) tuples @@ -173,8 +178,10 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)), ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), + ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ] + # TODO: throw error if merge conflicts def merge_configs(config_contents: List[Dict]): "helper function to merge configuration contents." @@ -183,10 +190,10 @@ def merge_configs(config_contents: List[Dict]): def _merge(result: Dict, new_contents: Dict): for k, v in new_contents.items(): if k not in result: - # if k is not in result, it means v does not + # if k is not in result, it means v does not # exist as a subtree under result, so we just do # an assingment - result[k] = v + result[k] = v else: # otherwise we call the merge _merge(result[k], v) diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 798138bf..8f8a1f9b 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -58,10 +58,10 @@ if [ -n "$RESULT_DIR" ]; then echo "Results dir $RESULT_DIR is not empty, but NO_OVERWRITE=true" echo "If intending to overwrite please delete the folder manually" echo "or do not set NO_OVERWRITE" - exit 1 + else + echo "Deleting $RESULT_DIR" + rm -rf $RESULT_DIR fi - echo "Deleting $RESULT_DIR" - rm -rf $RESULT_DIR fi # tag on the directories @@ -98,9 +98,11 @@ elif [ "$MEMORY_LOGGING" = "all" ]; then fi # dump out the environment -echo "Creating $RESULT_DIR" -mkdir -p $RESULT_DIR -pip freeze > $PIP_REQUIREMENTS_FILE +if [ ! "$NO_OVERWRITE" = "true" ]; then + echo "Creating $RESULT_DIR" + mkdir -p $RESULT_DIR + pip freeze > $PIP_REQUIREMENTS_FILE +fi # run the bench python $WORKING_DIR/benchmark.py \ @@ -116,8 +118,10 @@ python $WORKING_DIR/benchmark.py \ # this will write to the BENCH_RESULT_FILE # Remove the columns with values already represented by other metrics in the summary report PYTHONPATH=. \ - python $WORKING_DIR/display_bench_results.py benchmark_outputs \ + python $WORKING_DIR/display_bench_results.py $RESULT_DIR \ --result_file $BENCH_RESULT_FILE \ + --keep_columns \ + 'torch_dtype' \ --remove_columns \ 'before_init_mem_cpu' \ 'before_init_mem_gpu' \ @@ -129,5 +133,7 @@ PYTHONPATH=. \ 'train_mem_cpu_peaked_delta' \ 'train_mem_gpu_alloc_delta' \ 'train_mem_gpu_peaked_delta' \ + 'training_data_path' \ + 'error_messages' \ 'acceleration_framework_config_file' From 00febdc8a6b03d2fc509469f2e9e4619dc96a372 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 7 Jun 2024 15:02:43 +0800 Subject: [PATCH 7/8] Address Incorrect Ignoring of Base Layer Modules for FSDP with Kernels (#31) * properly ignore lora adapters * handle qlora quant state * improve fix * further simplification of fix * updated benchmark reference (#34) --------- Co-authored-by: achew010 <165894159+achew010@users.noreply.github.com> --- .../framework_plugin_fast_quantized_peft.py | 7 +- .../fused_ops/unsloth_lora/utils.py | 18 ++- scripts/benchmarks/refs/a100_80gb.csv | 153 +++++++++--------- 3 files changed, 96 insertions(+), 82 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index 7eab87f0..01a5b4b7 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -55,7 +55,11 @@ def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin): reduces the accumulated gradients across devices """ - fsdp_plugin.ignored_modules = modules + # NOTE: assuming lora has no bias + fsdp_plugin.ignored_modules = [] + for mod in modules: + fsdp_plugin.ignored_modules.append(mod.lora_A) + fsdp_plugin.ignored_modules.append(mod.lora_B) def _all_reduce_hook(grad): if grad is not None: @@ -64,7 +68,6 @@ def _all_reduce_hook(grad): return grad for mod in modules: - # NOTE: assuming lora has no bias A = mod.lora_A.default B = mod.lora_B.default diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py index 6ea90780..5354670f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py @@ -41,19 +41,27 @@ def calculate_settings(n): cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +# modified by flim@sg.ibm.com +def QUANT_STATE(W, base_layer): -def QUANT_STATE(W): - return getattr(W, "quant_state", None) -pass + # if the weights has quant_state just take it from there + if hasattr(W, 'quant_state'): + return W.quant_state + # otherwise fall back to checking if it is on the base layer + # This is needed when FSDP shards the parameters, and destroys the original + # weight matrix, so we can get the quant state back + return getattr(base_layer, 'quant_state', None) +pass +# modified by flim@sg.ibm.com def get_lora_parameters(proj): # For DPO or disabled adapters base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: - return W, QUANT_STATE(W), None, None, None + return W, QUANT_STATE(W, base_layer), None, None, None pass active_adapter = proj.active_adapters[0] if \ @@ -61,7 +69,7 @@ def get_lora_parameters(proj): A = proj.lora_A [active_adapter].weight B = proj.lora_B [active_adapter].weight s = proj.scaling[active_adapter] - return W, QUANT_STATE(W), A, B, s + return W, QUANT_STATE(W, base_layer), A, B, s pass diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv index b83549a7..45cdf125 100644 --- a/scripts/benchmarks/refs/a100_80gb.csv +++ b/scripts/benchmarks/refs/a100_80gb.csv @@ -1,82 +1,85 @@ epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8676117706298828,584.6749,0.684,0.171,2802.241 -0.15,True,baseline-peft-bnb,2e-4,16,0.0,12512.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8593511199951172,279.9917,1.429,0.357,2925.801 -0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.86837890625,1149.6017,0.696,0.087,2850.378 -0.29,True,baseline-peft-bnb,2e-4,16,0.0,20435.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8526134586334229,496.2449,1.612,0.202,3301.596 -0.15,True,baseline-peft-bnb,2e-4,16,0.0,47079.0,46427906560,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8966263771057129,1169.4078,0.342,0.086,1401.051 -0.15,True,baseline-peft-bnb,2e-4,16,0.0,24609.0,21937980416,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8650046825408936,564.3075,0.709,0.177,1451.691 -0.29,True,baseline-peft-bnb,2e-4,16,0.0,68071.0,67121147392,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8866284656524658,2118.0176,0.378,0.047,1547.107 -0.29,True,baseline-peft-bnb,2e-4,16,0.0,32054.0,29375012352,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8636721038818359,959.452,0.834,0.104,1707.641 -,True,baseline-peft-bnb,2e-4,16,0.0,80631.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.14,True,baseline-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9462522315979004,1951.2462,0.205,0.051,419.834 -,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,baseline-peft-bnb,2e-4,16,0.0,80801.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.935322732925415,3737.7987,0.214,0.027,438.333 -0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8679532146453858,480.1165,0.833,0.208,3412.505 -0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12477.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8598325538635254,281.0553,1.423,0.356,2914.729 -0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8708646774291993,944.515,0.847,0.106,3469.294 -0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20417.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8568318557739257,498.8375,1.604,0.2,3284.436 -0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37321.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8979199028015137,923.4329,0.433,0.108,1774.249 -0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24783.0,21940224000,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8649028778076172,564.1011,0.709,0.177,1452.222 -0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49847.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8877867794036866,1717.1699,0.466,0.058,1908.256 -0.29,True,accelerated-peft-bnb,2e-4,16,0.0,31907.0,29336790016,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8623861598968506,952.2959,0.84,0.105,1720.474 -0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.999151840209961,3662.4376,0.109,0.027,447.352 -0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9392572689056397,1950.7659,0.205,0.051,419.938 -,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80866.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9258937835693359,3744.4001,0.214,0.027,437.56 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0217428588867188,477.2159,0.838,0.21,3433.247 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12056.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9701251029968262,278.7874,1.435,0.359,2938.44 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9569056987762451,941.1761,0.85,0.106,3481.601 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19530.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9303163433074951,494.3287,1.618,0.202,3314.394 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9736110210418701,411.3906,0.972,0.243,3982.589 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11506.0,9174099456,2405399552,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0141907215118409,248.8178,1.608,0.402,3292.368 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9668986797332764,809.2016,0.989,0.124,4049.424 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,18635.0,15282316800,2405596160,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.942121753692627,444.2322,1.801,0.225,3688.162 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36435.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9004004192352295,879.8344,0.455,0.114,1862.169 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22962.5,20697435648,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8698519325256348,537.8597,0.744,0.186,1523.074 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48941.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8974114608764648,1669.3163,0.479,0.06,1962.959 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29756.0,27484941824,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8667408466339112,924.2282,0.866,0.108,1772.722 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36613.0,33671981056,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9003233146667481,814.7613,0.491,0.123,2010.896 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22421.0,20108989952,12191160320,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.867002067565918,506.3203,0.79,0.198,1617.948 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49691.0,42742948864,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.897435302734375,1534.4874,0.521,0.065,2135.436 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,28865.0,26629788672,12191300608,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.866525583267212,877.2087,0.912,0.114,1867.742 -0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.99012770652771,3600.8607,0.111,0.028,455.002 -0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49455.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9539268207550049,1890.9021,0.212,0.053,433.232 -,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79283.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9549467945098877,3679.8651,0.217,0.027,445.234 -0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9903428840637207,3295.1075,0.121,0.03,497.222 -0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,46207.0,41579411968,15105330176,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9634347057342529,1740.6214,0.23,0.057,470.637 -,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80949.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,74507.0,66445605376,15105526784,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9590920734405518,3441.8985,0.232,0.029,476.016 -0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9002080440521241,558.4193,0.716,0.179,2933.996 -0.15,,none,2e-5,,,43695.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8854282188415528,302.5551,1.322,0.331,2707.606 -0.29,,none,2e-5,,,73761.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.0202219200134277,1085.5804,0.737,0.092,3018.478 -0.29,,none,2e-5,,,52923.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8920887660980225,561.8731,1.424,0.178,2915.961 -,,none,2e-5,,,79961.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,, -,,none,2e-5,,,80925.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, +0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9112484455108643,565.9213,0.707,0.177,2895.102 +0.15,,none,2e-5,,,43702.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8622726058959961,307.6782,1.3,0.325,2662.522 +0.29,,none,2e-5,,,70669.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.017976951599121,1094.9632,0.731,0.091,2992.612 +0.29,,none,2e-5,,,52882.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8944576263427735,576.1931,1.388,0.174,2843.491 +,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,, +,,none,2e-5,,,79169.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, ,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,, -,,none,2e-5,,,80703.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, +,,none,2e-5,,,80083.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, ,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,, -,,none,2e-5,,,80922.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, +,,none,2e-5,,,80923.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, ,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,, -,,none,2e-5,,,80782.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, -0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8848505210876465,456.0676,0.877,0.219,3592.45 -0.15,,none,2e-4,16,0.0,17655.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8546714687347412,267.0472,1.498,0.374,3067.623 -0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0078722095489503,909.6399,0.879,0.11,3602.305 -0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8803257846832275,477.2486,1.676,0.21,3433.012 -,,none,2e-4,16,0.0,78871.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,,none,2e-4,16,0.0,61532.0,57531527168,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8628986740112304,545.0419,0.734,0.183,1503.004 -,,none,2e-4,16,0.0,80991.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.29,,none,2e-4,16,0.0,68811.0,64348470272,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8795901584625244,919.9512,0.87,0.109,1780.964 +,,none,2e-5,,,81006.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, +0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8818108749389648,458.2667,0.873,0.218,3575.21 +0.15,,none,2e-4,16,0.0,17669.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8540384006500245,270.1999,1.48,0.37,3031.829 +0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0028394603729247,912.5081,0.877,0.11,3590.982 +0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8833828353881836,482.6901,1.657,0.207,3394.311 +,,none,2e-4,16,0.0,80990.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,,none,2e-4,16,0.0,61532.0,57546370048,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8696129798889161,561.2483,0.713,0.178,1459.604 +,,none,2e-4,16,0.0,80207.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,,none,2e-4,16,0.0,69171.0,64398757376,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.885084867477417,938.9714,0.852,0.106,1744.888 ,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,none,2e-4,16,0.0,80760.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80907.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, ,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,none,2e-4,16,0.0,80987.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704845142364502,417.5391,0.958,0.239,3923.944 -,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5527.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8942180156707764,818.5228,0.977,0.122,4003.309 -,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5675.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37301.0,35622334464,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.887912654876709,861.4969,0.464,0.116,1901.806 -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49955.0,46024318976,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8887538051605225,1590.7501,0.503,0.063,2059.909 -0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0002326488494873,3357.4377,0.119,0.03,487.991 +,,none,2e-4,16,0.0,80783.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8698946189880371,586.9178,0.682,0.17,2791.532 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,12476.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8552890300750733,284.376,1.407,0.352,2880.693 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8654958820343017,1148.1408,0.697,0.087,2854.005 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,20405.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8869294357299805,503.0597,1.59,0.199,3256.87 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,47189.0,46475660288,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8893787956237793,1185.2488,0.337,0.084,1382.326 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,24751.0,21932720128,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8617707204818725,568.5808,0.704,0.176,1440.78 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,68683.0,67165218816,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8893123245239258,2124.0668,0.377,0.047,1542.701 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,32064.0,29353074176,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8585504531860352,962.8971,0.831,0.104,1701.532 +,True,baseline-peft-bnb,2e-4,16,0.0,80121.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.14,True,baseline-peft-bnb,2e-4,16,0.0,51701.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9204118633270264,1981.2518,0.202,0.05,413.476 +,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,baseline-peft-bnb,2e-4,16,0.0,80394.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9444941711425782,3760.1788,0.213,0.027,435.724 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704616069793701,479.6819,0.834,0.208,3415.597 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12533.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8528211212158203,282.8845,1.414,0.354,2895.882 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8675907611846924,945.5376,0.846,0.106,3465.542 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20423.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.854712610244751,502.3584,1.592,0.199,3261.417 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8722561931610108,420.8819,0.95,0.238,3892.778 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,12118.0,9796856320,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8581914234161377,232.51,1.72,0.43,3523.289 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8683128643035889,821.991,0.973,0.122,3986.418 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19463.0,16207063552,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.852388572692871,427.1268,1.873,0.234,3835.864 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37417.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8887558174133301,913.0381,0.438,0.11,1794.449 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24952.0,21921468928,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8612120914459228,572.3054,0.699,0.175,1431.404 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49893.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8909227275848388,1711.7453,0.467,0.058,1914.303 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,32207.0,29359173632,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8591176319122314,959.9538,0.833,0.104,1706.749 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37547.0,35651058176,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8895366668701172,854.9879,0.468,0.117,1916.284 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,24572.0,21746056192,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8630767631530761,514.5553,0.777,0.194,1592.054 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49861.0,46058696192,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8951810073852539,1601.6113,0.499,0.062,2045.94 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,31701.0,29043888640,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8600863265991211,880.114,0.909,0.114,1861.577 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9996430969238281,3700.3604,0.108,0.027,442.768 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9264963436126709,1955.4907,0.205,0.051,418.923 +,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80815.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9262647342681884,3714.7153,0.215,0.027,441.057 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9998687934875489,3351.04,0.119,0.03,488.923 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,51141.0,46250760704,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9389877033233642,1747.6289,0.229,0.057,468.749 ,True,accelerated-peft-bnb-foak,2e-4,16,0.0,80303.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,True,accelerated-peft-bnb-foak,2e-4,16,0.0,21095.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb-foak,2e-4,16,0.0,79861.0,71720933888,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9403298473358155,3375.4111,0.237,0.03,485.393 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.009563512802124,491.6352,0.814,0.203,3332.552 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12230.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9266629409790039,294.4237,1.359,0.34,2782.385 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9904310989379883,953.3973,0.839,0.105,3436.972 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19477.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8998308277130127,506.1818,1.58,0.198,3236.781 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.003525791168213,414.297,0.965,0.241,3954.651 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11879.0,9512265216,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9293491744995117,224.6767,1.78,0.445,3646.128 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.992929859161377,810.9726,0.986,0.123,4040.581 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19063.0,15620482560,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9045120429992676,418.8226,1.91,0.239,3911.919 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36389.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.89991379737854,897.8879,0.445,0.111,1824.727 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22882.0,20691720192,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8638970375061035,557.2929,0.718,0.179,1469.963 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48959.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.893577823638916,1673.2594,0.478,0.06,1958.334 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29704.0,27482931712,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.864154224395752,938.3626,0.853,0.107,1746.02 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36607.0,33649802752,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8993340969085694,811.6061,0.493,0.123,2018.713 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22801.0,20438869504,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8660580062866211,478.0288,0.837,0.209,1713.704 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49669.0,42707730944,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8937735366821289,1533.2657,0.522,0.065,2137.138 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,29370.0,26951336960,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8651807403564453,838.8338,0.954,0.119,1953.188 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9811842250823974,3639.6437,0.11,0.027,450.154 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49475.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9557892894744873,1923.445,0.208,0.052,425.902 +,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79187.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9580207633972168,3685.3642,0.217,0.027,444.569 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.981500825881958,3273.1958,0.122,0.031,500.551 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49187.0,44599679488,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9558010864257812,1682.0158,0.238,0.059,487.035 +,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80945.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,78208.0,69465872896,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9556115436553955,3298.135,0.243,0.03,496.766 From bfde526a7c8e55def4c054c9e84db76f9ebf1bec Mon Sep 17 00:00:00 2001 From: achew010 <165894159+achew010@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:23:35 +0800 Subject: [PATCH 8/8] Shift GPU Memory Computation to End of Benchmarking Script (#30) * shift gpu mem computation to gather_report * addressed comments --- scripts/benchmarks/benchmark.py | 125 ++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index f5ff4a54..ec601c43 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -493,38 +493,6 @@ def maybe_get_experiment_error_traceback(self): return None if len(results) == 0 else results - def get_peak_mem_usage_by_device_id(self): - """ - This function retrieves the raw measurements of reserved GPU memory per device across the experiment - - computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading). - Returns: - - pd.Series of peak memory usage per device id - - the device name as string - e.g. "NVIDIA A100-SXM4-80GB" - - Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series: - - - pd.Series - index - 0 52729.0 - 1 52783.0 - Name: memory.used [MiB], dtype: float64 - """ - - # group the gpu readings into device ids - gpu_logs = pd.read_csv(self.gpu_log_filename, skipinitialspace=True) - # assume that all the devices have the same device name - device_name = gpu_logs.name.iloc[-1] - # extract and convert the gpu memory usage as float values - gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[ - GPU_LOG_USED_MEM_COLUMN_NAME - ].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, ""))) - mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME] - # Calibrate values by subtracting out the initial values of the GPU readings - # to ensure no existing memory is counted in addition with the experiment - initial_values = mem_usage_by_device_id.first() - peak_values = mem_usage_by_device_id.max() - return peak_values.sub(initial_values), device_name - def write_result(self): "Function to write a json result file" @@ -532,30 +500,6 @@ def write_result(self): save_result = ConfigUtils.convert_args_to_dict(self.experiment_args_str) save_result["num_gpus"] = self.num_gpus - # if a gpu log file exist, process the raw nvidia logs and write to result - if os.path.isfile(self.gpu_log_filename): - # Add GPU info and measurements into the result saving - peak_mem_usage_by_device_id, device_name = ( - self.get_peak_mem_usage_by_device_id() - ) - save_result[RESULT_FIELD_DEVICE_NAME] = device_name - # Memory usage is averaged across all devices in the final result - save_result[RESULT_FIELD_RESERVED_GPU_MEM] = ( - peak_mem_usage_by_device_id.mean() - ) - - # process gpu mem from output metrics and write to result - # check if HF_ARG_SKIP_MEMORY_METRIC is set to False in experiment arg - # this arg is specified explicitly inside `def generate_list_of_experiments`` - argument_idx = self.experiment_arg.index(HF_ARG_SKIP_MEMORY_METRIC) - write_memory_metric = not self.experiment_arg[argument_idx + 1] - if write_memory_metric: - peak_gpu_mem, gpu_allocated_mem = extract_gpu_memory_metrics( - self.get_experiment_final_metrics() - ) - save_result[RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM] = peak_gpu_mem - save_result[RESULT_FIELD_ALLOCATED_GPU_MEM] = gpu_allocated_mem - # if there is an error we save the error message else we save the final result maybe_error_messages = self.maybe_get_experiment_error_traceback() if maybe_error_messages is None: @@ -613,6 +557,37 @@ def maybe_get_experiment_error_traceback(self): return None +def get_peak_mem_usage_by_device_id(gpu_logs: pd.DataFrame): + """ + This function retrieves the raw measurements of reserved GPU memory per device across the experiment - + computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading). + Returns: + - pd.Series of peak memory usage per device id + - the device name as string - e.g. "NVIDIA A100-SXM4-80GB" + + Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series: + + - pd.Series + index + 0 52729.0 + 1 52783.0 + Name: memory.used [MiB], dtype: float64 + """ + + # assume that all the devices have the same device name + device_name = gpu_logs.name.iloc[-1] + # extract and convert the gpu memory usage as float values + gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[ + GPU_LOG_USED_MEM_COLUMN_NAME + ].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, ""))) + mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME] + # Calibrate values by subtracting out the initial values of the GPU readings + # to ensure no existing memory is counted in addition with the experiment + initial_values = mem_usage_by_device_id.first() + peak_values = mem_usage_by_device_id.max() + return peak_values.sub(initial_values), device_name + + def prepare_arguments(args): defaults = ConfigUtils.read_yaml(args.defaults_config_path) defaults["training_data_path"] = args.dataset_save_path @@ -712,6 +687,8 @@ def _gather(rdir): x for x in os.listdir(rdir) if x.startswith(DIR_PREFIX_EXPERIMENT) ] for tag in exper_dirs: + gpu_log_filename = os.path.join(rdir, tag, FILE_MEM) + try: with open(os.path.join(rdir, tag, FILE_RESULTS)) as f: tag = tag.replace(DIR_PREFIX_EXPERIMENT + "_", "") @@ -719,6 +696,42 @@ def _gather(rdir): experiment_stats[tag] = json.load(f) except FileNotFoundError: pass + + if script_args["log_nvidia_smi"]: + gpu_logs = pd.read_csv(gpu_log_filename, skipinitialspace=True) + peak_nvidia_mem_by_device_id, device_name = ( + get_peak_mem_usage_by_device_id(gpu_logs) + ) + experiment_stats[tag].update( + { + # Report the mean peak memory across all gpu device ids + RESULT_FIELD_RESERVED_GPU_MEM: peak_nvidia_mem_by_device_id.mean(), + RESULT_FIELD_DEVICE_NAME: device_name, + } + ) + + if script_args["log_memory_hf"] and tag in experiment_stats.keys(): + memory_metrics_prefixes = [ + HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, + HF_TRAINER_LOG_GPU_STAGE_INIT, + HF_TRAINER_LOG_GPU_STAGE_TRAIN, + ] + memory_metrics = { + k: v + for k, v in experiment_stats[tag].items() + if any([prefix in k for prefix in memory_metrics_prefixes]) + } + if len(memory_metrics.keys()) > 0: + peak_torch_gpu_mem, torch_gpu_mem = extract_gpu_memory_metrics( + memory_metrics + ) + experiment_stats[tag].update( + { + RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM: peak_torch_gpu_mem, + RESULT_FIELD_ALLOCATED_GPU_MEM: torch_gpu_mem, + } + ) + df = pd.DataFrame.from_dict(experiment_stats, orient="index").sort_index() try: df["framework_config"] = df["acceleration_framework_config_file"].map(