-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[core][distributed] add ep group and all2all interface #18077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
|
||
class All2AllBase: | ||
|
||
def __init__(self, cpu_group, model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for how to use cpu_group
to initialize pplx, see ppl-ai/pplx-kernels#18 as an example.
here we have access to the model instance, so we can also get moe configs here, and we can assume all moe layers have the same config, without _all_to_all_cache
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from @nandor ,
uid = nvshmem_get_unique_id() if rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, rank, world_size)
is not necessary for intranode code.
we can tell if we are in single node with
vllm/vllm/distributed/parallel_state.py
Line 1127 in 79a1d25
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], |
Signed-off-by: youkaichao <youkaichao@gmail.com>
Thanks @youkaichao - The abstractions generally look good . But the dispatch-combine calls for pplx input a very different set of information than the naive all2all implementation. I think we can cross that bridge when we add the pplx All2All implementation to the abstraction - But just adding it as a note. |
agree, right now the interface is just porting code from the naive implementation. we should definitely change it when we add the pplx All2All implementation to the abstraction. |
def dispatch(self, hidden_states: torch.Tensor, | ||
router_logits: torch.Tensor): | ||
raise NotImplementedError | ||
|
||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a function signature for dispatch and combine that will work across all kernels?
- The naive multicast implementation
- pplx-kernels
- DeepEP
I think it would to look like:
def dispatch(self,
hidden_states: torch.Tensor,
hs_scales: torch.Tensor, # Quantized scales for hidden_states
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # out_hidden_states and out_scales
(There are other parameters to the pplx-kernels that are needed but I think they can be encapsulated in a PPLXAll2All wrapper class)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another consideration here is that the output format of the different A2A implementations will be different.
- For naive All2All, the output shape would be
[total_tokens_across_dp, hidden_size]
- For pplx All2All, it will be
[num_experts_per_rank, max_tokens_per_expert, hidden_size]
, with padding along axis 1 - For DeepEP I am not sure what the output format will look like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for a generic dispatch/combine interface for different backends (pplx, naive, deepep), I think we can have:
def dispatch(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:
def combine(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:
the backend can decide what does each tensor mean, and the moe layer can use assertion to make sure the list contains data it needs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for having the base class in this case? It seems that the caller will have to know exactly the implementation it's using
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's mainly to manage the lifecycle of the communication library, and separate the implementations.
it is true that the dispatch / combine apis are quite different, but what matters in this pr is to unify the construction, backend selection, and deconstruction. as a byproduct, we need to have a generic dispatch / combine function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @tlrmchlsmth that it would be good if these interfaces matched the pplx AllToAll interfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to change the interface in the moe modularization pr!
try: | ||
# pytorch <= 2.6 | ||
from torch.distributed.distributed_c10d import _shutdown_backend | ||
_shutdown_backend(pg) | ||
except ImportError: | ||
# pytorch >= 2.7 | ||
pg.shutdown() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used this in the PPLX PR. I think this is a bit nicer than try/catch.
if is_torch_equal_or_newer("2.7"):
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to change it in the moe modularization pr!
…18077) Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
…18077) Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
the all2all interface should be refactored with a follow-up from #15956 .
this PR mainly adds ep (expert parallel) group, and all2all base class, so that backend-dependent code like pplx's
pplx_init
will not invade into common code path (vllm/distributed/parallel_state.py
) .