Skip to content

Commit

Permalink
additional fmt fixes
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 29, 2024
1 parent 9cb999e commit 29362a4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
)
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down
9 changes: 6 additions & 3 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
)
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
PaddingFree,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down Expand Up @@ -514,7 +514,8 @@ def test_framework_initialize_and_trains_with_aadp():
model_args,
data_args,
train_args,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
attention_and_distributed_packing_config=\
attention_and_distributed_packing_config,
)

# spy inside the train to ensure that the ilab plugin is called
Expand Down Expand Up @@ -573,7 +574,8 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset():
model_args,
data_args,
train_args,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
attention_and_distributed_packing_config=\
attention_and_distributed_packing_config,
)


Expand All @@ -596,6 +598,7 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled():
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
)


def test_error_raised_with_multipack_and_paddingfree_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def _verify_configured_dataclasses(self):
# this also ensures that the attention implementation for multipack
# will be flash attention as sfttrainer will enforce flash attn to be
# set for padding free
assert self.padding_free is not None, "`--multipack` is currently only supported with `--padding_free`"
assert (
self.padding_free is not None
), "`--multipack` is currently only supported with `--padding_free`"

@staticmethod
def from_dataclasses(*dataclasses: Type):
Expand Down

0 comments on commit 29362a4

Please sign in to comment.