Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save cpu mem by leveraging FSDP rank0 broadcasting #77

Merged
merged 14 commits into from
Aug 11, 2023

Conversation

lchu-ibm
Copy link
Contributor

@lchu-ibm lchu-ibm commented Aug 1, 2023

What does this PR do?

for FSDP mode, this saves cpu memory by loading only one cpu copy of the model. This is specifically useful when using llama 70B as the current code would consume 2+ TB of cpu memory with 70B (70 * 4 * 8), which will cause cpu oom.

Notes

  1. This would require latest nightlies. I vaguely remembered I hit various of issues with sync_module_states+param_init_fn in the past until the nightlies in the most recent months.
  2. I wasn't sure what's the best in-general _param_init_fn we should use in the current version given the fast evolving PRs around meta device init. maybe @awgu can comment.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

cc @HamidShojanazeri

llama_finetuning.py Outdated Show resolved Hide resolved
llama_finetuning.py Outdated Show resolved Hide resolved
Copy link
Contributor

@HamidShojanazeri HamidShojanazeri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lchu-ibm for the PR! pls refer to inline comment.

llama_finetuning.py Outdated Show resolved Hide resolved
@HamidShojanazeri
Copy link
Contributor

HamidShojanazeri commented Aug 3, 2023

Thanks @lchu-ibm for the updates, I would appreciate if we could add similar comments from the code about this feature to here, and here as well.

llama_finetuning.py Outdated Show resolved Hide resolved
Copy link
Contributor

@chauhang chauhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lchu-ibm Thanks for this PR to address the CPU OOM issues for 70B model. The code has some changes in the usage of "rank" and "local_rank". It will be good to test this on both single host multi-gpu and multi-host multi-gpu to verify things work correctly for both cases. It will be great if you can run the tests and attach the logs as well.

@lchu-ibm
Copy link
Contributor Author

lchu-ibm commented Aug 6, 2023

@chauhang Thanks for the suggestions! please see my response in your inline comment on the rank fix. Also, I have just done a quick code cleanup by optimizing the imports as original code also have a bunch of unused imports.

raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we figure out why torch.device("meta") init doesn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma for non-0 ranks, we are using torch.device("meta") init.

llama_finetuning.py Outdated Show resolved Hide resolved
@pacman100
Copy link

Hello everyone, FYI: PR huggingface/transformers#25107 and huggingface/accelerate#1777 will handle the loading of models when using transformers to be efficient to avoid CPU OOMs when using FSDP without any code changes from user side. Currently, testing that out on 70B on single-node and multi-node setups.

@HamidShojanazeri
Copy link
Contributor

HamidShojanazeri commented Aug 10, 2023

@pacman100 Thanks for the update, that would be very helpful will give it a try. Can you pls elaborate a bit on the usage as well.

@pacman100
Copy link

Hello, just using AutoModelForCausalLM.from_pretrained() should work as long as one is using accelerate launcher with FSDP enabled. Basically, when FSDP is enabled with Accelerate, it sets env variable ACCELERATE_USE_FSDP to True and I am using that in the from_pretrained method:

def is_fsdp_enabled():
    return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1

So, if you don't want to use accelerate launcher, you can simply run export ACCELERATE_USE_FSDP=true and then have your own training loop wherein you properly use FDSP class with sync_module_states=True

How it works?

  1. Have the model load on meta device on all ranks
  2. Load the state dict only on rank==0 and set the param values from meta to cpu for rank==0
  3. For all other ranks, do torch.empty(*param.size(), dtype=dtype) for every parameter on meta device
  4. So, rank==0 will have loaded the model with correct state dict while all other ranks will have random/0 weights.
  5. Set sync_module_states=True so that FSDP object takes care of broadcasting them to all the ranks before training starts.

here is the output on a 7B model:

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744
accelerator.process_index=0 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       requires_grad=True)
accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506
accelerator.process_index=1 model.lm_head.weight=Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)
accelerator.process_index=0 GPU Memory before entering the prepare : 0
accelerator.process_index=0 GPU Memory consumed at the end of the prepare (end-begin): 13202
accelerator.process_index=0 GPU Peak Memory consumed during the prepare (max-begin): 15458
accelerator.process_index=0 GPU Total Peak Memory consumed during the prepare (max): 15458
accelerator.process_index=0 CPU Memory before entering the prepare : 27345
accelerator.process_index=0 CPU Memory consumed at the end of the prepare (end-begin): -26394
accelerator.process_index=0 CPU Peak Memory consumed during the prepare (max-begin): 0
accelerator.process_index=0 CPU Total Peak Memory consumed during the prepare (max): 27345
FullyShardedDataParallel(
  (_fsdp_wrapped_module): RWForCausalLM(
    (transformer): RWModel(
      (word_embeddings): Embedding(65024, 4544)
      (h): ModuleList(
        (0-31): 32 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): DecoderLayer(
            (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
            (self_attention): Attention(
              (maybe_rotary): RotaryEmbedding()
              (query_key_value): Linear(in_features=4544, out_features=4672, bias=False)
              (dense): Linear(in_features=4544, out_features=4544, bias=False)
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (mlp): MLP(
              (dense_h_to_4h): Linear(in_features=4544, out_features=18176, bias=False)
              (act): GELU(approximate='none')
              (dense_4h_to_h): Linear(in_features=18176, out_features=4544, bias=False)
            )
          )
        )
      )
      (ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=4544, out_features=65024, bias=False)
  )
)
accelerator.process_index=1 GPU Memory before entering the prepare : 0
accelerator.process_index=1 GPU Memory consumed at the end of the prepare (end-begin): 13202
accelerator.process_index=1 GPU Peak Memory consumed during the prepare (max-begin): 15458
accelerator.process_index=1 GPU Total Peak Memory consumed during the prepare (max): 15458
accelerator.process_index=1 CPU Memory before entering the prepare : 945
accelerator.process_index=1 CPU Memory consumed at the end of the prepare (end-begin): 4
accelerator.process_index=1 CPU Peak Memory consumed during the prepare (max-begin): 4
accelerator.process_index=1 CPU Total Peak Memory consumed during the prepare (max): 949
accelerator.process_index=1 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       device='cuda:1', requires_grad=True)
accelerator.process_index=0 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       device='cuda:0', requires_grad=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants