Description
System Info
My goal is to follow Distributed fine-tuning blogpost with FSDP to test with distributed fine-tuning on larger size of model like 300B Grok-1.
Context is that I have tried g5.48xlarge (8 GPUs with 192 GB and 768 GB CPU) and p4d.24xlarge (8 GPUs. with 320 GB and 1152 GB CPU). There are two issues listed as following.
Transformer version is: transformers==4.40.0
Issue 1
When I tried to load the model with 4 bits quantization with code below (WITHOUT FSDP and it is purely on a EC2 of g5.48xlarge), the total GPU memory required should be around 150GB (since model is ~300B Grok-1), which is smaller than 192GB GPU memory of g5.48xlarge, but I hit OOM. If I turn on low_cpu_mem_usage=True
, then the model can be successfully loaded on CPU in the EC2 of g5.48xlarge. Same error happens at p4d.24xlarge
where 4 bit quantization is failed at loading.
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
set_seed,
)
import torch
import os
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_quant_storage=quant_storage_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
"keyfan/grok-1-hf",
quantization_config=quantization_config,
torch_dtype=quant_storage_dtype,
use_cache=(
False
),
trust_remote_code=True,
)
Issue 2
Continue on point 1, i think I find a path forward to load the model into CPU by setting low_cpu_mem_usage=True
. Follow the blogpost above, I start try SageMaker training job and I try to load this model using the default qlora_fsdp script, shown in the blog. Further, I disabled the quantization (as the quantization will load the model into GPUs but it failed in the point 1). Since when FSDP is enabled, it will by default use
low_cpu_mem_usage=True
according to this line. However, I hit timeout issue even after I modified training argument ddp_timeout
to be 10800.
The model checkpoints are loaded twice and failed at second time of loading.
return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 5%|▌ | 1/19 [00:00<00:03, 5.36it/s]
Loading checkpoint shards: 11%|█ | 2/19 [00:00<00:03, 5.28it/s]
Loading checkpoint shards: 16%|█▌ | 3/19 [00:00<00:03, 5.24it/s]
Loading checkpoint shards: 21%|██ | 4/19 [00:00<00:02, 5.23it/s]
Loading checkpoint shards: 26%|██▋ | 5/19 [00:00<00:02, 5.29it/s]
Loading checkpoint shards: 32%|███▏ | 6/19 [00:01<00:02, 5.27it/s]
Loading checkpoint shards: 37%|███▋ | 7/19 [00:01<00:02, 5.25it/s]
Loading checkpoint shards: 42%|████▏ | 8/19 [00:01<00:02, 5.25it/s]
Loading checkpoint shards: 47%|████▋ | 9/19 [00:01<00:01, 5.23it/s]
Loading checkpoint shards: 53%|█████▎ | 10/19 [00:01<00:01, 5.21it/s]
Loading checkpoint shards: 58%|█████▊ | 11/19 [00:02<00:01, 5.20it/s]
Loading checkpoint shards: 63%|██████▎ | 12/19 [00:02<00:01, 5.20it/s]
Loading checkpoint shards: 68%|██████▊ | 13/19 [00:02<00:01, 5.19it/s]
Loading checkpoint shards: 74%|███████▎ | 14/19 [00:02<00:00, 5.21it/s]
Loading checkpoint shards: 79%|███████▉ | 15/19 [00:02<00:00, 5.20it/s]
Loading checkpoint shards: 84%|████████▍ | 16/19 [00:03<00:00, 5.19it/s]
Loading checkpoint shards: 89%|████████▉ | 17/19 [00:03<00:00, 5.19it/s]
Loading checkpoint shards: 95%|█████████▍| 18/19 [00:03<00:00, 5.24it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00, 5.27it/s]
Loading checkpoint shards: 100%|██████████| 19/19 [00:03<00:00, 5.23it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 1 examples [00:00, 8.21 examples/s]
Generating train split: 827 examples [00:00, 1985.28 examples/s]
Generating train split: 1641 examples [00:00, 3495.17 examples/s]
Generating train split: 2496 examples [00:00, 3041.11 examples/s]
Generating train split: 3324 examples [00:01, 3366.71 examples/s]
Generating train split: 4001 examples [00:01, 3996.93 examples/s]
Generating train split: 4797 examples [00:01, 4292.10 examples/s]
Generating train split: 5698 examples [00:01, 4238.81 examples/s]
Generating train split: 6060 examples [00:01, 3625.63 examples/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 324 examples [00:00, 4303.12 examples/s]
Loading checkpoint shards: 5%|▌ | 1/19 [01:51<33:35, 111.97s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:56<34:51, 116.18s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:55<34:45, 115.86s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:55<34:46, 115.93s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:55<34:46, 115.89s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:55<34:47, 115.98s/it]
Loading checkpoint shards: 5%|▌ | 1/19 [01:57<35:21, 117.86s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:45<42:02, 148.38s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:49<42:21, 149.50s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:48<42:19, 149.37s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:48<42:20, 149.42s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:48<42:19, 149.39s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:50<42:32, 150.16s/it]
Loading checkpoint shards: 11%|█ | 2/19 [04:51<42:45, 150.92s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:27<40:58, 153.63s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:27<41:10, 154.41s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:28<41:01, 153.85s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:27<41:00, 153.78s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:29<41:13, 154.57s/it]
Loading checkpoint shards: 16%|█▌ | 3/19 [07:31<41:20, 155.04s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:21<40:24, 161.67s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:21<40:24, 161.63s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:22<40:28, 161.92s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:22<40:36, 162.40s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:21<40:28, 161.87s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:23<40:28, 161.92s/it]
Loading checkpoint shards: 21%|██ | 4/19 [10:26<40:43, 162.87s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:01<37:33, 160.99s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:01<37:27, 160.56s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:02<37:37, 161.28s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:02<37:37, 161.22s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:02<37:37, 161.26s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:02<37:45, 161.83s/it]
Loading checkpoint shards: 26%|██▋ | 5/19 [13:06<37:46, 161.88s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:56<35:54, 165.76s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:56<35:53, 165.62s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:56<35:53, 165.67s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:57<36:00, 166.18s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:57<35:56, 165.90s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [15:57<36:00, 166.17s/it]
Loading checkpoint shards: 32%|███▏ | 6/19 [16:01<36:01, 166.23s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:36<32:47, 164.00s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:38<32:54, 164.56s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:38<32:55, 164.67s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:38<32:53, 164.44s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:39<32:55, 164.63s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:40<33:00, 165.05s/it]
Loading checkpoint shards: 37%|███▋ | 7/19 [18:45<33:05, 165.47s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:36<30:58, 168.96s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:36<30:55, 168.64s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:36<30:56, 168.80s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:36<30:56, 168.78s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:36<30:57, 168.91s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:38<31:01, 169.27s/it]
Loading checkpoint shards: 42%|████▏ | 8/19 [21:44<31:09, 169.91s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:19<27:49, 166.91s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:21<27:57, 167.72s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:21<27:57, 167.71s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:21<27:56, 167.63s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:21<27:56, 167.69s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:22<27:56, 167.63s/it]
Loading checkpoint shards: 47%|████▋ | 9/19 [24:27<27:55, 167.55s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:17<25:30, 170.07s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:17<25:30, 170.01s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:17<25:31, 170.16s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:17<25:31, 170.15s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:21<25:38, 170.90s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:21<25:41, 171.30s/it]
Loading checkpoint shards: 53%|█████▎ | 10/19 [27:24<25:35, 170.64s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [29:54<22:10, 166.27s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [29:54<22:10, 166.29s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [29:54<22:10, 166.26s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [29:57<22:13, 166.71s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [29:57<22:16, 167.09s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [30:00<22:19, 167.44s/it]
Loading checkpoint shards: 58%|█████▊ | 11/19 [30:03<22:16, 167.03s/it]
[E ProcessGroupNCCL.cpp:474] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
warnings.warn(
[E ProcessGroupNCCL.cpp:488] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:494] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:915] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
what(): [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800772 milliseconds before timing out.
Who can help?
@philschmid @SunMarc @lewtun @sgugger @ArthurZucker @pacman100
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Same as above
Expected behavior
Should be no OOM