From 29362a4b8b280af99e899c6e06a28f3f8e8c3a10 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Thu, 29 Aug 2024 09:29:04 +0000 Subject: [PATCH] additional fmt fixes Signed-off-by: 1000850000 user --- tests/acceleration/test_acceleration_dataclasses.py | 2 +- tests/acceleration/test_acceleration_framework.py | 9 ++++++--- .../acceleration_framework_config.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/acceleration/test_acceleration_dataclasses.py b/tests/acceleration/test_acceleration_dataclasses.py index 4be78e556..130159933 100644 --- a/tests/acceleration/test_acceleration_dataclasses.py +++ b/tests/acceleration/test_acceleration_dataclasses.py @@ -25,8 +25,8 @@ ) from tuning.config.acceleration_configs.attention_and_distributed_packing import ( AttentionAndDistributedPackingConfig, - PaddingFree, MultiPack, + PaddingFree, ) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index fafb04e74..05465aab0 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -40,8 +40,8 @@ ) from tuning.config.acceleration_configs.attention_and_distributed_packing import ( AttentionAndDistributedPackingConfig, - PaddingFree, MultiPack, + PaddingFree, ) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, @@ -514,7 +514,8 @@ def test_framework_initialize_and_trains_with_aadp(): model_args, data_args, train_args, - attention_and_distributed_packing_config=attention_and_distributed_packing_config, + attention_and_distributed_packing_config=\ + attention_and_distributed_packing_config, ) # spy inside the train to ensure that the ilab plugin is called @@ -573,7 +574,8 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset(): model_args, data_args, train_args, - attention_and_distributed_packing_config=attention_and_distributed_packing_config, + attention_and_distributed_packing_config=\ + attention_and_distributed_packing_config, ) @@ -596,6 +598,7 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled(): attention_and_distributed_packing_config=attention_and_distributed_packing_config, ) + def test_error_raised_with_multipack_and_paddingfree_disabled(): """Ensure error raised when padding-free is not used with flash attention""" with pytest.raises( diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index ef14f7932..1aad6abcb 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -123,7 +123,9 @@ def _verify_configured_dataclasses(self): # this also ensures that the attention implementation for multipack # will be flash attention as sfttrainer will enforce flash attn to be # set for padding free - assert self.padding_free is not None, "`--multipack` is currently only supported with `--padding_free`" + assert ( + self.padding_free is not None + ), "`--multipack` is currently only supported with `--padding_free`" @staticmethod def from_dataclasses(*dataclasses: Type):