Closed
Description
Root Cause
The root cause is due to recent transformers update to resolve high CPU usage for large quantized models.
- what the PR does is to when loading the state dict, puts the quantized model on the
meta
device on ranks > 0, and then shards the weights during the preperation step - however, this will require
torch.distributed
memory calls duringAccelerator.prepare_model
, which we observe it being stuck for QLoRAMistral 7B
. - This is most probably an NCCL issue. THe PR also says that similar issues were encountered on their test platform (AWS), but overcome by upgrading to most recent NCCL. unfortunately, this fix didnt work on our infrastructure.
What was observed
Running experiments to test new Granite models (e.g. ibm/PowerLM-3b
) available on Transformers==4.45.0.dev0
. Encountered the following issues;
- Hanging inside
trainer.train()
leading to an eventual distributed timeout error for FSDP-QLoRA experiments despite only using standard HF libraries in our baseline experiments.
[rank0]:[E906 21:06:59.080356778 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 0] First PG on this rank that detected no heartbeat of its watchdog.
[rank0]:[E906 21:06:59.080547547 ProcessGroupNCCL.cpp:1413] [PG 0 (default_pg) Rank 0] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList_.size()=8
[rank0]:[F906 21:16:59.081481788 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 0] [PG 0 (default_pg) Rank 0] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList_.size() = 8
- Issue with failing to install FOAK plugin for FSDP-QLoRA. During registration of DDP gradient reduction hooks for LoRA adapters, weights cannot be casted to
cuda
on non-zero ranked devices as there are no actual weights onmeta
, this is due to theefficient-cpu-ram-mode
fix that now puts all weights of non-zero ranked devices onmeta
device.
ERROR:sft_trainer.py:Traceback (most recent call last):
File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 585, in main
trainer = train(
File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 367, in train
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
File "/data/aaron/experimental/fms-acceleration/plugins/framework/src/fms_acceleration/framework.py", line 260, in get_callbacks_and_ready_for_train
cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py", line 164, in get_callbacks_and_ready_for_train
lora_adapters_switch_ddp_from_fsdp(
File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py", line 58, in lora_adapters_switch_ddp_from_fsdp
set_module_tensor_to_device(A, "weight", "cuda")
File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 364, in set_module_tensor_to_device
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
ValueError: weight is on the meta device, we need a `value` to put in on cuda.
- Update We have reported a variant of this problem, whereby
accelerate.prepare_model
gets stuck whenlow_cpu_mem_mode
is enabled. This is reported in Update Benchmarks and Documentation for GraniteCausalLM #86this problem seems to be observed for selected models. ForGraniteCausalForLM
it is observed, but for other models likeMistral7B
it is not- Update: this problem seems to be coming from tied_weights and how FSDP1 is handling it for the quantized models. The key differnce is that only the granite models have tied weights, as opposed to mistral and llama which done.
Reproduce
- Issue 1
export ACCELERATION_FRAMEWORK_CONFIG_FILE=sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --packing True --max_seq_len 4096 --training_data_path benchmark_outputs/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_33/hf --skip_memory_metrics False
- Issue 2
export ACCELERATION_FRAMEWORK_CONFIG_FILE=sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --packing True --max_seq_len 4096 --training_data_path benchmark_outputs/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_57/hf --skip_memory_metrics False
- Issue 3
Ensure versions
pip install accelerate==0.34.1
pip install transformers==4.45
pip install git+https://github.com/huggingface/trl.git@9b80f3d50ccb98ceee94bab4145a36e7e58aa4eb
Also please remove the workaround to disable low_cpu_mem_mode
in #86
export ACCELERATION_FRAMEWORK_CONFIG_FILE=/workspace/fms-acceleration/scripts/benchmarks/../../sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29511 -m tuning.sft_trainer --model_name_or_path ibm/PowerLM-3b --packing True --max_seq_len 4096 --training_data_path benchmark_outputs_debug/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs_debug/exp_1/hf --skip_memory_metrics False
Dependencies
transformers==transformers @ git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10
trl==0.9.6
accelerate==0.33.0
torch==2.4.0
triton==3.0.0
Metadata
Metadata
Assignees
Labels
No labels