Skip to content

Commit

Permalink
minor fixes to foak full
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 6, 2024
1 parent 82e1873 commit bea4dda
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 14 deletions.
2 changes: 1 addition & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers=[
dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
"transformers",
"git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10",
"peft",
"accelerate",
"pandas",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =

# Local
from .models import ( # pylint: disable=import-outside-toplevel
gpt_bigcode,
llama,
mistral,
mixtral,
granite,
)
rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*mixtral.get_mp_rules(base_type),
Expand All @@ -55,6 +59,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =

# maybe this we should define envvars
FILTER_MAP = {
"base_layer": set(),
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_rsm_layernorm": "rms",
Expand All @@ -65,6 +70,8 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin):

# NOTE: may remove this when we have generic model rules
restricted_model_archs = [
"GraniteForCausalLM",
"GPTBigCodeForCausalLM",
"MixtralForCausalLM",
"LlamaForCausalLM",
"MistralForCausalLM",
Expand Down Expand Up @@ -112,7 +119,10 @@ def augmentation(
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):

# This is designed to be a passthrough if training scenario is
# full finetuning or standard peft fused-lora rules (only meant for qpeft)
# will still be installed but never triggered
# if no peft layer is detected at the point of patching
terms = set()
for k, v in self.configurations.items():
if v:
Expand All @@ -124,8 +134,10 @@ def augmentation(
# wrapper function to register foak patches
# NOTE: we never take the lora modules so just set arbitrarily
# to "auto_gptq"
_base_layer = self.configurations['base_layer'] if 'base_layer' \
in self.configurations else 'auto_gptq'
register_foak_model_patch_rules2(
base_type="auto_gptq", filter_endswith=terms
base_type=_base_layer, filter_endswith=terms
)
return model, modifiable_args

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_callbacks_and_ready_for_train(


# register
AccelerationPlugin.register_plugin(
FastQuantizedPeftAccelerationPlugin,
configuration_and_paths=["peft.quantization.fused_ops_and_kernels"],
)
# AccelerationPlugin.register_plugin(
# FastQuantizedPeftAccelerationPlugin,
# configuration_and_paths=["peft.quantization.fused_ops_and_kernels"],
# )
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fms_acceleration.model_patcher import (
ModelPatcherRule,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss

def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
return [
# TODO: have a generic version of this rule
# - get the module_name and reload on that
ModelPatcherRule(
rule_id="gpt-bigcode-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.gpt_bigcode.modeling_gpt_bigcode",
),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from functools import partial

# Third Party
from fms_acceleration.model_patcher import (
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteMLP,
GraniteRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
return [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
ModelPatcherRule(
rule_id="granite-rms",
trigger=ModelPatcherTrigger(check=GraniteRMSNorm),
forward=fast_rms_layernorm,
),
# TODO: have a generic version of this rule
# - do regex on Attention class name
# - have a set of qkv / o module names and check on that
ModelPatcherRule(
rule_id="granite-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=GraniteAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=GraniteAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
base_type=base_type,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
base_type=base_type,
),
logic="APPEND",
),
),
ModelPatcherRule(
rule_id="granite-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=GraniteMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
base_type=base_type,
),
),
# TODO: have a generic version of this rule
# - get the module_name and reload on that
ModelPatcherRule(
rule_id="granite-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granite.modeling_granite",
),
),
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
# - patch
ModelPatcherRule(
rule_id="granite-rope",
import_and_maybe_reload=(
"transformers.models.granite.modeling_granite.apply_rotary_pos_emb",
fast_rope_embedding,
None,
),
),
]
18 changes: 11 additions & 7 deletions scripts/benchmarks/scenarios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,25 @@
scenarios:
- name: full-finetuning
framework_config:
- null
-
- foak-fast-kernels
arguments:
learning_rate: 2e-5
model_name_or_path:
- 'bigcode/gpt_bigcode-santacoder'
- 'ibm/PowerLM-3b'
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
torch_dtype: float16
torch_dtype: bfloat16

- name: standard-peft
framework_config:
- null
-
- foak-fast-kernels
arguments:
learning_rate: 2e-4
torch_dtype: float16
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
Expand All @@ -71,7 +73,7 @@ scenarios:
arguments:
fp16: True
learning_rate: 2e-4
torch_dtype: float16
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
Expand All @@ -89,13 +91,15 @@ scenarios:
arguments:
fp16: True
learning_rate: 2e-4
torch_dtype: float16
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'ibm/PowerLM-3b'
- 'bigcode/gpt_bigcode-santacoder'
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
Expand All @@ -107,7 +111,7 @@ scenarios:
arguments:
learning_rate: 2e-4
fp16: True
torch_dtype: float16
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
Expand Down

0 comments on commit bea4dda

Please sign in to comment.