Skip to content

[TPU] Re-enable the Pallas MoE kernel #18025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions requirements/tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250430
torchvision==0.22.0.dev20250430
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev20250518
torchvision==0.22.0.dev20250518
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

3 changes: 1 addition & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/layers/fused_moe/moe_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,23 @@

import torch
import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram


def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min <= max, "min must be less than or equal to max."

def searchsorted(sorted_sequence: torch.Tensor,
values_to_search: torch.Tensor) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)

bin_edges = torch.linspace(min, max, max - min + 1,
dtype=input.dtype).to(input.device)
return searchsorted(bin_edges, input).to(torch.int32)


def fused_moe(
Expand Down Expand Up @@ -61,7 +77,7 @@ def fused_moe(
x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)

x = x * topk_weights.unsqueeze_(dim=-1)
x = x * topk_weights.unsqueeze(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x