-
Notifications
You must be signed in to change notification settings - Fork 228
Fix deepspeed prefix-lm #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4e131c
5e8299c
b0c5f10
f396c69
efcd497
43f56be
62fe447
e65852f
34df6f6
2ae4ba6
9182bee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -290,7 +290,6 @@ def forward(self, inputs, **kwargs): | |
| if hasattr(self._args, 'attn_mask'): | ||
| return embeddings | ||
| else: | ||
| assert False | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We remove this in order to allow this case. |
||
| return embeddings, attention_mask | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,14 +48,6 @@ def model_provider(pre_process=True, post_process=True): | |
| enabled=args.zero_stage == 3, | ||
| mpu=mpu): | ||
| if args.deepspeed: | ||
| model = GPTModelPipe( | ||
| num_tokentypes=0, | ||
| parallel_output=True | ||
| ) | ||
| # This is a hack to give us a reference to get_batch_pipe from within training.py | ||
| # We need to call model.set_batch_fn after deepspeed.initialize | ||
| model._megatron_batch_fn = get_batch_pipe | ||
|
|
||
| # Precompute the attention mask and store it in args. This avoids having to | ||
| # pipeline it as an activation during training. The mask is constant, and thus | ||
| # we can reuse it. | ||
|
|
@@ -73,6 +65,14 @@ def model_provider(pre_process=True, post_process=True): | |
| # must be bool or the training crashes expecting bool, but getting Half | ||
| args.attn_mask = attention_mask.to(torch.bool) | ||
| args.attn_mask_original = attention_mask.to(torch.bool) | ||
|
|
||
| model = GPTModelPipe( | ||
| num_tokentypes=0, | ||
| parallel_output=True | ||
| ) | ||
| # This is a hack to give us a reference to get_batch_pipe from within training.py | ||
| # We need to call model.set_batch_fn after deepspeed.initialize | ||
| model._megatron_batch_fn = get_batch_pipe | ||
|
Comment on lines
+69
to
+75
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We move this part of the code after setting attention mask in |
||
| else: | ||
| model = GPTModel( | ||
| num_tokentypes=0, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately we drop attention here ....