Skip to content

KevinL10/nanofsdp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NanoFSDP

A simple, minimal implementation of Fully Sharded Data Parallel (FSDP) for Pytorch in ~300 LOC.

Usage

To train a model with nanofsdp:

import nanofsdp

model = MyModel()
nanofsdp.fully_shard(model)

optimizer = AdamW(model.parameters())
...

You can also call fully_shard on intermediate modules to shard the parameters separately.

model = Transformer()

for layer in model.layers:
    nanofsdp.fully_shard(layer)

nanofsdp.fully_shard(model)

See mnist.py and transformer.py for end-to-end training scripts using nanofsdp.

How it works

NanoFSDP registers pre/post-hooks during the forward and backward passes to automatically manage sharded parameters and perform collective operations.

Sharding

Parameters (and gradients) are sharded along dim-0 and converted to DTensors. This lets us transparently use model.parameters() for the optimizer, since any parameter updates will be performed directly on the local DTensor slice.

dtensor Example: DTensors sharded over dim-0 with two ranks.

Hooks

pytorch FSDP Credit: PyTorch team (pytorch/pytorch#114299).

The pre-forward and post-forward hooks are registered onto the module via register_forward_pre_hook and register_forward_hook. The pre-forward hook is responsible for unsharding the parameters via an all-gather, while the post-forward hook reshards the parameters back to DTensors.

The pre-forward hook also registers the post-backward hook by wrapping the inputs in an identity autograd function (PostBackwardHook). Meanwhile, the post-forward hook registers the pre-backward hook by attaching register_hook on all output tensors.

Profiling

nsys profiling nsys profile: all-gathers and reduce-scatters for simple.py.

simple.py is a short script that shards a toy model and performs a number of forward/backward passes.

With the profile above, we can verify that the sharding behaviour works as expected. The parameters are all-gathered at the start of the forward/backward passes, while the gradients are reduce-scattered at the end of the backward pass.

The intervals before the all-gathers come from the overhead of rearranging and concatenating each rank's shards.

Memory Usage

Memory usage was benchmarked over five forward/backward passes using the Transformer in transformer.py, with d_model=1024, n_layers=32, and context_length=256.

memory usage with 1 device Figure: memory usage over five training passes without nanofsdp.

memory usage with 8 devices Figure: memory usage over five training passes with nanofsdp on 8 devices.

With nanofsdp, the peak memory usage is reduced from ~20 GB to ~7 GB. The pyramid-like patterns in the second diagram reflect activation memory: allocations grow through the forward pass and are gradually released during the backward pass.

The same trend holds for fewer devices. We found that two devices achieve a peak memory of ~11.5 GB, while four devices further reduce the peak to ~8.5 GB.

Acknowledgements

The key design decisions come from the PyTorch team – see the fully_shard docs and the FSDP RFC.

About

A simple, minimal implementation of FSDP for PyTorch in ~300 LOC.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages