Skip to content

Commit d1809fe

Browse files
BenjaminBossanSunMarc
authored andcommitted
FIX FSDP plugin update for QLoRA (huggingface#36720)
The _fsdp_qlora_plugin_updates checks for LoraConfig but other PEFT methods can also support quantized models, e.g. VeRA. Therefore, the isinstance check is now looking for PeftConfig in general. Moreover, the fsdp_plugin variable may be undefined in the 2nd if condition, leading to an `UnboundLocalError` error. This is fixed by not assigning the variable at all. I checked for tests that may need updating but only found test_fsdp_config_transformers_auto_wrap associated with this change. AFAICT, this test does not cover the changed code, since the test does not start the training loop. Therefore, I haven't updated any tests. LMK if/how this fix should be tested. Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 26b4193 commit d1809fe

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/transformers/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5202,18 +5202,17 @@ def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
52025202

52035203
def _fsdp_qlora_plugin_updates(self):
52045204
if self.is_fsdp_enabled and _is_peft_model(self.model):
5205-
from peft import LoraConfig
5205+
from peft import PeftConfig
52065206
from peft.utils.other import fsdp_auto_wrap_policy
52075207

5208-
if isinstance(self.model.active_peft_config, LoraConfig):
5209-
fsdp_plugin = self.accelerator.state.fsdp_plugin
5210-
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
5208+
if isinstance(self.model.active_peft_config, PeftConfig):
5209+
self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
52115210
if (
52125211
getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
52135212
and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
52145213
and version.parse(accelerate_version) > version.parse("0.27.0")
52155214
):
5216-
fsdp_plugin.set_mixed_precision(
5215+
self.accelerator.state.fsdp_plugin.set_mixed_precision(
52175216
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
52185217
)
52195218

0 commit comments

Comments
 (0)