Skip to content

Commit 1906118

Browse files
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
fix dependency in pytest
1 parent bc1da87 commit 1906118

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_infer/test_ops/triton/test_rmsnorm_triton.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import triton
44
from packaging import version
55
from transformers.models.llama.modeling_llama import LlamaRMSNorm
6-
from vllm.model_executor.layers.layernorm import RMSNorm
76

87
from colossalai.kernel.triton import rms_layernorm
98
from colossalai.testing.utils import parameterize
109

1110
try:
12-
pass
11+
import triton # noqa
1312

1413
HAS_TRITON = True
1514
except ImportError:
@@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
8584
SEQUENCE_TOTAL: int,
8685
HIDDEN_SIZE: int,
8786
):
87+
try:
88+
from vllm.model_executor.layers.layernorm import RMSNorm
89+
except ImportError:
90+
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
91+
8892
warmup = 10
8993
rep = 1000
9094

0 commit comments

Comments
 (0)