Skip to content

Commit

Permalink
linting fixes
Browse files Browse the repository at this point in the history
Signed-off-by: 1000960000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 1, 2024
1 parent 840b314 commit bff3128
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 54 deletions.
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def augmentation(
modifiable_args: Tuple,
):
# get the config
archs = model.config.architectures
archs = model.config.architectures
model_archs = set(archs if archs is not None else [])

# NOTE: this assumes that augmentation order does not matter
Expand Down
2 changes: 1 addition & 1 deletion plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

# Local
from .framework_plugin_padding_free import PaddingFreeAccelerationPlugin
from .framework_plugin_padding_free import PaddingFreeAccelerationPlugin
27 changes: 22 additions & 5 deletions plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from types import MethodType

if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # pylint: disable=import-error

def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length):
query = query.view(-1, query.size(-2), query.size(-1))
Expand Down Expand Up @@ -51,18 +51,35 @@ def forward(self, *args, **kwargs):
return out, *others

def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, **kwargs,
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
**kwargs,
):
# if not self._flash_attn_uses_top_left_mask:
# causal = self.is_causal
# else:
# # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
# # TODO: Remove the `query_length != 1`
# # check once Flash Attention for RoCm is bumped to 2.1.
# # For details, please see the comment in LlamaFlashAttention2 __init__.
# causal = self.is_causal and query_length != 1

assert attention_mask is None, "should not be using attention mask"
assert position_ids is not None, "should be expecting position ids"
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
(
query_states,
key_states,
value_states,
_,
cu_seq_lens,
max_seq_lens,
) = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids, query_length
)

Expand All @@ -83,7 +100,7 @@ def _flash_attention_forward(
)

return attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

# do this replace
attention._flash_attention_forward = MethodType(_flash_attention_forward, attention)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
# Third Party
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import TrainingArguments, __version__ as transformers_version, DataCollatorForSeq2Seq
from transformers import (
TrainingArguments,
__version__ as transformers_version,
DataCollatorForSeq2Seq,
)
from accelerate import Accelerator
import torch
from types import MethodType
Expand All @@ -30,7 +34,7 @@
TRANSFORMERS_VERSION = "4.44"

class PaddingFreeAccelerationPlugin(AccelerationPlugin):

require_packages = ["flash_attn"]

def __init__(self, configurations: Dict[str, Dict]):
Expand Down Expand Up @@ -111,37 +115,38 @@ def get_callbacks_and_ready_for_train(
def _patch_dataloader(
self,
accelerator: Accelerator,
):
"""
Hijacks the accelorator prepare inside `Trainer.train`
- If it is a single argument. it is assumed to be the prepare call on the dataloader
- we replace the collate function in the dataloader to flatten the batch into a long
sequence with special tokens to define the attention computation boundaries
"""
# Check if transformers already supports a collator that flattens the batch
# Otherwise, use the locally implemented DataCollatorWithFlattening
if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION):
from .ilab_utils import DataCollatorWithFlattening
else:
from transformers import DataCollatorWithFlattening

# hijack the dataloader in accelerator.prepare to replace the collate_fn
_old_prepare = accelerator.prepare
def prepare(self, *args, device_placement=None):
if len(args) > 1 or not isinstance(args[0], DataLoader):
return _old_prepare(*args, device_placement=device_placement)
dataloader = args[0]

if not isinstance(dataloader.collate_fn, DataCollatorForSeq2Seq):
raise Exception("The padding-free plugin currently only works with a `DataCollatorForSeq2Seq` collate_fn, \
otherwise the collation can be unreliable")

# Replace the collate_fn in dataloader
dataloader.collate_fn = DataCollatorWithFlattening()

return dataloader

accelerator.prepare = MethodType(prepare, accelerator)
):
"""
Hijacks the accelorator prepare inside `Trainer.train`
- If it is a single argument. it is assumed to be the prepare call on the dataloader
- we replace the collate function in the dataloader to flatten the batch into a long
sequence with special tokens to define the attention computation boundaries
"""
# Check if transformers already supports a collator that flattens the batch
# Otherwise, use the locally implemented DataCollatorWithFlattening
if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION):
from .ilab_utils import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel
else:
from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module

# hijack the dataloader in accelerator.prepare to replace the collate_fn
_old_prepare = accelerator.prepare
def prepare(self, *args, device_placement=None):
if len(args) > 1 or not isinstance(args[0], DataLoader):
return _old_prepare(*args, device_placement=device_placement)
dataloader = args[0]

if not isinstance(dataloader.collate_fn, DataCollatorForSeq2Seq):
raise TypeError("The padding-free plugin currently only works with a \
`DataCollatorForSeq2Seq` collate_fn, \
otherwise the collation can be unreliable")

# Replace the collate_fn in dataloader
dataloader.collate_fn = DataCollatorWithFlattening()

return dataloader

accelerator.prepare = MethodType(prepare, accelerator)

# register
AccelerationPlugin.register_plugin(
Expand Down
28 changes: 14 additions & 14 deletions plugins/instruct-lab/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ deps =
-e {toxinidir}/../framework
pylint>=2.16.2,<=3.1.0
commands =
pylint src tests
pylint src
allowlist_externals = pylint

[testenv:fmt]
Expand All @@ -35,16 +35,16 @@ commands =
black --exclude .*unsloth.* tests
isort .

# [testenv:build]
# description = build wheel
# deps =
# build
# commands = python -m build -w
# skip_install = True
#
# [testenv:twinecheck]
# description = check wheel
# deps =
# twine
# commands = twine check dist/*
# skip_install = True
[testenv:build]
description = build wheel
deps =
build
commands = python -m build -w
skip_install = True

[testenv:twinecheck]
description = check wheel
deps =
twine
commands = twine check dist/*
skip_install = True

0 comments on commit bff3128

Please sign in to comment.