Skip to content

Commit 1c2f36b

Browse files
parallelism goes brrr (#37877)
* accept custom device_mesh * fix device_map * assert that num_heads % tp_size == 0 * todo. * ReplicateParallel * handle tied weights * handle dtensor in save_pretrained with safe_serialization * tp test works * doesnt work * fix shard_and_distribute_module's rank should be local_rank * tp=4 is correct * dp+tp is broken * todo allreduce with dtensors on another dim is annoying * workaround to sync dp grads when using dtensors * loading a checkpoint works * wandb and compare losses with different tp/dp * cleaning * cleaning * . * . * logs * CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention * DP=2 TP=2 now works even with tied embeddings * model.parameters() and model.module.parameters() are empty.. * reformat sanity_check_tensor_sync * set atol=1e-4 for CP to pass * try populate _parameters from named_modules * refactors TP2 DP2 works CP2 DP2 works * is_causal=True and pack sequences, no attn mask, and preshuffle dataset * fix packing * CP=4 doesn't work * fix labels and position_ids for CP * DP CP works with transformers 🥳🥳🥳 * refactor * add example cp * fixup * revert sdpa changes * example cleared * add CP, DP to the mesh init * nit * clean * use `ALL_PARALLEL_STYLES` * style * FSDP works * log on 1 rank * . * fix? * FSDP1 also has .parameters() bug * reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay * . * style and fixup * move stuff around * fix tests * style * let's make it a check * warning should be an info --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
1 parent b591d92 commit 1c2f36b

File tree

6 files changed

+1515
-136
lines changed

6 files changed

+1515
-136
lines changed

0 commit comments

Comments
 (0)