From efdd436663436e78d8ad3213d11325d86578db95 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:45:08 +0100 Subject: [PATCH] FIX [`PEFT` / `Trainer` ] Handle better peft + quantized compiled models (#29055) * handle peft + compiled models * add tests * fixup * adapt from suggestions * clarify comment --- src/transformers/trainer.py | 6 ++++++ tests/trainer/test_trainer.py | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4994aef3af8133..a2436dadc1a812 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -429,6 +429,12 @@ def __init__( getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable ) + # Filter out quantized + compiled models + if _is_quantized_and_base_model and hasattr(model, "_orig_mod"): + raise ValueError( + "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT" + ) + # At this stage the model is already loaded if _is_quantized_and_base_model and not _is_peft_model(model): raise ValueError( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b64e93a2d17494..65eeb6d6238431 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,6 +62,7 @@ require_deepspeed, require_intel_extension_for_pytorch, require_optuna, + require_peft, require_ray, require_safetensors, require_sentencepiece, @@ -873,6 +874,42 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + @require_peft + @require_bitsandbytes + def test_bnb_compile(self): + from peft import LoraConfig, get_peft_model + + # Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box + # QLoRA + torch compile is not really supported yet, but we should at least support the model + # loading and let torch throw the + tiny_model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True + ) + + peft_config = LoraConfig( + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + tiny_model = get_peft_model(tiny_model, peft_config) + + tiny_model = torch.compile(tiny_model) + + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments( + tmp_dir, + learning_rate=1e-9, + logging_steps=5, + ) + with self.assertRaises(ValueError): + _ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa + @require_bitsandbytes def test_rmsprop_bnb(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)