Skip to content

Conversation

@thomasw21
Copy link
Contributor

Fixes necessary for implementation of prefix lm:

  • Passing tuple of tensors between stages
  • Allowing models to pass bool tensors. (This seems fixed in torch.distributed for pytorch > 1.7)

Shaden Smith and others added 30 commits June 6, 2021 11:27
…activation

Avoid partitioning small activations
* removes repeated overflow log

* pipe_replicated

* _pipe_replicated -> ds_pipe_replicated

* Adds send/recv fallback to bcast when torch version <= 1.8
…er (deepspeedai#1263)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
* Use mpu in DeepSpeedConfig() call

* Improve argument naming
* FP16 fused and unfused grad norm query.

* API for obtaining global unclipped gradient norm across parameter groups

* Use global norm not group norms

Co-authored-by: Shaden Smith <shaden.smith@microsoft.com>
* restore fp16 params if no zero ckpts available

* formatting
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, Thomas!

I left a few small suggestions.

I think the only issue here is that this new code doesn't have a test, so unless we add one it's likely to break down the road.

I'm not an expert on DS test suite, I added a few tests following monkey-see-monkey-do but this one might be not so easy to add. Let's first see if this is in general acceptable by the DS team and then we can work on a test.

group=self.grid.get_slice_parallel_group())
inputs = tuple([part.to_meta(), part.data()])

inputs = ([part.to_meta(), part.data(), *inputs_grad_tail])
Copy link
Collaborator

@stas00 stas00 Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this not be a tuple? it could be just fine, I was just comparing with the original and it was there...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tuple, in python you don't need to specifcy tuple, parenthesis work just fine.

Copy link
Contributor

@hyunwoongko hyunwoongko Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be list.

>>> list_ = [1, 2, 3]
>>> is_tuple = (list_)
>>> type(is_tuple)
<class 'list'>

the following line means type casting. parenthesis can not cast list into tuple.
so you need to remain keyword tuple.

inputs = ([part.to_meta(), part.data()])  # list
inputs = tuple([part.to_meta(), part.data()])  # tuple

Copy link
Contributor Author

@thomasw21 thomasw21 Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah well played I didn't notice, which now begs the question why it works ...

EDIT: It actually doesn't matter, because you end up looping over the elements of an iterable. But I'll make the desired changes.

@stas00
Copy link
Collaborator

stas00 commented Oct 5, 2021

@ShadenSmith, @jeffra - could you please have a look at the proposed changes - this is currently a blocker for our adding of prefix-lm to Meg-DS. Thank you!

@stas00
Copy link
Collaborator

stas00 commented Oct 5, 2021

@thomasw21, also please run:

pre-commit run --all-files

to auto-format your code - as you can see the build CI failed.

pip install pre-commit 

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Oct 6, 2021

스크린샷 2021-10-06 오후 2 03 06

@hyunwoongko
Copy link
Contributor

@thomasw21 I'll test your branch now. If it works, this will be a very great work.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Oct 6, 2021

@thomasw21 I got error from this line.

assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]])

Why all the output tensors except for first one must not be grad tensor when tuple was output?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Oct 6, 2021

I think this assertion torch.is_tensor() came from big-science branch. why did you guys change like it?

407ff0f#diff-26611f6be759237464a03bb1328cbc16555888836b3504dc3703e2e25d2a3ca3

torch.is_tensor() came from this PR.

@thomasw21
Copy link
Contributor Author

thomasw21 commented Oct 6, 2021

Hey @hyunwoongko! Concerning the grad issue, i had two reasons:

  • it doesn't fit our current need. Typically we wanted something similar to GPT2ModelPipe hacks. But in my opinion this is a more generic solution.
  • I'm not super sure on how we'd want to handle that, typically with PartitionedTensor. There are a few tricks with tensors that require grads, which I'm a bit unfamiliar with. Typically the way grads are stored https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/engine.py#L595-L597 So ideally I guess we'd want a similar way of handling grads to reduce memory footprint no?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Oct 6, 2021

#1432

I made un issue to make it more accessible for others.
I think someone may have similar problems later on as we do.

@thomasw21 thomasw21 changed the base branch from big-science to master October 6, 2021 08:11
@thomasw21 thomasw21 changed the base branch from master to big-science October 6, 2021 08:26
@thomasw21 thomasw21 changed the base branch from big-science to master October 6, 2021 08:32
@thomasw21
Copy link
Contributor Author

@stas00 Thanks ! I'll wait to get first general approval of the approach before writing tests I guess?

Something important to note is that passing bool tensors works because we use torch 1.7+ (there was an issue with bool tensors not working in distributed setting using nccl)

@stas00
Copy link
Collaborator

stas00 commented Oct 6, 2021

@stas00 Thanks ! I'll wait to get first general approval of the approach before writing tests I guess?

Something important to note is that passing bool tensors works because we use torch 1.7+ (there was an issue with bool tensors not working in distributed setting using nccl)

perhaps then it should check that pytorch>=1.7 - I don't know if we could simply require that for the pipeline files?

@jeffra, what do you think? Do you have internal needs to support torch<1.7 for PP?

Additionally, could one of you please confirm that this line of changes is agreeable to you?

@thomasw21
Copy link
Contributor Author

I should have fixed all the failing tests:

pytest --forked unit/test_checkpointing.py::test_checkpoint_moe unit/test_checkpointing.py::test_checkpoint_zero_no_optimizer unit/test_checkpointing.py::test_checkpoint_unique_tag unit/test_configurable_parallel.py::TestConfigurablePP::test_gpt2_mp2_pp_2to1 unit/test_configurable_parallel.py::TestConfigurablePP::test_gpt2_mp1_pp_2to1 unit/test_configurable_parallel.py::TestConfigurablePP::test_pp_basic unit/test_configurable_parallel.py::TestConfigurablePP::test_gpt2_pp_1to2_mp_1to2 unit/test_configurable_parallel.py::TestConfigurablePP::test_gpt2_mp2_pp_1to2 unit/test_configurable_parallel.py::TestConfigurablePP::test_gpt2_pp_2to1_mp_2to1 unit/test_configurable_parallel.py::TestConfigurableMP::test_gpt2_mp2_no_resize

image

@stas00
Copy link
Collaborator

stas00 commented Oct 7, 2021

@jeffra, @tjruwase - could you please start CI? Thank you!

Copy link
Contributor

@ShadenSmith ShadenSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks a ton!

@jeffra jeffra merged commit 9c67278 into deepspeedai:master Oct 7, 2021
@conglongli conglongli mentioned this pull request Oct 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants