Skip to content

[core][distributed] add ep group and all2all interface #18077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 14, 2025

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented May 13, 2025

the all2all interface should be refactored with a follow-up from #15956 .

this PR mainly adds ep (expert parallel) group, and all2all base class, so that backend-dependent code like pplx's pplx_init will not invade into common code path (vllm/distributed/parallel_state.py) .

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@mergify mergify bot added the v1 label May 13, 2025
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>

class All2AllBase:

def __init__(self, cpu_group, model):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for how to use cpu_group to initialize pplx, see ppl-ai/pplx-kernels#18 as an example.

here we have access to the model instance, so we can also get moe configs here, and we can assume all moe layers have the same config, without _all_to_all_cache .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from @nandor ,

uid = nvshmem_get_unique_id() if rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, rank, world_size)

is not necessary for intranode code.

we can tell if we are in single node with

def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao marked this pull request as ready for review May 13, 2025 13:29
Signed-off-by: youkaichao <youkaichao@gmail.com>
@varun-sundar-rabindranath
Copy link
Contributor

Thanks @youkaichao - The abstractions generally look good . But the dispatch-combine calls for pplx input a very different set of information than the naive all2all implementation. I think we can cross that bridge when we add the pplx All2All implementation to the abstraction - But just adding it as a note.

@youkaichao
Copy link
Member Author

But the dispatch-combine calls for pplx input a very different set of information than the naive all2all implementation.

agree, right now the interface is just porting code from the naive implementation. we should definitely change it when we add the pplx All2All implementation to the abstraction.

Comment on lines +30 to +35
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a function signature for dispatch and combine that will work across all kernels?

  • The naive multicast implementation
  • pplx-kernels
  • DeepEP

I think it would to look like:

def dispatch(self, 
             hidden_states: torch.Tensor,
             hs_scales: torch.Tensor, # Quantized scales for hidden_states
             router_logits: torch.Tensor,  
             topk_weights: torch.Tensor,  
             topk_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # out_hidden_states and out_scales

(There are other parameters to the pplx-kernels that are needed but I think they can be encapsulated in a PPLXAll2All wrapper class)

@varun-sundar-rabindranath @bnellnm WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another consideration here is that the output format of the different A2A implementations will be different.

  • For naive All2All, the output shape would be [total_tokens_across_dp, hidden_size]
  • For pplx All2All, it will be [num_experts_per_rank, max_tokens_per_expert, hidden_size], with padding along axis 1
  • For DeepEP I am not sure what the output format will look like

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a generic dispatch/combine interface for different backends (pplx, naive, deepep), I think we can have:

def dispatch(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:
def combine(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:

the backend can decide what does each tensor mean, and the moe layer can use assertion to make sure the list contains data it needs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for having the base class in this case? It seems that the caller will have to know exactly the implementation it's using

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mainly to manage the lifecycle of the communication library, and separate the implementations.

it is true that the dispatch / combine apis are quite different, but what matters in this pr is to unify the construction, backend selection, and deconstruction. as a byproduct, we need to have a generic dispatch / combine function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @tlrmchlsmth that it would be good if these interfaces matched the pplx AllToAll interfaces.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to change the interface in the moe modularization pr!

Comment on lines +364 to +370
try:
# pytorch <= 2.6
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
except ImportError:
# pytorch >= 2.7
pg.shutdown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this in the PPLX PR. I think this is a bit nicer than try/catch.

    if is_torch_equal_or_newer("2.7"):
        pg.shutdown()
    else:
        # Lazy import for non-CUDA backends.                                                                                                                            
        from torch.distributed.distributed_c10d import _shutdown_backend
        _shutdown_backend(pg)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to change it in the moe modularization pr!

@youkaichao youkaichao merged commit 6266c57 into vllm-project:main May 14, 2025
29 of 32 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…18077)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…18077)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants