Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero3 checkpoint frozen params #3205

Merged
merged 26 commits into from
Apr 20, 2023
Merged

zero3 checkpoint frozen params #3205

merged 26 commits into from
Apr 20, 2023

Conversation

tjruwase
Copy link
Contributor

@tjruwase
Copy link
Contributor Author

@stas00, FYI

@stas00
Copy link
Collaborator

stas00 commented Apr 13, 2023

I tried it out - and when the checkpoint is saved, I get almost all frozen weights saved with size[0]

python tools/convert_checkpoint/inspect_checkpoint.py /hf/m4-master-3/save_dir/opt_step-10/accelerator_state/pytorch_model/zero_pp_rank_0_mp_rank_00_model_states.pt
loading checkpoint file: /hf/m4-master-3/save_dir/opt_step-10/accelerator_state/pytorch_model/zero_pp_rank_0_mp_rank_00_model_states.pt
[tensor] module.lm_head.weight = torch.Size([0])
[tensor] module.lm_head.additional_fc.weight = torch.Size([0])
[tensor] module.model.decoder.embed_tokens.weight = torch.Size([0])
[...]

I think they need to be gathered before saving.

But we probably shouldn't do that on every process as it'd be quite slow if the model has 50% frozen weights. if it's the same weights saving it once should be enough (at least on the shared fs, it won't work on non-shared fs).

the following will do the gathering:

diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 8c31a9d6..8b91e242 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -357,7 +357,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
         param_groups = []
         for param_group in self.optimizer.param_groups:
             frozen_params = [p for p in param_group["params"] if not p.requires_grad]
-            param_groups.append(frozen_params)
+            with deepspeed.zero.GatheredParameters(frozen_params, modifier_rank=None):
+                param_groups.append(frozen_params)
         return param_groups

     def _setup_for_real_optimizer(self):

but the saved tensors still appear to be of size 0. so that fix doesn't seem to be it.

Ah, I see - the original code will never succeed because frozen params aren't in optimizer.param_groups

@stas00
Copy link
Collaborator

stas00 commented Apr 13, 2023

I'm also thinking would this even work if there is a huge model with a lot of frozen params? There might not be enough memory to gather them all. Perhaps should save their fp16 shards instead? that would be much faster.

@shaankhosla
Copy link

Hi @stas00 and @tjruwase, thanks for your work on this. I'm just checking to see if this would fix an error I'm getting using DeepSpeed and LoRA. Let me know if this isn't the place to ask.

I'm able to train "t5" using DeepSpeed Stage 3 and LoRA, however when I run the load_state_dict_from_zero_checkpoint command I get an error KeyError: '_forward_module.model.base_model.model.encoder.embed_tokens.weight'

Thanks again for all your help!

@tjruwase
Copy link
Contributor Author

Hi @stas00 and @tjruwase, thanks for your work on this. I'm just checking to see if this would fix an error I'm getting using DeepSpeed and LoRA. Let me know if this isn't the place to ask.

I'm able to train "t5" using DeepSpeed Stage 3 and LoRA, however when I run the load_state_dict_from_zero_checkpoint command I get an error KeyError: '_forward_module.model.base_model.model.encoder.embed_tokens.weight'

Thanks again for all your help!

@shaankhosla, thanks for your interest. Please open a new ticket for this problem. It would be very helpful to provide more details for reproducing the problem in that ticket.

@shaankhosla
Copy link

Here it is: #3291 :)

@tjruwase tjruwase requested a review from ShijieZZZZ April 18, 2023 20:29
@tjruwase tjruwase enabled auto-merge (squash) April 20, 2023 18:48
@tjruwase tjruwase disabled auto-merge April 20, 2023 18:48
@tjruwase tjruwase enabled auto-merge (squash) April 20, 2023 18:49
@tjruwase tjruwase merged commit dd8df20 into master Apr 20, 2023
@stas00
Copy link
Collaborator

stas00 commented Apr 20, 2023

Thank you for the quick solving and merge, Tunji and the team!

@conglongli conglongli added deepspeed-chat Related to DeepSpeed-Chat and removed deepspeed-chat Related to DeepSpeed-Chat labels Apr 30, 2023
@mrwyattii mrwyattii deleted the olruwase/issue_3090 branch July 7, 2023 02:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] save/load checkpoint in zero3 fails to preserve frozen weights
6 participants