Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Training Problems for QLoRA models with Transformers pre-release 4.45 #83

Closed
achew010 opened this issue Sep 11, 2024 · 2 comments · Fixed by #90
Closed

Distributed Training Problems for QLoRA models with Transformers pre-release 4.45 #83

achew010 opened this issue Sep 11, 2024 · 2 comments · Fixed by #90

Comments

@achew010
Copy link
Collaborator

achew010 commented Sep 11, 2024

Root Cause

The root cause is due to recent transformers update to resolve high CPU usage for large quantized models.

  • what the PR does is to when loading the state dict, puts the quantized model on the meta device on ranks > 0, and then shards the weights during the preperation step
  • however, this will require torch.distributed memory calls during Accelerator.prepare_model, which we observe it being stuck for QLoRA Mistral 7B.
  • This is most probably an NCCL issue. THe PR also says that similar issues were encountered on their test platform (AWS), but overcome by upgrading to most recent NCCL. unfortunately, this fix didnt work on our infrastructure.

What was observed

Running experiments to test new Granite models (e.g. ibm/PowerLM-3b) available on Transformers==4.45.0.dev0. Encountered the following issues;

  1. Hanging inside trainer.train() leading to an eventual distributed timeout error for FSDP-QLoRA experiments despite only using standard HF libraries in our baseline experiments.
[rank0]:[E906 21:06:59.080356778 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 0] First PG on this rank that detected no heartbeat of its watchdog.
[rank0]:[E906 21:06:59.080547547 ProcessGroupNCCL.cpp:1413] [PG 0 (default_pg) Rank 0] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList_.size()=8
[rank0]:[F906 21:16:59.081481788 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 0] [PG 0 (default_pg) Rank 0] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList_.size() = 8
  1. Issue with failing to install FOAK plugin for FSDP-QLoRA. During registration of DDP gradient reduction hooks for LoRA adapters, weights cannot be casted to cuda on non-zero ranked devices as there are no actual weights on meta, this is due to the efficient-cpu-ram-mode fix that now puts all weights of non-zero ranked devices on meta device.
ERROR:sft_trainer.py:Traceback (most recent call last):
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 585, in main
    trainer = train(
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 367, in train
    for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
  File "/data/aaron/experimental/fms-acceleration/plugins/framework/src/fms_acceleration/framework.py", line 260, in get_callbacks_and_ready_for_train
    cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
  File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py", line 164, in get_callbacks_and_ready_for_train
    lora_adapters_switch_ddp_from_fsdp(
  File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py", line 58, in lora_adapters_switch_ddp_from_fsdp
    set_module_tensor_to_device(A, "weight", "cuda")
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 364, in set_module_tensor_to_device
    raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
ValueError: weight is on the meta device, we need a `value` to put in on cuda.
  1. Update We have reported a variant of this problem, whereby accelerate.prepare_model gets stuck when low_cpu_mem_mode is enabled. This is reported in Update Benchmarks and Documentation for GraniteCausalLM #86
    • this problem seems to be observed for selected models. For GraniteCausalForLM it is observed, but for other models like Mistral7B it is not
    • Update: this problem seems to be coming from tied_weights and how FSDP1 is handling it for the quantized models. The key differnce is that only the granite models have tied weights, as opposed to mistral and llama which done.

Reproduce

  • Issue 1
export ACCELERATION_FRAMEWORK_CONFIG_FILE=sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --packing True --max_seq_len 4096 --training_data_path benchmark_outputs/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_33/hf --skip_memory_metrics False
  • Issue 2
export ACCELERATION_FRAMEWORK_CONFIG_FILE=sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --packing True --max_seq_len 4096 --training_data_path benchmark_outputs/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_57/hf --skip_memory_metrics False
  • Issue 3

Ensure versions

pip install accelerate==0.34.1
pip install transformers==4.45
pip install git+https://github.com/huggingface/trl.git@9b80f3d50ccb98ceee94bab4145a36e7e58aa4eb

Also please remove the workaround to disable low_cpu_mem_mode in #86

export ACCELERATION_FRAMEWORK_CONFIG_FILE=/workspace/fms-acceleration/scripts/benchmarks/../../sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29511 -m tuning.sft_trainer --model_name_or_path ibm/PowerLM-3b --packing True --max_seq_len 4096 --training_data_path benchmark_outputs_debug/data/cache_all.json --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --bf16 True --learning_rate 2e-4 --torch_dtype bfloat16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.1 --target_modules q_proj k_proj v_proj o_proj --per_device_train_batch_size 2 --output_dir benchmark_outputs_debug/exp_1/hf --skip_memory_metrics False

Dependencies

transformers==transformers @ git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10
trl==0.9.6
accelerate==0.33.0
torch==2.4.0
triton==3.0.0
@fabianlim fabianlim changed the title QLoRA experiments hanging with Transformers==4.45.0.dev0 Distributed Training Problems for QLoRA models with Transformers pre-release 4.45 Sep 11, 2024
@fabianlim
Copy link
Contributor

@achew010 @wynterl i made some progress with this. If we comment out

https://github.com/huggingface/trl/blob/c3143832cb305139b2551af2e00f008b4d64a981/trl/trainer/sft_trainer.py#L211-L275

and replace with

import fms_acceleration_peft

from fms_acceleration_peft.framework_plugin_bnb import _prepare_model_for_kbit_training
model = _prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs,
)

model = get_peft_model(model, peft_config)

Which suggests one of those lines that were commented out is causing the issue.

@fabianlim
Copy link
Contributor

Update: the problem 1) is because with the new fix then this https://github.com/huggingface/trl/blob/c3143832cb305139b2551af2e00f008b4d64a981/trl/trainer/sft_trainer.py#L231 does not hold anymore

fabianlim added a commit that referenced this issue Oct 11, 2024
fabianlim added a commit that referenced this issue Oct 14, 2024
@fabianlim fabianlim linked a pull request Oct 14, 2024 that will close this issue
3 tasks
fabianlim added a commit that referenced this issue Oct 14, 2024
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
fabianlim added a commit that referenced this issue Oct 17, 2024
* address issue 2 in #83

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* properly handle broadcast of adapters

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* handle param_init_fn_tied_param

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* trl version error

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* tied weights fix and meta fix for autogptq

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* update readme

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* fmt + lint

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* upgrade granite benches

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

---------

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants