Skip to content

Conversation

@Niccolo-Ajroldi
Copy link

@Niccolo-Ajroldi Niccolo-Ajroldi commented Nov 3, 2025

Currently, all AllGather calls of the data-parallel Muon implementation are synchronous. This means that after orthogonalizing a gradient and updating its corresponding parameter, each GPU must wait for every other GPU to finish processing its parameter. We can make this faster by overlapping computation and communication, and just synchronizing at the end of the optimization step.

The modification is very simple. Replace this:

for base_i in ...:
    dist.all_gather(...)

with:

handles = []
for base_i in ...:
    handle = dist.all_gather(..., async_op=True)
    handles.append(handle)

for handle in handles:
    handle.wait()

Speed-up

I tested this on a 1B transformer model trained on 8xA100-80GB with DDP and observed a 20% speed-up in the optimization step when using the asynchronous version.

The speed-up will be even larger on models where the number of layers is not a multiple of the number of GPUs.

@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Make AllGather asynchronous in Asynchronous AllGather Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant