Skip to content

[FEAT] [ROCm]: AITER Fused MOE V1 Support #16752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Apr 17, 2025

Description

This PR integrates enables Aiter's fused Mixture-of-Experts ops, found here, to be used with v1.

Implementation

The following ops have been added/modified and registered as custom ops:

  1. rocm_aiter_ck_moe
  2. rocm_aiter_fmoe_fp8_blockscale_g1u1
  3. rocm_aiter_asm_moe
  4. rocm_aiter_topk_softmax
  5. rocm_aiter_shuffle_weight
  6. rocm_aiter_asm_moe_tkw1

Testing

The integration has been verified through:

  1. High-level integration tests with various models.
  2. Accuracy Test using Lmeval.

Accuracy Test GSM8K

The following command has been used to run Lmeval on the following models:

  • Llama-4-Maverick-17B-128E-Instruct
  • Llama-4-Maverick-17B-128E-Instruct-FP8
  • DeepSeek-V3
  • Mixtral-8x7B-Instruct-v0.1
  • Mixtral-8x7B-Instruct-v0.1(FP8)
VLLM_USE_TRITON_FLASH_ATTN=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ROCM_USE_AITER=0 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
SAFETENSORS_FAST_GPU=1 \
lm_eval \
--model vllm \
--model_args pretrained=model_name,tensor_parallel_size=8,enforce_eager=False,max_model_len=4096 \
--trust_remote_code \
--tasks gsm8k \
--num_fewshot 5 \
--batch_size auto 

Additionally we set some addiational vars/args for some models as specified below:

Llama-4-Maverick-17B-128E-Instruct:

  • VLLM_USE_V1=1

Llama-4-Maverick-17B-128E-Instruct-FP8:

  • VLLM_USE_V1=1

DeepSeek-V3:

  • VLLM_USE_V1=0

Mixtral-8x7B-Instruct-v0.1:

  • VLLM_USE_V1=1

Mixtral-8x7B-Instruct-v0.1(FP8):

  • VLLM_USE_V1=1
  • --quantization fp8

We provide the table below to show the lm_eval results :

Model vLLM version Tasks Version Filter n-shot Metric   Value   Stderr
Llama-4-Maverick-17B-128E-Instruct-BF16 V1 gsm8k 3 flexible-extract 5 exact_match 0.9272 ± 0.0072
    strict-match 5 exact_match 0.9272 ± 0.0072
Llama-4-Maverick-17B-128E-Instruct-FP8 V1 gsm8k 3 flexible-extract 5 exact_match 0.9234 ± 0.0073
    strict-match 5 exact_match 0.9272 ± 0.0072
DeepSeek-V3 V0 gsm8k 3 flexible-extract 5 exact_match 0.9454 ± 0.063
    strict-match 5 exact_match 0.9454 ± 0.063
Mixtral-8x7B-Instruct-v0.1 V1 gsm8k 3 flexible-extract 5 exact_match 0.6452 ± 0.0132
    strict-match 5 exact_match 0.6429 ± 0.0132
Mixtral-8x7B-Instruct-v0.1 (FP8) V1 gsm8k 3 flexible-extract 5 exact_match 0.5413 ± 0.0137
    strict-match 5 exact_match 0.5398 ± 0.0137

This PR is part of a larger effort to integrate AITER kernels into vLLM for improved performance on the ROCm platform.

Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm marked this pull request as ready for review April 23, 2025 11:07
@hongxiayang
Copy link
Collaborator

cc @houseroad This enables AITER kennel Cudagraph mode for llama4 models in V1 for performance.

