-
Notifications
You must be signed in to change notification settings - Fork 227
Description
we will need to hack https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/module.py#L378-L384 to support partition_method
type:embed:2|transformer:1
- or something like that - now the embed weights will get 2x partitioning weights and will get its own stage and all stages will be more balanced.
For context please see: #166 (comment)
It's actually not complicated at all. It's just a simple weighing scheme.
Let's look at partitioning weights to the code I quoted in the first para:
with 4 layers and 4 gpus
type:transformer
[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0]
gets partitioned as[0, 0, 0, 1], [1], [1], [1, 0, 0, 0, 0]
type:embed|transformer
[0, 1, 0, 1, 1, 1 1, 0, 0, 1, 0]
gets partitioned as[0, 1, 0, 1], [1], [1], [1, 0, 0, 1, 0]
(or something similar - I haven't validated),
but what we want is this:
the initial weights should be: [0, 2, 0, 1, 1, 1 1, 0, 0, 2, 0]
which now should gets partitioned as [0, 2], [0, 1, 1], [1, 1], [0, 0, 2, 0]
(note: I'm not exactly sure where the 0's belong, it should be easy to see with print debug or debugger)
For context: 250k dict for mt5 has a huge embedding. it's 2x bigger than a single layer (n 104B), that's why we want them partitioned so that an embedding has its own stage and then each 2 layers use another stage.
this is so in the case of 60 layers and 2 embeddings and 32 pipe stages.
and once we are happy we can contribute this to deepspeed.
p.s. need to think about the best syntax to use, probably weighted_type:embed:2|transformer:1