Skip to content

OOM when loading 300B models with AutoModelForCausalLM.from_pretrained and BitsAndBytesConfig quantization. #31577

Closed
@Neo9061

Description

System Info

My goal is to follow Distributed fine-tuning blogpost with FSDP to test with distributed fine-tuning on larger size of model like 300B Grok-1.

Context is that I have tried g5.48xlarge (8 GPUs with 192 GB and 768 GB CPU) and p4d.24xlarge (8 GPUs. with 320 GB and 1152 GB CPU). There are two issues listed as following.

Transformer version is: transformers==4.40.0


Issue 1
When I tried to load the model with 4 bits quantization with code below (WITHOUT FSDP and it is purely on a EC2 of g5.48xlarge), the total GPU memory required should be around 150GB (since model is ~300B Grok-1), which is smaller than 192GB GPU memory of g5.48xlarge, but I hit OOM. If I turn on low_cpu_mem_usage=True, then the model can be successfully loaded on CPU in the EC2 of g5.48xlarge. Same error happens at p4d.24xlarge where 4 bit quantization is failed at loading.

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed,
)
import torch
import os
   
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=quant_storage_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    "keyfan/grok-1-hf",
    quantization_config=quantization_config,
    torch_dtype=quant_storage_dtype,
    use_cache=(
        False
    ), 
    trust_remote_code=True,
)

Issue 2

Continue on point 1, i think I find a path forward to load the model into CPU by setting low_cpu_mem_usage=True. Follow the blogpost above, I start try SageMaker training job and I try to load this model using the default qlora_fsdp script, shown in the blog. Further, I disabled the quantization (as the quantization will load the model into GPUs but it failed in the point 1). Since when FSDP is enabled, it will by default use
low_cpu_mem_usage=True according to this line. However, I hit timeout issue even after I modified training argument ddp_timeout to be 10800.

The model checkpoints are loaded twice and failed at second time of loading.

  return self.fget.__get__(instance, owner)()
Loading checkpoint shards:   5%|▌         | 1/19 [00:00<00:03,  5.36it/s]
Loading checkpoint shards:  11%|█         | 2/19 [00:00<00:03,  5.28it/s]
Loading checkpoint shards:  16%|█▌        | 3/19 [00:00<00:03,  5.24it/s]
Loading checkpoint shards:  21%|██        | 4/19 [00:00<00:02,  5.23it/s]
Loading checkpoint shards:  26%|██▋       | 5/19 [00:00<00:02,  5.29it/s]
Loading checkpoint shards:  32%|███▏      | 6/19 [00:01<00:02,  5.27it/s]
Loading checkpoint shards:  37%|███▋      | 7/19 [00:01<00:02,  5.25it/s]
Loading checkpoint shards:  42%|████▏     | 8/19 [00:01<00:02,  5.25it/s]
Loading checkpoint shards:  47%|████▋     | 9/19 [00:01<00:01,  5.23it/s]
Loading checkpoint shards:  53%|█████▎    | 10/19 [00:01<00:01,  5.21it/s]
Loading checkpoint shards:  58%|█████▊    | 11/19 [00:02<00:01,  5.20it/s]
Loading checkpoint shards:  63%|██████▎   | 12/19 [00:02<00:01,  5.20it/s]
Loading checkpoint shards:  68%|██████▊   | 13/19 [00:02<00:01,  5.19it/s]
Loading checkpoint shards:  74%|███████▎  | 14/19 [00:02<00:00,  5.21it/s]
Loading checkpoint shards:  79%|███████▉  | 15/19 [00:02<00:00,  5.20it/s]
Loading checkpoint shards:  84%|████████▍ | 16/19 [00:03<00:00,  5.19it/s]
Loading checkpoint shards:  89%|████████▉ | 17/19 [00:03<00:00,  5.19it/s]
Loading checkpoint shards:  95%|█████████▍| 18/19 [00:03<00:00,  5.24it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00,  5.27it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00,  5.23it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 1 examples [00:00,  8.21 examples/s]
Generating train split: 827 examples [00:00, 1985.28 examples/s]
Generating train split: 1641 examples [00:00, 3495.17 examples/s]
Generating train split: 2496 examples [00:00, 3041.11 examples/s]
Generating train split: 3324 examples [00:01, 3366.71 examples/s]
Generating train split: 4001 examples [00:01, 3996.93 examples/s]
Generating train split: 4797 examples [00:01, 4292.10 examples/s]
Generating train split: 5698 examples [00:01, 4238.81 examples/s]
Generating train split: 6060 examples [00:01, 3625.63 examples/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 324 examples [00:00, 4303.12 examples/s]
Loading checkpoint shards:   5%|▌         | 1/19 [01:51<33:35, 111.97s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:56<34:51, 116.18s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:45, 115.86s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:46, 115.93s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:46, 115.89s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:55<34:47, 115.98s/it]
Loading checkpoint shards:   5%|▌         | 1/19 [01:57<35:21, 117.86s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:45<42:02, 148.38s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:49<42:21, 149.50s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:19, 149.37s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:20, 149.42s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:48<42:19, 149.39s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:50<42:32, 150.16s/it]
Loading checkpoint shards:  11%|█         | 2/19 [04:51<42:45, 150.92s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<40:58, 153.63s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:10, 154.41s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:28<41:01, 153.85s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:29<41:13, 154.57s/it]
Loading checkpoint shards:  16%|█▌        | 3/19 [07:31<41:20, 155.04s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:24, 161.67s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:24, 161.63s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:22<40:28, 161.92s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:22<40:36, 162.40s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:21<40:28, 161.87s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:23<40:28, 161.92s/it]
Loading checkpoint shards:  21%|██        | 4/19 [10:26<40:43, 162.87s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:01<37:33, 160.99s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:01<37:27, 160.56s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.28s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.22s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:37, 161.26s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:02<37:45, 161.83s/it]
Loading checkpoint shards:  26%|██▋       | 5/19 [13:06<37:46, 161.88s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:54, 165.76s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:53, 165.62s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:56<35:53, 165.67s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<36:00, 166.18s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<35:56, 165.90s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [15:57<36:00, 166.17s/it]
Loading checkpoint shards:  32%|███▏      | 6/19 [16:01<36:01, 166.23s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:36<32:47, 164.00s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:54, 164.56s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:55, 164.67s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:38<32:53, 164.44s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:39<32:55, 164.63s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:40<33:00, 165.05s/it]
Loading checkpoint shards:  37%|███▋      | 7/19 [18:45<33:05, 165.47s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:58, 168.96s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:55, 168.64s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:56, 168.80s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:56, 168.78s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:36<30:57, 168.91s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:38<31:01, 169.27s/it]
Loading checkpoint shards:  42%|████▏     | 8/19 [21:44<31:09, 169.91s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:19<27:49, 166.91s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:57, 167.72s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:57, 167.71s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:56, 167.63s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:21<27:56, 167.69s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:22<27:56, 167.63s/it]
Loading checkpoint shards:  47%|████▋     | 9/19 [24:27<27:55, 167.55s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:30, 170.07s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:30, 170.01s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:31, 170.16s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:17<25:31, 170.15s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:21<25:38, 170.90s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:21<25:41, 171.30s/it]
Loading checkpoint shards:  53%|█████▎    | 10/19 [27:24<25:35, 170.64s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.27s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.29s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:54<22:10, 166.26s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:57<22:13, 166.71s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [29:57<22:16, 167.09s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [30:00<22:19, 167.44s/it]
Loading checkpoint shards:  58%|█████▊    | 11/19 [30:03<22:16, 167.03s/it]
[E ProcessGroupNCCL.cpp:474] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
[E ProcessGroupNCCL.cpp:488] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:494] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:915] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.

Who can help?

@philschmid @SunMarc @lewtun @sgugger @ArthurZucker @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Same as above

Expected behavior

Should be no OOM

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions