-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add util for ram efficient loading of model when using fsdp #25107
add util for ram efficient loading of model when using fsdp #25107
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As said internally, I would prefer for this to be done automatically in from_pretrained
when FSDP is detected with the options that make sense. For instance we have several tests for DeepSpeed ZeRO3 in from_pretrained
, one loading the state dict only on the main process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better, thansk!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just have the unresolved comment on why the prefetch this is removed in the trainer file.
@pacman100 Hi, I've tried to run Llama2 with the two PR but it seems something went wrong. Plz check, thx! While copying the parameter named "model.layers.29.self_attn.v_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ('Cannot copy out of meta tensor; no data!\nException raised from copy_impl at ../aten/src/ATen/native/Copy.cpp:188 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f649c6c04d7 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)\nframe #1: + 0x11c32e4 (0x7f64ea8552e4 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #2: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x62 (0x7f64eb3deb32 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #3: at::ops::copy::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x7b (0x7f64ebff07db in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #4: + 0x5443145 (0x7f64eead5145 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #5: at::ops::copy::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x7b (0x7f64ebff07db in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #6: + 0x54454f4 (0x7f64eead74f4 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #7: at::ops::copy::call(at::Tensor&, at::Tensor const&, bool) + 0x15f (0x7f64ec04dadf in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #8: + 0x4cdbc9 (0x7f65030ddbc9 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so)\nframe #9: + 0x1453a3 (0x55e385cfb3a3 in /opt/conda/bin/python)\nframe #10: _PyEval_EvalFrameDefault + 0x6f3 (0x55e385ce9b13 in /opt/conda/bin/python)\nframe #11: + 0x1515df (0x55e385d075df in /opt/conda/bin/python)\nframe #12: _PyEval_EvalFrameDefault + 0x2b8f (0x55e385cebfaf in /opt/conda/bin/python)\nframe #13: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #14: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #15: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #16: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #17: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #18: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #19: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #20: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #21: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #22: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #23: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #24: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #25: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #26: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #27: + 0x150d7c (0x55e385d06d7c in /opt/conda/bin/python)\nframe #28: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #29: + 0x150d7c (0x55e385d06d7c in /opt/conda/bin/python)\nframe #30: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #31: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #32: PyObject_Call + 0xb8 (0x55e385d07a08 in /opt/conda/bin/python)\nframe #33: _PyEval_EvalFrameDefault + 0x2b8f (0x55e385cebfaf in /opt/conda/bin/python)\nframe #34: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #35: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #36: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #37: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #38: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #39: _PyEval_EvalFrameDefault + 0x4a35 (0x55e385cede55 in /opt/conda/bin/python)\nframe #40: + 0x1e64d2 (0x55e385d9c4d2 in /opt/conda/bin/python)\nframe #41: PyEval_EvalCode + 0x87 (0x55e385d9c417 in /opt/conda/bin/python)\nframe #42: + 0x219ed9 (0x55e385dcfed9 in /opt/conda/bin/python)\nframe #43: + 0x2147e4 (0x55e385dca7e4 in /opt/conda/bin/python)\nframe #44: + 0x98214 (0x55e385c4e214 in /opt/conda/bin/python)\nframe #45: _PyRun_SimpleFileObject + 0x1af (0x55e385dc4b1f in /opt/conda/bin/python)\nframe #46: _PyRun_AnyFileObject + 0x43 (0x55e385dc46e3 in /opt/conda/bin/python)\nframe #47: Py_RunMain + 0x39f (0x55e385dc189f in /opt/conda/bin/python)\nframe #48: Py_BytesMain + 0x39 (0x55e385d8f709 in /opt/conda/bin/python)\nframe #49: __libc_start_main + 0xf3 (0x7f6534704083 in /usr/lib/x86_64-linux-gnu/libc.so.6)\nframe #50: + 0x1d9611 (0x55e385d8f611 in /opt/conda/bin/python)\n',). |
Hello @lwmlyy, I'm able to run 70B Llama on 32 A100 80GB GPUs with it without any issues. Can you share the config, minimal example and launch command? |
@pacman100 I run the code in meta-llama/llama-recipes#77 with the following command: Could you also share the script for running Llama-70b as you mentioned? |
Hello @lwmlyy, follow this: meta-llama/llama-recipes#77 (comment) |
@pacman100 As you mentioned, if the model is loaded with accelerate, no code change is needed. I wonder why the error shows up. Could you give some advice? |
Hello, you aren't launching via |
@pacman100 Hi,I meet the same error when using the following command: It works fine with the command: But the loss is nan. |
Hello, you aren't using the accelerate integration of FSDP and you are mixing llama recipe implementation which doesn't use Accelerate. Please refer to the Accelerate docs on the proper way to use FSDP with Accelerate. Also, please raise a separate issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
Hi @pacman100 , I am trying to Train Llama 70B Model in FSDP, I was going through your repo https://github.com/pacman100/ram_efficient_fsdp/blob/main/train.py, code is failing when trying to import this function load_pretrained_model_only_on_rank0 getting error "ImportError: cannot import name 'load_pretrained_model_only_on_rank0' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/init.py)". Tried to check this function in the Transformer Repo but couldn't find one in the main branch. Can you please help me, how I can execute your code. Regards |
…ace#25107) * add util for ram efficient loading of model when using fsdp * make fix-copies * fixes 😅 * docs * making it further easier to use * rename the function * refactor to handle fsdp ram efficiency in `from_pretrained` * fixes * fixes * fixes * update * fixes * revert `load_pretrained_model_only_on_rank0` * resolve `load_from_checkpoint`
I'm currently using transformer v.4.37.2 and accelerate v.0.26.1 and am training on one machine with 2 GPU processors. I'm seeing the Mistral 7B model being loaded onto CPU RAM x2 (once for each processor). I don't understand why since this fix was released with earlier versions transformer v.4.32.0 and accelerate v.0.22.0 and should load the model onto CPU only once, independent of the number of processors. Any insight anyone has is super appreciated! These are the settings in my fsdp config file: |
Hello @pkaercher, To use this feature, you would need to use Accelerate config for FSDP along with Accelerate launcher. For more details on how to use this, please refer https://huggingface.co/docs/transformers/trainer#accelerate-and-trainer |
Hi @pacman100,
|
Hello @pkaercher, Thank you for the above code, this helps. You are calling Updating the docs here huggingface/accelerate#2430 with this information. |
Thank you @pacman100 ! I did as you suggested and saw the max CPU RAM usage during loading of the model drop from 65.2 GB to 47.2 GB, so it looks like it's working now. |
@pacman100 for quantized models the transformers/src/transformers/modeling_utils.py Lines 4217 to 4232 in dd4654e
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
) |
Hi @ArthurZucker @pacman100 @sgugger please help answer the question above? It seems to be a problem when loading a large size model like 300B with 4 bit quantization. I have a similar issue documented in #31577 In particular, this line at
On the other end, by using @philschmid 's blogpost and code, I am able to load and train two models:
Neither |
cc @SunMarc |
Hi @SunMarc @amyeroberts any chance you can take a look my questions and share insights? many thanks! |
What does this PR do?
param_init_fn
andsync_module_states=True
pytorch/pytorch#105840 when using FSDP for training very large models. Should be merged after support for ram efficient loading of model with FSDP accelerate#1777Currently, when using FSDP, the model is loaded for each of the N processes completely on CPU leading to huge CPU RAM usage. When training models like Flacon-40B with FSDP on a dgx node with 8 GPUs, it would lead to CPU RAM getting out of memory because each process is loading 160GB (40B x 4Bytes (FP32)) in CPU RAM for a total of 160*8=1280GB requirement which results in script getting killed due to out of CPU RAM.
To combat this, we load the model only on rank 0 and have it on meta device when rank!=0. Then use no-op param_init_fn along with sync_module_states=True for FSDP to properly init the weights on other ranks and broadcast the params from rank 0 to other ranks.
Usage:
No user-facing changes:
Post this PR:
So you can see that during loading Rank 1 doesn't take any more CPU RAM. And the performance between both setups matches.
To Do: