Skip to content

Commit d588cd2

Browse files
[Bugfix] fix custom op test (#25429)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 45d7d85 commit d588cd2

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional
34

45
import pytest
56
import torch
@@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
3435
[
3536
# Default values based on compile level
3637
# - All by default (no Inductor compilation)
37-
("", 0, False, [True] * 4, True),
38-
("", 1, True, [True] * 4, True),
39-
("", 2, False, [True] * 4, True),
38+
(None, 0, False, [True] * 4, True),
39+
(None, 1, True, [True] * 4, True),
40+
(None, 2, False, [True] * 4, True),
4041
# - None by default (with Inductor)
41-
("", 3, True, [False] * 4, False),
42-
("", 4, True, [False] * 4, False),
42+
(None, 3, True, [False] * 4, False),
43+
(None, 4, True, [False] * 4, False),
4344
# - All by default (without Inductor)
44-
("", 3, False, [True] * 4, True),
45-
("", 4, False, [True] * 4, True),
45+
(None, 3, False, [True] * 4, True),
46+
(None, 4, False, [True] * 4, True),
4647
# Explicitly enabling/disabling
4748
#
4849
# Default: all
@@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
5455
# All but SiluAndMul
5556
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
5657
# All but ReLU3 (even if ReLU2 is on)
57-
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
58+
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
5859
# RMSNorm and SiluAndMul
5960
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
6061
# All but RMSNorm
@@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
6768
# All but RMSNorm
6869
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
6970
])
70-
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
71+
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
7172
ops_enabled: list[int], default_on: bool):
73+
custom_ops = env.split(',') if env else []
7274
vllm_config = VllmConfig(
7375
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
7476
level=torch_level,
75-
custom_ops=env.split(",")))
77+
custom_ops=custom_ops))
7678
with set_current_vllm_config(vllm_config):
7779
assert CustomOp.default_on() == default_on
7880

0 commit comments

Comments
 (0)