Skip to content

Conversation

xuechendi
Copy link
Contributor

@xuechendi xuechendi commented Jun 4, 2025

Co-Authored with @kzawora-intel

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Purpose

POC for #19161

Problem:

  • Existing implementation of CustomOP doesn't use op_regitry[class_name] for initialization.
  • For vllm-plugin custom op, only monkey patch works for updating exiting op impl, including process_weight_after_loading, forward_oot

Solution:

  • In this PR, propose to use op_regitry by init from op_regitry[class_name] instead of original one

Test Plan

Verfied by a poc hpu plugin

entry_points={
        "vllm.platform_plugins": ["hpu = vllm_hpu:register"],
        "vllm.general_plugins": ["hpu_custom_ops = vllm_hpu:register_ops"],
    },
def register_ops():
    """Register custom ops for the HPU platform."""

    from vllm_hpu.ops.hpu_fused_moe import HPUUnquantizedFusedMoEMethod # noqa: F401
from typing import Callable, Optional

import torch
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.fused_moe.layer import \
    UnquantizedFusedMoEMethod, FusedMoE
from vllm.model_executor.custom_op import CustomOp

@UnquantizedFusedMoEMethod.register_oot
class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
    """MoE method without quantization."""

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        super().process_weights_after_loading(layer)
        
        # custom handling for HPU
        num_experts = layer.local_num_experts
        ep_shift = layer.ep_rank * num_experts
        quant_config = layer.quant_config
        moe_op = VllmMixtureOfExpertsOp(
                num_experts,
                experts_min,
                experts_max,
        )
        layer.moe_op = moe_op

        for expert_id in range(layer.local_num_experts):
            layer.moe_op.w13_list[expert_id].set_weight(
                layer.w13_weight.data[expert_id])
            layer.moe_op.w2_list[expert_id].set_weight(
                layer.w2_weight.data[expert_id])

    def forward_oot(
        self,
        ...
        **kwargs,
    ):
        input_shape = x.shape
        x = x.view(-1, x.shape[-1])
        topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias)
       
        topk_ids = topk_ids.view(*x.shape[:-1], -1)
        topk_weights = topk_weights.view(*x.shape[:-1], -1)
        
        return layer.moe_op(
            x,
            topk_ids.to(torch.int64),
            topk_weights.to(x.dtype),
            permuted_weights=True,
            activation=activation,
        ).view(*input_shape)

Test Result

Verified that HPUUnquantizedFusedMoEMethod in hpu plugin is called.

INFO 06-05 01:12:17 [__init__.py:31] Available plugins for group vllm.platform_plugins:
INFO 06-05 01:12:17 [__init__.py:33] - hpu -> vllm_hpu:register
INFO 06-05 01:12:17 [__init__.py:36] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 06-05 01:12:17 [__init__.py:234] Platform plugin hpu is activated
WARNING 06-05 01:12:18 [_custom_ops.py:21] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
INFO 06-05 01:12:19 [__init__.py:31] Available plugins for group vllm.general_plugins:
INFO 06-05 01:12:19 [__init__.py:33] - lora_filesystem_resolver -> vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 06-05 01:12:19 [__init__.py:33] - **hpu_custom_ops -> vllm_hpu:register_ops**
INFO 06-05 01:12:19 [__init__.py:36] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 06-05 01:12:27 [config.py:793] This model supports multiple tasks: {'embed', 'generate', 'classify', 'score', 'reward'}. Defaulting to 'generate'.
INFO 06-05 01:12:27 [arg_utils.py:1594] hpu is experimental on VLLM_USE_V1=1. Falling back to V0 Engine.
INFO 06-05 01:12:27 [config.py:1909] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-05 01:12:27 [llm_engine.py:230] Initializing a V0 LLM engine (v0.9.1.dev172+gd459fae0a.d20250604) with config: model='Qwen3/Qwen3-30B-A3B/', 
Adding requests:   0%|          | 0/4 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1987.35it/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]Configuration: ('prompt', 4, 128) was not warmed-up!
...

