Skip to content

Commit 9a7e2d0

Browse files
authored
[Bugfix] Allow vllm to still work if triton is not installed. (#6786)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 7f8d612 commit 9a7e2d0

File tree

13 files changed

+65
-37
lines changed

13 files changed

+65
-37
lines changed

requirements-cpu.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@
44
# Dependencies for x86_64 CPUs
55
torch == 2.4.0; platform_machine != "ppc64le"
66
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
7-
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

requirements-openvino.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@
55
torch >= 2.1.2
66
openvino ~= 2024.3.0.dev
77
optimum-intel[openvino] >= 1.18.1
8-
9-
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

requirements-tpu.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@
55
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
66
# You can install the dependencies in Dockerfile.tpu.
77
ray
8-
triton # To avoid import errors

tests/kernels/test_sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import triton
66
import triton.language as tl
77

8-
from vllm.model_executor.layers.ops.sample import (
9-
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits,
10-
sample)
8+
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
9+
sample)
1110
from vllm.model_executor.sampling_metadata import SamplingTensors
1211
from vllm.model_executor.utils import set_random_seed
12+
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
13+
get_num_triton_sampler_splits)
1314

1415
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
1516
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100

vllm/attention/ops/paged_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import torch
55

66
from vllm import _custom_ops as ops
7-
from vllm.attention.ops.prefix_prefill import context_attention_fwd
7+
from vllm.triton_utils import HAS_TRITON
8+
9+
if HAS_TRITON:
10+
from vllm.attention.ops.prefix_prefill import context_attention_fwd
811

912
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
1013
_PARTITION_SIZE = 512
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
from vllm.model_executor.layers.fused_moe.fused_moe import (
2-
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
31
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
42
FusedMoEMethodBase)
3+
from vllm.triton_utils import HAS_TRITON
54

65
__all__ = [
7-
"fused_moe",
8-
"fused_topk",
9-
"fused_experts",
10-
"get_config_file_name",
11-
"grouped_topk",
126
"FusedMoE",
137
"FusedMoEMethodBase",
148
]
9+
10+
if HAS_TRITON:
11+
12+
from vllm.model_executor.layers.fused_moe.fused_moe import (
13+
fused_experts, fused_moe, fused_topk, get_config_file_name,
14+
grouped_topk)
15+
16+
__all__ += [
17+
"fused_moe",
18+
"fused_topk",
19+
"fused_experts",
20+
"get_config_file_name",
21+
"grouped_topk",
22+
]

vllm/model_executor/layers/ops/sample.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
1-
import math
21
from typing import Optional, Tuple
32

43
import torch
54
import triton
65
import triton.language as tl
76

87
from vllm.model_executor.layers.ops.rand import seeded_uniform
8+
from vllm.triton_utils.sample import get_num_triton_sampler_splits
99

1010
_EPS = 1e-6
1111

12-
# This is a hardcoded limit in Triton (max block size).
13-
MAX_TRITON_N_COLS = 131072
14-
15-
16-
def get_num_triton_sampler_splits(n_cols: int) -> int:
17-
"""Get the number of splits to use for Triton sampling.
18-
19-
Triton has a limit on the number of columns it can handle, so we need to
20-
split the tensor and call the kernel multiple times if it's too large.
21-
"""
22-
return math.ceil(n_cols / MAX_TRITON_N_COLS)
23-
2412

2513
def _multi_split_sample(
2614
probs: torch.Tensor,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from vllm import _custom_ops as ops
88
from vllm.logger import init_logger
9-
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
10-
fused_moe)
9+
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
1110
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1211
UnquantizedLinearMethod)
1312
from vllm.model_executor.layers.quantization.base_config import (
@@ -404,6 +403,7 @@ def apply(self,
404403
num_expert_group: Optional[int] = None,
405404
topk_group: Optional[int] = None) -> torch.Tensor:
406405

406+
from vllm.model_executor.layers.fused_moe import fused_moe
407407
return fused_moe(x,
408408
layer.w13_weight,
409409
layer.w2_weight,

vllm/model_executor/layers/sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
import torch.nn as nn
88

9-
from vllm.model_executor.layers.ops.sample import sample as sample_triton
9+
from vllm.triton_utils import HAS_TRITON
10+
11+
if HAS_TRITON:
12+
from vllm.model_executor.layers.ops.sample import sample as sample_triton
13+
1014
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
1115
SamplingTensors,
1216
SequenceGroupToSample)

vllm/model_executor/sampling_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import torch
77

8-
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
98
from vllm.sampling_params import SamplingParams, SamplingType
109
from vllm.sequence import SequenceData, SequenceGroupMetadata
10+
from vllm.triton_utils.sample import get_num_triton_sampler_splits
1111
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
1212
make_tensor_with_pad, maybe_expand_dim)
1313

0 commit comments

Comments
 (0)