Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add DataClass Arguments to Activate Padding-Free and MultiPack Plugin and FastKernels #280

Merged
Prev Previous commit
Next Next commit
plugin rename
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 20, 2024
commit 6e766338ab991d99067c9b91854c1802282a8cbe
6 changes: 3 additions & 3 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.instruct_lab_config import (
InstructLabConfig,
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_dataclass_parse_successfully():
assert isinstance(cfg.bnb_qlora, BNBQLoraConfig)

# 3. Specifing "--padding_free" will parse a PaddingFree class
parser = transformers.HfArgumentParser(dataclass_types=InstructLabConfig)
parser = transformers.HfArgumentParser(dataclass_types=AttentionAndDistributedPackingConfig)
(cfg,) = parser.parse_args_into_dataclasses(
["--padding_free", "huggingface"],
)
Expand Down
45 changes: 35 additions & 10 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.instruct_lab_config import (
InstructLabConfig,
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
Expand Down Expand Up @@ -467,8 +467,9 @@ def test_framework_intialized_properly_foak():


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="ilab"),
reason="Only runs if fms-accelerate is installed along with instruct-lab plugin",
not is_fms_accelerate_available(plugins="aadp"),
reason="Only runs if fms-accelerate is installed along with \
attention_and_distributed_packing plugin",
)
def test_framework_initialize_and_trains_with_ilab():
"""
Expand All @@ -489,7 +490,7 @@ def test_framework_initialize_and_trains_with_ilab():
data_args.dataset_text_field = None

# initialize a config
instruct_lab_config = InstructLabConfig(
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)

Expand All @@ -512,7 +513,8 @@ def test_framework_initialize_and_trains_with_ilab():
model_args,
data_args,
train_args,
instruct_lab_config=instruct_lab_config,
attention_and_distributed_packing_config=\
attention_and_distributed_packing_config,
)

# spy inside the train to ensure that the ilab plugin is called
Expand All @@ -522,8 +524,9 @@ def test_framework_initialize_and_trains_with_ilab():


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="ilab"),
reason="Only runs if fms-accelerate is installed along with instruct-lab plugin",
not is_fms_accelerate_available(plugins="aadp"),
reason="Only runs if fms-accelerate is installed along with \
attention_and_distributed_packing plugin",
)
def test_padding_free_plugin_raises_error_with_untokenized_dataset():
"""
Expand All @@ -547,7 +550,7 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset():
data_args.dataset_text_field = "output"

# initialize a config
instruct_lab_config = InstructLabConfig(
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)

Expand All @@ -570,5 +573,27 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset():
model_args,
data_args,
train_args,
instruct_lab_config=instruct_lab_config,
attention_and_distributed_packing_config=\
attention_and_distributed_packing_config,
)

def test_error_raised_with_paddingfree_and_flash_attn_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
ValueError,
match="`--padding_free` argument was called without enabling flash attention, \
ensure `use_flash_attn = True` to use padding-free flash attention",
):
attention_and_distributed_packing_config = \
AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)
model_args = copy.deepcopy(MODEL_ARGS)
model_args.use_flash_attn = False
sft_trainer.train(
model_args,
DATA_ARGS,
TRAIN_ARGS,
attention_and_distributed_packing_config=\
attention_and_distributed_packing_config
)
2 changes: 1 addition & 1 deletion tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# Local
from .acceleration_framework_config import AccelerationFrameworkConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .instruct_lab_config import InstructLabConfig
from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig
from .quantized_lora_config import QuantizedLoraConfig
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# Local
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
from .instruct_lab_config import PaddingFree
from .attention_and_distributed_packing import PaddingFree
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
from tuning.utils.import_utils import is_fms_accelerate_available

Expand Down Expand Up @@ -104,7 +104,7 @@ class AccelerationFrameworkConfig:
ConfigAnnotation(
path="training.attention",
experimental=True,
required_packages=["ilab"],
required_packages=["aadp"],
),
] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __post_init__(self):


