-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed Training Problems for QLoRA models with Transformers pre-release 4.45 #83
Comments
fabianlim
changed the title
QLoRA experiments hanging with Transformers==
Distributed Training Problems for QLoRA models with Transformers pre-release 4.45
Sep 11, 2024
4.45.0.dev0
@achew010 @wynterl i made some progress with this. If we comment out and replace with import fms_acceleration_peft
from fms_acceleration_peft.framework_plugin_bnb import _prepare_model_for_kbit_training
model = _prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs,
)
model = get_peft_model(model, peft_config) Which suggests one of those lines that were commented out is causing the issue. |
Update: the problem 1) is because with the new fix then this https://github.com/huggingface/trl/blob/c3143832cb305139b2551af2e00f008b4d64a981/trl/trainer/sft_trainer.py#L231 does not hold anymore |
5 tasks
3 tasks
3 tasks
fabianlim
added a commit
that referenced
this issue
Oct 14, 2024
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
fabianlim
added a commit
that referenced
this issue
Oct 17, 2024
* address issue 2 in #83 Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * properly handle broadcast of adapters Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * handle param_init_fn_tied_param Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * trl version error Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * tied weights fix and meta fix for autogptq Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * update readme Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fmt + lint Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * upgrade granite benches Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> --------- Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Root Cause
The root cause is due to recent transformers update to resolve high CPU usage for large quantized models.
meta
device on ranks > 0, and then shards the weights during the preperation steptorch.distributed
memory calls duringAccelerator.prepare_model
, which we observe it being stuck for QLoRAMistral 7B
.What was observed
Running experiments to test new Granite models (e.g.
ibm/PowerLM-3b
) available on Transformers==4.45.0.dev0
. Encountered the following issues;trainer.train()
leading to an eventual distributed timeout error for FSDP-QLoRA experiments despite only using standard HF libraries in our baseline experiments.cuda
on non-zero ranked devices as there are no actual weights onmeta
, this is due to theefficient-cpu-ram-mode
fix that now puts all weights of non-zero ranked devices onmeta
device.accelerate.prepare_model
gets stuck whenlow_cpu_mem_mode
is enabled. This is reported in Update Benchmarks and Documentation for GraniteCausalLM #86this problem seems to be observed for selected models. ForGraniteCausalForLM
it is observed, but for other models likeMistral7B
it is notReproduce
Ensure versions
Also please remove the workaround to disable
low_cpu_mem_mode
in #86Dependencies
The text was updated successfully, but these errors were encountered: