diff --git a/plugins/framework/pyproject.toml b/plugins/framework/pyproject.toml index 6d71bca3..0eb0c7eb 100644 --- a/plugins/framework/pyproject.toml +++ b/plugins/framework/pyproject.toml @@ -24,7 +24,7 @@ classifiers=[ dependencies = [ "numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3 "torch>2.2", - "transformers", + "git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10", "peft", "accelerate", "pandas", diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 7fe5a898..17305bec 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -33,11 +33,15 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = # Local from .models import ( # pylint: disable=import-outside-toplevel + gpt_bigcode, llama, mistral, mixtral, + granite, ) rules = [ + *gpt_bigcode.get_mp_rules(base_type), + *granite.get_mp_rules(base_type), *llama.get_mp_rules(base_type), *mistral.get_mp_rules(base_type), *mixtral.get_mp_rules(base_type), @@ -55,6 +59,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = # maybe this we should define envvars FILTER_MAP = { + "base_layer": set(), "fused_lora": {"qkvo", "mlp"}, "fast_loss": "cross-ent", "fast_rsm_layernorm": "rms", @@ -65,6 +70,8 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin): # NOTE: may remove this when we have generic model rules restricted_model_archs = [ + "GraniteForCausalLM", + "GPTBigCodeForCausalLM", "MixtralForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", @@ -112,7 +119,10 @@ def augmentation( train_args: TrainingArguments, modifiable_args: Tuple[LoraConfig], ): - + # This is designed to be a passthrough if training scenario is + # full finetuning or standard peft fused-lora rules (only meant for qpeft) + # will still be installed but never triggered + # if no peft layer is detected at the point of patching terms = set() for k, v in self.configurations.items(): if v: @@ -124,8 +134,10 @@ def augmentation( # wrapper function to register foak patches # NOTE: we never take the lora modules so just set arbitrarily # to "auto_gptq" + _base_layer = self.configurations['base_layer'] if 'base_layer' \ + in self.configurations else 'auto_gptq' register_foak_model_patch_rules2( - base_type="auto_gptq", filter_endswith=terms + base_type=_base_layer, filter_endswith=terms ) return model, modifiable_args diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index ff67229c..d13ae255 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -162,7 +162,7 @@ def get_callbacks_and_ready_for_train( # register -AccelerationPlugin.register_plugin( - FastQuantizedPeftAccelerationPlugin, - configuration_and_paths=["peft.quantization.fused_ops_and_kernels"], -) +# AccelerationPlugin.register_plugin( +# FastQuantizedPeftAccelerationPlugin, +# configuration_and_paths=["peft.quantization.fused_ops_and_kernels"], +# ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py new file mode 100644 index 00000000..9eb2cf64 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/gpt_bigcode.py @@ -0,0 +1,40 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fms_acceleration.model_patcher import ( + ModelPatcherRule, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss + +def get_mp_rules(base_type: str): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ + return [ + # TODO: have a generic version of this rule + # - get the module_name and reload on that + ModelPatcherRule( + rule_id="gpt-bigcode-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.gpt_bigcode.modeling_gpt_bigcode", + ), + ), + ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py new file mode 100644 index 00000000..778b6211 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -0,0 +1,130 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from fms_acceleration.model_patcher import ( + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from transformers.models.granite.modeling_granite import ( + GraniteAttention, + GraniteMLP, + GraniteRMSNorm, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + + +def get_mp_rules(base_type: str): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ + return [ + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="granite-rms", + trigger=ModelPatcherTrigger(check=GraniteRMSNorm), + forward=fast_rms_layernorm, + ), + # TODO: have a generic version of this rule + # - do regex on Attention class name + # - have a set of qkv / o module names and check on that + ModelPatcherRule( + rule_id="granite-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="granite-mlp", + trigger=ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], + ) + ), + forward_builder=partial( + build_lora_fused_ops, + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, + base_type=base_type, + ), + ), + # TODO: have a generic version of this rule + # - get the module_name and reload on that + ModelPatcherRule( + rule_id="granite-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.granite.modeling_granite", + ), + ), + # TODO: have a generic version of this rule + # - get the module name + # - check if "apply_rotary_pos_emb" exists + # - patch + ModelPatcherRule( + rule_id="granite-rope", + import_and_maybe_reload=( + "transformers.models.granite.modeling_granite.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ), + ] diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index bd020400..89502227 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -38,23 +38,25 @@ scenarios: - name: full-finetuning framework_config: - - null + - - foak-fast-kernels arguments: learning_rate: 2e-5 model_name_or_path: + - 'bigcode/gpt_bigcode-santacoder' + - 'ibm/PowerLM-3b' - 'mistralai/Mistral-7B-v0.1' - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - 'NousResearch/Llama-2-70b-hf' - torch_dtype: float16 + torch_dtype: bfloat16 - name: standard-peft framework_config: - - null + - - foak-fast-kernels arguments: learning_rate: 2e-4 - torch_dtype: float16 + torch_dtype: bfloat16 peft_method: lora r: 16 lora_alpha: 16 @@ -71,7 +73,7 @@ scenarios: arguments: fp16: True learning_rate: 2e-4 - torch_dtype: float16 + torch_dtype: bfloat16 peft_method: lora r: 16 lora_alpha: 16 @@ -89,13 +91,15 @@ scenarios: arguments: fp16: True learning_rate: 2e-4 - torch_dtype: float16 + torch_dtype: bfloat16 peft_method: lora r: 16 lora_alpha: 16 lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: + - 'ibm/PowerLM-3b' + - 'bigcode/gpt_bigcode-santacoder' - 'mistralai/Mistral-7B-v0.1' - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - 'NousResearch/Llama-2-70b-hf' @@ -107,7 +111,7 @@ scenarios: arguments: learning_rate: 2e-4 fp16: True - torch_dtype: float16 + torch_dtype: bfloat16 peft_method: lora r: 16 lora_alpha: 16