@dataclass
class InstructLabConfig:
class AttentionAndDistributedPackingConfig:

padding_free: PaddingFree = None

Expand Down
30 changes: 18 additions & 12 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig,
FusedOpsAndKernelsConfig,
InstructLabConfig,
AttentionAndDistributedPackingConfig,
QuantizedLoraConfig,
)
from tuning.config.tracker_configs import (
Expand Down Expand Up @@ -87,7 +87,7 @@ def train(
exp_metadata: Optional[Dict] = None,
quantized_lora_config: Optional[QuantizedLoraConfig] = None,
fusedops_kernels_config: Optional[FusedOpsAndKernelsConfig] = None,
instruct_lab_config: Optional[InstructLabConfig] = None,
attention_and_distributed_packing_config: Optional[AttentionAndDistributedPackingConfig] = None,
):
"""Call the SFTTrainer

Expand Down Expand Up @@ -115,7 +115,7 @@ def train(
fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \
Should be used in combination with quantized_lora_config. Also currently
fused_lora and fast_kernels must used together (may change in future). \
instruct_lab_config: Used for padding free and multipack.
attention_and_distributed_packing_config: Used for padding-free attention and multipack.
"""

train_args, logger = set_log_level(train_args, "sft_trainer_train")
Expand All @@ -130,6 +130,12 @@ def train(
):
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

if (attention_and_distributed_packing_config.padding_free is not None and
model_args.use_flash_attn is False
):
raise ValueError("`--padding_free` argument was called without enabling \
flash attention, ensure `use_flash_attn = True` to use padding-free flash attention")

task_type = "CAUSAL_LM"
additional_metrics = {}

Expand Down Expand Up @@ -181,7 +187,7 @@ def train(
trainer_callbacks.append(cb)

framework = AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config, fusedops_kernels_config, instruct_lab_config
quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config
).get_framework()

model_loader = AutoModelForCausalLM.from_pretrained
Expand Down Expand Up @@ -439,7 +445,7 @@ def get_parser():
AimConfig,
QuantizedLoraConfig,
FusedOpsAndKernelsConfig,
InstructLabConfig,
AttentionAndDistributedPackingConfig,
)
)
parser.add_argument(
Expand Down Expand Up @@ -503,7 +509,7 @@ def parse_arguments(parser, json_config=None):
aim_config,
quantized_lora_config,
fusedops_kernels_config,
instruct_lab_config,
attention_and_distributed_packing_config,
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
Expand All @@ -519,7 +525,7 @@ def parse_arguments(parser, json_config=None):
aim_config,
quantized_lora_config,
fusedops_kernels_config,
instruct_lab_config,
attention_and_distributed_packing_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
Expand All @@ -544,7 +550,7 @@ def parse_arguments(parser, json_config=None):
aim_config,
quantized_lora_config,
fusedops_kernels_config,
instruct_lab_config,
attention_and_distributed_packing_config,
exp_metadata,
)

Expand All @@ -565,7 +571,7 @@ def main():
aim_config,
quantized_lora_config,
fusedops_kernels_config,
instruct_lab_config,
attention_and_distributed_packing_config,
exp_metadata,
) = parse_arguments(parser, job_config)

Expand All @@ -577,7 +583,7 @@ def main():
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
instruct_lab_config %s exp_metadata %s",
attention_and_distributed_packing_config %s exp_metadata %s",
model_args,
data_args,
training_args,
Expand All @@ -587,7 +593,7 @@ def main():
aim_config,
quantized_lora_config,
fusedops_kernels_config,
instruct_lab_config,
attention_and_distributed_packing_config,
exp_metadata,
)
except Exception as e: # pylint: disable=broad-except
Expand Down Expand Up @@ -629,7 +635,7 @@ def main():
exp_metadata=metadata,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
instruct_lab_config=instruct_lab_config,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
)
except (MemoryError, OutOfMemoryError) as e:
logger.error(traceback.format_exc())
Expand Down