@hongxiayang hongxiayang added rocm Related to AMD ROCm v1 labels Apr 23, 2025
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
return torch.empty((topk_ids.size(0), hidden_states.size(1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.empty_like(hidden_states)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. We have updated the code accordingly.

Comment on lines +55 to +72
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> torch.Tensor:
return torch.empty_like(hidden_states)


def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, we should have some comments to tell the use case for each kernel, like

  • asm_moe_tkw1: for w8a8
  • ck_moe: for w16a16
    what do you think?

a1_scale: torch.Tensor,
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter.fused_moe_bf16_asm import moe_sorting_ck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry, is it possible to just return torch.empty_like(a1, dtype=torch.bf16)? any reason we need to call the moe_sorting_ck in the fake impl?


def rocm_aiter_shuffle_weight_impl(tensor: torch.Tensor) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
return shuffle_weight(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shuffle_weight is not a pybind kernel just a normal pytorch func, do we still need to register it as a custom op? : D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shuffle_weight is not a pybind kernel just a normal pytorch func, do we still need to register it as a custom op? : D

good question

a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is allow_deep_gemm actually used?

Copy link
Contributor Author

@vllmellm vllmellm Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just added for mypy.

return tensor


if current_platform.is_rocm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only register these custom_ops when VLLM_USE_V1=1 for V0 compatibility and performance reasons?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ops register under direct_register_custom_op are also compatible with V0.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified the code end to end with llama4 fp8 E128 model. Looks good.

Approving this with comments.

@hongxiayang hongxiayang added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 24, 2025
@DarkLight1337 DarkLight1337 merged commit eef3647 into vllm-project:main Apr 25, 2025
65 checks passed
gshtras added a commit to ROCm/vllm that referenced this pull request Apr 25, 2025
* [BugFix] Remove default multiproc executor `collective_rpc` timeout (vllm-project#17000)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [Core][V1][TPU] Enable structured decoding on TPU V1 (vllm-project#16499)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Bugfix] validate urls object for multimodal content parts (vllm-project#16990)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* add Dockerfile build vllm against torch nightly (vllm-project#16936)

Signed-off-by: Yang Wang <elainewy@meta.com>

* [Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (vllm-project#13305)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: maleksan85 <maleksan@amd.com>
Signed-off-by: <>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>

* [V1][DP] More robust DP/EP dummy request coordination (vllm-project#16277)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [BugFix] Revert ROCm Custom Paged Attention Env Flag Check (vllm-project#17022)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>

* Revert "[Misc] Add S3 environment variables for better support of MinIO." (vllm-project#17021)

* [misc] tune some env vars for GB200 (vllm-project#16992)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [INTEL-HPU][v0] Port delayed sampling to upstream (vllm-project#16949)

Signed-off-by: Michal Adamczyk <michal.adamczyk@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>

* [doc] add download path tips (vllm-project#17013)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Bugfix] Triton FA function takes no keyword arguments (vllm-project#16902)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>

* [V1] Avoid socket errors during shutdown when requests are in in-flight (vllm-project#16807)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (vllm-project#16998)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* [Misc] Improve readability of get_open_port function. (vllm-project#17024)

Signed-off-by: gitover22 <qidizou88@gmail.com>

* [Bugfix] Fix AssertionError: skip_special_tokens=False is not supported for Mistral tokenizers (vllm-project#16964)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [CI] Run v1/test_serial_utils.py in CI (vllm-project#16996)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Mistral-format support for compressed-tensors (vllm-project#16803)

Signed-off-by: mgoin <mgoin64@gmail.com>

* Categorize `tests/kernels/` based on kernel type (vllm-project#16799)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Doc] Add top anchor and a note to quantization/bitblas.md (vllm-project#17042)

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

* Ensure that `pid` passed to `kill_process_tree` is `int` for `mypy` (vllm-project#17051)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [CI] Update structured-output label automation (vllm-project#17055)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Improve Transformers backend model loading QoL (vllm-project#17039)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* `CacheConfig.block_size` should always be `int` when used (vllm-project#17052)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Use `@property` and private field for `data_parallel_rank_local` (vllm-project#17053)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (vllm-project#15949)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* [BugFix][V1] Fix int32 token index overflow when preparing input ids (vllm-project#16806)

* [V1][Spec Decode] Always use argmax for sampling draft tokens  (vllm-project#16899)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [CI/Build] workaround for CI build failure (vllm-project#17070)

Signed-off-by: csy1204 <josang1204@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [Quantization]add prefix for commandA quantized model (vllm-project#17017)

* [Minor] Use larger batch sizes for A100/B100/B200/MI300x (vllm-project#17073)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Bugfix] Enable V1 usage stats (vllm-project#16986)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>

* More informative error when using Transformers backend (vllm-project#16988)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Addendum Fix to support FIPS enabled machines with MD5 hashing (vllm-project#17043)

Signed-off-by: sydarb <areebsyed237@gmail.com>

* [Bugfix][Core] add seq_id_to_seq_group clearing to avoid memory leak when s… (vllm-project#16472)

Signed-off-by: 开哲 <kaizhe.zy@alibaba-inc.com>
Co-authored-by: 开哲 <kaizhe.zy@alibaba-inc.com>

* [V1] Update structured output (vllm-project#16812)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [doc] update to hyperlink (vllm-project#17096)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Add docs for runai_streamer_sharded (vllm-project#17093)

Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Chore] Remove Sampler from Model Code (vllm-project#17084)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* Disable enforce_eager for V1 TPU sampler and structured output tests (vllm-project#17016)

Signed-off-by: mgoin <mgoin64@gmail.com>

* Simplify `TokenizerGroup` (vllm-project#16790)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Fix OOT registration test (vllm-project#17099)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1][PP] Optimization: continue scheduling prefill chunks (vllm-project#17080)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>

* [Misc] Remove OLMo2 config copy (vllm-project#17066)

Signed-off-by: Isotr0py <2037008807@qq.com>

* Improve static type checking in `LoRAModelRunnerMixin` (vllm-project#17104)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1][Structured Output] Clear xgrammar compiler object when engine core shut down to avoid nanobind leaked warning (vllm-project#16954)

Signed-off-by: shen-shanshan <467638484@qq.com>

* [Frontend] Using matryoshka_dimensions control the allowed output dimensions. (vllm-project#16970)

* Add missing rocm_skinny_gemms kernel test to CI (vllm-project#17060)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Misc] refactor example series - structured outputs (vllm-project#17040)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (vllm-project#16665)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>

* [CI] Add automation for the `tool-calling` github label (vllm-project#17118)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Updating builkite job for IBM Power  (vllm-project#17111)

Signed-off-by: Aaruni Aggarwal <aaruniagg@gmail.com>

* existing torch installation pip command fix for docs (vllm-project#17059)

* Molmo Requirements (vllm-project#17026)

Signed-off-by: Eyshika Agarwal <eyshikaengineer@gmail.com>
Signed-off-by: eyshika <eyshikaengineer@gmail.com>

* Add `:markdownhelp:` to `EngineArgs` docs so markdown docstrings render properly (vllm-project#17124)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Improve configs - `LoRAConfig` + `PromptAdapterConfig` (vllm-project#16980)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Docs] Generate correct github links for decorated functions (vllm-project#17125)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Add collective_rpc to llm engine (vllm-project#16999)

Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai>

* Add chat template for Llama 4 models (vllm-project#16428)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>

* [Misc] Add example to run DeepSeek with Ray Serve LLM (vllm-project#17134)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>

* Better error message for missing mistral params.json (vllm-project#17132)

Signed-off-by: mgoin <mgoin64@gmail.com>

* Use custom address for listening socket (vllm-project#15988)

Signed-off-by: Jens Glaser <glaserj@ornl.gov>

* [FEAT] [ROCm]: AITER Fused MOE V1 Support (vllm-project#16752)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>

* [Attention] FA3 decode perf improvement - single mma warp group support for head dim 128 (vllm-project#16864)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* fix float16 support for kimi-vl (vllm-project#17156)

Co-authored-by: zhouzaida <zhouzaida@msh.team>

* [Doc] V1 : Update LoRA status (vllm-project#17133)

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>

* [Docs] Fix True->true in supported_models.md (vllm-project#17141)

* Move missed `SchedulerConfig` args into scheduler config group in `EngineArgs` (vllm-project#17131)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Misc] Clean up redundant code in uniproc_executor.py (vllm-project#16762)

Signed-off-by: Lifu Huang <lifu.hlf@gmail.com>

* [Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton (vllm-project#15099)

Signed-off-by: Mengqing Cao <cmq0113@163.com>

* [Misc] Benchmark Serving Script Support Appending Results (vllm-project#17028)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* [Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (vllm-project#16457)

Signed-off-by: cynthieye <yexin93@qq.com>
Co-authored-by: MagnetoWang <magnetowang@outlook.com>

* [Bugfix] remove fallback in guided_json (int range, patterns) (vllm-project#16725)

Signed-off-by: csy1204 <josang1204@gmail.com>
Co-authored-by: 조상연[플레이스 AI] <sang-yeon.cho@navercorp.com>

* [Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (vllm-project#15734)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>

* [Doc] Add headings to improve gptqmodel.md (vllm-project#17164)

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

* Only turn on FastIncrementalDetokenizer when tokenizers >= 0.21.1 (vllm-project#17158)

* [Doc] Add two links to disagg_prefill.md (vllm-project#17168)

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

* [Doc] Move todo out of beam search docstring (vllm-project#17183)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [Bugfix] Fix mistral model tests (vllm-project#17181)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix Mistral ChatCompletionRequest Body Exception (vllm-project#16769)

Signed-off-by: Jasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* Fix API typo and remove FP8 on V1 restriction

---------

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Chenyaaang <chenyangli@google.com>
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: maleksan85 <maleksan@amd.com>
Signed-off-by: <>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Michal Adamczyk <michal.adamczyk@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: reidliu41 <reid201711@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: gitover22 <qidizou88@gmail.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: csy1204 <josang1204@gmail.com>
Signed-off-by: sydarb <areebsyed237@gmail.com>
Signed-off-by: 开哲 <kaizhe.zy@alibaba-inc.com>
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Aaruni Aggarwal <aaruniagg@gmail.com>
Signed-off-by: Eyshika Agarwal <eyshikaengineer@gmail.com>
Signed-off-by: eyshika <eyshikaengineer@gmail.com>
Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Jens Glaser <glaserj@ornl.gov>
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Signed-off-by: Lifu Huang <lifu.hlf@gmail.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Signed-off-by: cynthieye <yexin93@qq.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Jasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com>
Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Co-authored-by: Yang Wang <elainewy@meta.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: root <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>
Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: huafeng <qidizou88@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Yao <haifeng.yao@daocloud.io>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Sangyeon Cho <josang1204@gmail.com>
Co-authored-by: Chen Xia <cxia0209@gmail.com>
Co-authored-by: Areeb Syed <areebsyed237@gmail.com>
Co-authored-by: 张宇 <zhangyuygss@outlook.com>
Co-authored-by: 开哲 <kaizhe.zy@alibaba-inc.com>
Co-authored-by: omer-dayan <omdayan@nvidia.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Aaruni Aggarwal <47731267+AaruniAggarwal@users.noreply.github.com>
Co-authored-by: Atilla <48064466+atilla00@users.noreply.github.com>
Co-authored-by: Eyshika Agarwal <eyshikaengineer@gmail.com>
Co-authored-by: Yinghai Lu <yinghai@thinkingmachines.ai>
Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: jglaser <glaserj@ornl.gov>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: zhouzaida <zhouzaida@msh.team>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: yexin(叶鑫) <yexin93@qq.com>
Co-authored-by: MagnetoWang <magnetowang@outlook.com>
Co-authored-by: 조상연[플레이스 AI] <sang-yeon.cho@navercorp.com>
Co-authored-by: rasmith <Randall.Smith@amd.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Jasmond L <120363110+JasmondL@users.noreply.github.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants