-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save cpu mem by leveraging FSDP rank0 broadcasting #77
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lchu-ibm for the PR! pls refer to inline comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lchu-ibm Thanks for this PR to address the CPU OOM issues for 70B model. The code has some changes in the usage of "rank" and "local_rank". It will be good to test this on both single host multi-gpu and multi-host multi-gpu to verify things work correctly for both cases. It will be great if you can run the tests and attach the logs as well.
@chauhang Thanks for the suggestions! please see my response in your inline comment on the rank fix. Also, I have just done a quick code cleanup by optimizing the imports as original code also have a bunch of unused imports. |
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " | ||
"please install latest nightly.") | ||
if rank == 0: | ||
model = LlamaForCausalLM.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we figure out why torch.device("meta") init doesn't work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma for non-0 ranks, we are using torch.device("meta")
init.
Hello everyone, FYI: PR huggingface/transformers#25107 and huggingface/accelerate#1777 will handle the loading of models when using transformers to be efficient to avoid CPU OOMs when using FSDP without any code changes from user side. Currently, testing that out on 70B on single-node and multi-node setups. |
@pacman100 Thanks for the update, that would be very helpful will give it a try. Can you pls elaborate a bit on the usage as well. |
Hello, just using
So, if you don't want to use accelerate launcher, you can simply run How it works?
here is the output on a 7B model:
|
What does this PR do?
for FSDP mode, this saves cpu memory by loading only one cpu copy of the model. This is specifically useful when using llama 70B as the current code would consume 2+ TB of cpu memory with 70B (70 * 4 * 8), which will cause cpu oom.
Notes
sync_module_states
+param_init_fn
in the past until the nightlies in the most recent months.I wasn't sure what's the best in-general_param_init_fn
we should use in the current version given the fast evolving PRs around meta device init. maybe @awgu can comment.Before submitting
Pull Request section?
to it if that's the case.
cc @HamidShojanazeri