Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Attempting to unscale FP16 gradients. #1031

Open
6 of 8 tasks
hengjiUSTC opened this issue Jan 2, 2024 · 12 comments
Open
6 of 8 tasks

ValueError: Attempting to unscale FP16 gradients. #1031

hengjiUSTC opened this issue Jan 2, 2024 · 12 comments
Labels
bug Something isn't working

Comments

@hengjiUSTC
Copy link

hengjiUSTC commented Jan 2, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Should run correctly.

Current behaviour

running crash

wandb: WARNING Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save("/mnt/folder/file.h5", base_path="/mnt")
[2024-01-02 12:15:40,565] [INFO] [axolotl.callbacks.on_train_begin:572] [PID:9425] [RANK:0] The Axolotl config has been saved to the WandB run under files.
  0%|                                                                                                                                                            | 0/90 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/axolotl/src/axolotl/cli/train.py", line 38, in <module>
    fire.Fire(do_cli)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/axolotl/src/axolotl/cli/train.py", line 34, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/home/ubuntu/axolotl/src/axolotl/train.py", line 136, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1896, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2124, in clip_grad_norm_
    self.unscale_gradients()
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2087, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/axolotl/src/axolotl/cli/train.py", line 38, in <module>
    fire.Fire(do_cli)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/axolotl/src/axolotl/cli/train.py", line 34, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/home/ubuntu/axolotl/src/axolotl/train.py", line 136, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1896, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2124, in clip_grad_norm_
    self.unscale_gradients()
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2087, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/home/ubuntu/axolotl/venv/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")

Steps to reproduce

I use following config:

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

chat_template: chatml
datasets:
  - path: HenryJJ/tangshi
    type:
      system_prompt: ""
      field_system: system
      format: |-
        {instruction}
        Input: {input}
        Output: 
      no_input_format: "[INST] {instruction} [/INST]"
dataset_prepared_path:
val_set_size: 0.1
output_dir: ./out

adapter: lora
lora_model_dir:

sequence_len: 1024
sample_packing: false
pad_to_sequence_len: true
eval_sample_packing: false

lora_r: 64
lora_alpha: 16
lora_dropout: 0.1
lora_target_linear: true
lora_modules_to_save:
  - embed_tokens
  - lm_head


lora_fan_in_fan_out:

wandb_project: tangshi
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: constant
learning_rate: 0.0001

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

device_map: auto

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 2
xformers_attention:
flash_attention: false

warmup_steps: 10
eval_steps: 2
eval_batch_size: 4
eval_table_size:
eval_table_max_new_tokens:
save_steps: 10
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: "<unk>"

and run with python3 -m axolotl.cli.train mix_tangshi/config.yml

Config yaml

No response

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main commit 3678a6c

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@hengjiUSTC hengjiUSTC added the bug Something isn't working label Jan 2, 2024
@hengjiUSTC
Copy link
Author

However I am able to run lora with fp16 in my other experiments https://github.com/hengjiUSTC/learn-llm/blob/main/trl_finetune.py#L316. So I am not sure what is the expected behavior?

@hengjiUSTC
Copy link
Author

hengjiUSTC commented Jan 4, 2024

I found the bug happens when I set

lora_modules_to_save:
  - embed_tokens
  - lm_head
  1. Why I set it?
    The reason I set it is because detection in https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/utils/models.py#L153
    However given my special token is just setting pad_token to <unk> which is already in token list. I feel this detection shouldn't be triggered?

  2. Not sure why setting lora_modules_to_save with fp16 leads to crash.

  3. Another problem at https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/utils/models.py#L123. When I have flash_attention to false and is_mistral_derived_model to true, it will not set mixtral padding to left. Which is incorrect for Mixtral training

@winglian
Copy link
Collaborator

winglian commented Jan 4, 2024

I'm wondering if we are even supposed to be recasting to fp16. the original qlora only recasts when bf16 is used https://github.com/artidoro/qlora/blame/main/qlora.py#L396-L405

@winglian
Copy link
Collaborator

winglian commented Jan 4, 2024

@hengjiUSTC if you comment out these lines for your configuration above, does that fix the issue?

@hengjiUSTC
Copy link
Author

I am using Lora instead of Qlora, these lines won't be triggered https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/utils/models.py#L554-L561/

    if (cfg.adapter == "lora" and load_in_8bit) or (
        cfg.adapter == "qlora" and cfg.load_in_4bit
    ):

load_in_8bit is false and load_in_4bit is also false

@hengjiUSTC
Copy link
Author

hengjiUSTC commented Jan 7, 2024

See relevant discussion in :
huggingface/transformers#23165
huggingface/peft#341

Here are some experiements:
Break with raise ValueError("Attempting to unscale FP16 gradients.")

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )

No error for below two configs

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float32,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )
    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=False,
        ...
    )

I am a bit new to these settings, does anyone know what is the reason? (I am using T4 gpu, so not able to use bf16)
How should we handle this error in axololt?

@hengjiUSTC
Copy link
Author

hengjiUSTC commented Jan 11, 2024

I get confirmation that we should not load model in float16 when enable fp16 in peft config. huggingface/peft#341 (comment). But I do see a lot of code (other finetune repo) doing this. And it's the reason error is raised in Axolotl (when fp16 is ture in config.yml, model is loaded with float16 and fp16 is enabled in peft).

@ehartford
Copy link
Collaborator

I also have these lines because I am using ChatML and adding new tokens to the base model

lora_modules_to_save:
  - embed_tokens
  - lm_head

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Mar 30, 2024

Based on what @hengjiUSTC linked, if I understand it correctly, fp16 adapter training must use fp32 for trainable and fp16 for non-trainable. They provided a utility function cast_mixed_precision_params(peft_model, dtype) for us to use, but since we also handle gate/norm, we may need to adjust ourselves.

@RSDP101
Copy link

RSDP101 commented Aug 3, 2024

It worked for me by setting:

--mixed_precision="bf16"

@maxh2018
Copy link

set bf16 to True can work

@LBJ6666
Copy link

LBJ6666 commented Oct 12, 2024

See relevant discussion in : huggingface/transformers#23165 huggingface/peft#341

Here are some experiements: Break with raise ValueError("Attempting to unscale FP16 gradients.")

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )

No error for below two configs

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float32,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )
    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=False,
        ...
    )

I am a bit new to these settings, does anyone know what is the reason? (I am using T4 gpu, so not able to use bf16) How should we handle this error in axololt?

See relevant discussion in : huggingface/transformers#23165 huggingface/peft#341

Here are some experiements: Break with raise ValueError("Attempting to unscale FP16 gradients.")

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )

No error for below two configs

    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float32,
    )
    training_args = TrainingArguments(
        fp16=True,
        ...
    )
    model = AutoModelForCausalLM.from_pretrained(
        ...
        torch_dtype=torch.float16,
    )
    training_args = TrainingArguments(
        fp16=False,
        ...
    )

I am a bit new to these settings, does anyone know what is the reason? (I am using T4 gpu, so not able to use bf16) How should we handle this error in axololt?

I am using Kaggle's notebook environment with the following specifications:

  • transformers version: 4.45.1
  • Platform: Linux-5.15.154+-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.25.1
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (True)
  • Tensorflow version (GPU?): 2.16.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.4 (gpu)
  • Jax version: 0.4.26
  • JaxLib version: 0.4.26.dev20240620
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: Tesla T4

I attempted full fine-tuning, not LoRA
Setting 1: Loaded the model with torch_dtype=torch.float16, set fp16=True in TrainingArguments, trainer.train() did not throw an error, but the first step had a loss, and the second step had a loss of 0.
Setting 2: Loaded the model with torch_dtype=torch.float16, set bp16=True in TrainingArguments, trainer.train() did not throw an error, with the same issue of having a loss for the first step and 0 loss for the second step.
Is it possible that the new version has updated this error message, but the loss is wrong, to 0?
Also, how can I directly load a model in fp16 to avoid excessive GPU memory usage that leads to OOM?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants