A simple, minimal implementation of Fully Sharded Data Parallel (FSDP) for Pytorch in ~300 LOC.
To train a model with nanofsdp:
import nanofsdp
model = MyModel()
nanofsdp.fully_shard(model)
optimizer = AdamW(model.parameters())
...You can also call fully_shard on intermediate modules to shard the parameters separately.
model = Transformer()
for layer in model.layers:
nanofsdp.fully_shard(layer)
nanofsdp.fully_shard(model)See mnist.py and transformer.py for end-to-end training scripts using nanofsdp.
NanoFSDP registers pre/post-hooks during the forward and backward passes to automatically manage sharded parameters and perform collective operations.
Parameters (and gradients) are sharded along dim-0 and converted to DTensors. This lets us transparently use model.parameters() for the optimizer, since any parameter updates will be performed directly on the local DTensor slice.
Example: DTensors sharded over dim-0 with two ranks.
Credit: PyTorch team (pytorch/pytorch#114299).
The pre-forward and post-forward hooks are registered onto the module via register_forward_pre_hook and register_forward_hook. The pre-forward hook is responsible for unsharding the parameters via an all-gather, while the post-forward hook reshards the parameters back to DTensors.
The pre-forward hook also registers the post-backward hook by wrapping the inputs in an identity autograd function (PostBackwardHook). Meanwhile, the post-forward hook registers the pre-backward hook by attaching register_hook on all output tensors.
nsys profile: all-gathers and reduce-scatters for simple.py.
simple.py is a short script that shards a toy model and performs a number of forward/backward passes.
With the profile above, we can verify that the sharding behaviour works as expected. The parameters are all-gathered at the start of the forward/backward passes, while the gradients are reduce-scattered at the end of the backward pass.
The intervals before the all-gathers come from the overhead of rearranging and concatenating each rank's shards.
Memory usage was benchmarked over five forward/backward passes using the Transformer in transformer.py, with d_model=1024, n_layers=32, and context_length=256.
Figure: memory usage over five training passes without nanofsdp.
Figure: memory usage over five training passes with nanofsdp on 8 devices.
With nanofsdp, the peak memory usage is reduced from ~20 GB to ~7 GB. The pyramid-like patterns in the second diagram reflect activation memory: allocations grow through the forward pass and are gradually released during the backward pass.
The same trend holds for fewer devices. We found that two devices achieve a peak memory of ~11.5 GB, while four devices further reduce the peak to ~8.5 GB.
The key design decisions come from the PyTorch team – see the fully_shard docs and the FSDP RFC.