Copy link

github-actions bot commented Jun 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @xuechendi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here to provide a summary of this pull request. This PR addresses an issue where the existing CustomOp implementation didn't fully leverage the op_registry for initialization, making it difficult for plugins to properly override default operator implementations without resorting to monkey patching. The core idea is to modify the CustomOp class so that when it's instantiated, it looks up the actual implementation class in the op_registry and instantiates that instead. Additionally, the @CustomOp.register decorator is enhanced to allow plugin-provided custom ops to explicitly override existing registered ops, providing a cleaner mechanism for plugins to inject their own implementations.

Highlights

  • Custom Op Instantiation Logic: A __new__ method is added to the base CustomOp class. This method intercepts the class creation process, checks the op_registry for a class registered under the same name, and delegates the actual object creation to the registered class. This ensures that when code tries to instantiate CustomOp (or a subclass registered with a specific name), it gets the correct, potentially overridden, implementation from the registry.
  • Enhanced Registration for Overrides: The @CustomOp.register decorator is updated to accept an optional custom_op=True argument. When this flag is set and a class with the given name is already registered, the decorator will now replace the existing entry in the op_registry with the new class being registered. This provides a clear and intended way for plugins to override default operator implementations.

Changelog

  • vllm/model_executor/custom_op.py
    • Added __new__ method to CustomOp to enable dynamic instantiation from the op_registry (lines 19-23).
    • Modified the register class method decorator to accept custom_op=True and allow overriding existing registered ops (lines 150-161).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


A decorator, a clever wrap,
Changes how functions clap.
With @register's might,
Ops take new flight,
No more monkey patch trap.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant improvement to the CustomOp registration and instantiation mechanism, enabling plugins to override existing operators by leveraging the op_registry. The changes in __new__ and the register decorator are central to this enhancement.

While the direction is good, there are a few critical and high-severity issues in the register method's logic that could lead to unexpected behavior or errors, particularly concerning how overriding classes are handled and how name collisions are managed when custom_op=False. Additionally, the __new__ method could be made more robust in handling unregistered classes.

Addressing these points will greatly improve the reliability and predictability of the custom operator system.

Summary of Findings

  • Robustness of __new__ method: The __new__ method in CustomOp could lead to an unhandled AttributeError if cls.name is not defined. Improved error handling is recommended.
  • Missing name attribute setting in register: When overriding an operator using custom_op=True in the register decorator, the name attribute is not set on the overriding class, which can cause issues during its instantiation.
  • Incorrect return value from register decorator: The register decorator may return a different class than the one being decorated in cases of name collision with custom_op=False. This violates decorator conventions and can lead to subtle bugs. It should consistently return the decorated class or raise an error on invalid registration attempts.

Merge Readiness

The pull request makes valuable changes to the custom operator framework. However, there are critical and high-severity issues identified in the register method's logic, along with a medium-severity robustness concern in the __new__ method. These issues should be addressed before merging to ensure the stability and predictability of the custom operator system. I am unable to approve this PR myself, and it should be reviewed and approved by others once these changes are made.

@xuechendi
Copy link
Contributor Author

@simon-mo @MengqingCao @aurickq , please take a review.

@xuechendi
Copy link
Contributor Author

@mgoin @youkaichao @wangxiyuan , please also have a review on this PR. As part of RFC #19161

Copy link
Contributor

@MengqingCao MengqingCao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your efforts!

This change is overall lgtm, but there is one small suggestion I have commented.

And could you add test and docstring for the usage?

@xuechendi
Copy link
Contributor Author

@MengqingCao , I updated the register decorator arg_name and also add an UT for oot_custom_op, please review again, Thanks.

Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reasonable to me, thanks for this. probably want to cc either @youkaichao or @houseroad for this?

@aarnphm aarnphm requested a review from MengqingCao June 13, 2025 00:00
@aarnphm aarnphm added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 13, 2025
@simon-mo
Copy link
Collaborator

Also cc @zou3519 as well.

@CustomOp.register("RotaryEmbedding", is_oot_custom_op=True)
class DummyRotaryEmbedding(RotaryEmbedding):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some nits about the API. CustomOp.register is used for registering a new class of custom operators. What we're doing here with the is_oot_custom_op flag is that is actually that whenever a RotaryEmbedding is initialized, it will actually be replaced with a DummyRotaryEmbedding.

It also sounds like you might want to replace special variants, like Llama4RotaryEmbedding with NPURotaryEmbedding. The first input to register is supposed to be the "base name" of a custom operator (like "RotaryEmbedding"), but it sounds like in your case you may want to use "Llama4RotaryEmbedding" instead.

If we're going with this design, I think there should be a separate API because it seems like something separate from CustomOp.register.

@RotaryEmbedding.monkeypatch  # or RotaryEmbedding.use_out_of_tree if you want it to be nicer.
class DummyRotaryEmbedding(RotaryEmbedding):
    pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the seperate API will lead to bigger changes to existing vllm codes right? Currently, since most of the ops are already derived from custom_op, it is much smaller change in this PR to make things working.

And the main reason is also we want to avoid monkey patch, so we want to have a single entry which is CustomOp.

To clarify, we want do things like:
replace Llama4RotaryEmbedding to HPULlama4RotaryEmbedding
replace RotaryEmbedding to HPURotaryEmbedding

And HPULlama4RotaryEmbedding is different to HPURotaryEmbedding in impl
here is the actual impl: https://github.com/HabanaAI/vllm-hpu-extension/tree/plugin_poc/vllm_hpu/ops

Copy link
Contributor Author

@xuechendi xuechendi Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the major impact if we reuse custom_op to register oot_op vs using the way you suggested by decorate on actual layer class?

I think it makes more sense to us by reusing custom_op, this prevent massive update to vllm.

If you think reuse 'custom_op.register' might bring confusion, I can add a new decorator in custom_op and using 'custom_op.register_oot' instead, what do you think?

Copy link
Contributor Author

@xuechendi xuechendi Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 , I've updated the PR based on your suggestion. Now either using CustomOP.register_oot('RotaryEmbedding') or RotaryEmbedding.register_oot will register custom class to CustomOp

@mergify mergify bot added the qwen Related to Qwen models label Jun 18, 2025
@xuechendi xuechendi requested a review from zou3519 June 18, 2025 18:57
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, this looks reasonable to me.

Copy link
Contributor

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. It's fine

@xuechendi
Copy link
Contributor Author

@simon-mo , can you review the PR again, thanks

@houseroad
Copy link
Collaborator

@simon-mo is out this week, maybe @youkaichao or @WoosukKwon ?

@WoosukKwon WoosukKwon merged commit 7e8977f into vllm-project:main Jun 20, 2025
70 checks passed
chris-relational pushed a commit to chris-relational/vllm that referenced this pull request Jun 20, 2025
Signed-off-by: nie3e <adrcwiek@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

added notebooks to playground

updates

remoted verbatim HF secrets from all files

updates

[custom_op][vllm-plugin] update custom_op class to use op_registry (vllm-project#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>

Export NaNs in logits to scheduler_stats if output is corrupted (vllm-project#18777)

Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com>

[CPU][CI] Fallback sliding window to v0 and fix CPU pooling model tests (vllm-project#19901)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

[Kernel] mark TorchSDPABackend swap_blocks NotImplementedError (vllm-project#19749)
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
juncheoll pushed a commit to juncheoll/vllm that referenced this pull request Jun 23, 2025
…llm-project#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: juncheoll <th6re8e@naver.com>
fhl2000 pushed a commit to fhl2000/vllm that referenced this pull request Jun 25, 2025
…llm-project#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: fhl <2410591650@qq.com>
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
…llm-project#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Will Eaton <weaton@redhat.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
wwl2755-google pushed a commit to wwl2755-google/vllm that referenced this pull request Jul 1, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…llm-project#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants