Skip to content

Commit 62508c5

Browse files
review comments
1 parent 2d25a9a commit 62508c5

File tree

6 files changed

+45
-15
lines changed

6 files changed

+45
-15
lines changed

csrc/permute_cols.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
1010

1111
// For a given "a" of size [M,K] performs a permutation of the K columns based
1212
// on the given "perm" indices.
13+
// Currently only supports 16bit types (since we permute halfs)
1314
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
1415
int const* __restrict__ perm_int_ptr,
1516
int4* __restrict__ out_int4_ptr, int size_m,
@@ -61,26 +62,27 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
6162
}
6263
}
6364

64-
// More efficient version of A[:, perm]
65+
// More efficient version of A[..., perm]
6566
// taken from gptq_marlin.cu
6667
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
6768
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
6869
auto dev = A.get_device();
6970
auto stream = at::cuda::getCurrentCUDAStream(dev);
7071

7172
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
72-
"Only half and bfloat16 are supported");
73+
"Currently only 16bit types are supported");
7374
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
74-
TORCH_CHECK(A.size(1) % 8 == 0,
75+
TORCH_CHECK(A.size(-1) % 8 == 0,
7576
"A columns must be a multiple of 8 (128bits)");
77+
auto A_2d = A.view({-1, A.size(-1)});
7678

7779
torch::Tensor D = torch::empty_like(A);
7880
int sms;
7981
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
80-
int block_rows = div_ceil(A.size(0), sms);
82+
int block_rows = div_ceil(A_2d.size(0), sms);
8183
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
82-
reinterpret_cast<int4 const*>(A.const_data_ptr()),
84+
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
8385
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
84-
A.size(0), A.size(1), block_rows);
86+
A_2d.size(0), A_2d.size(1), block_rows);
8587
return D;
8688
}

csrc/quantization/machete/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def generate():
335335
)
336336

337337
# For now we use the same heuristic for all types
338+
# Heuristic is currently tuned for H100s
338339
default_heuristic = [
339340
#### M = 257+
340341
(

csrc/quantization/machete/machete_mm_kernel.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ struct MacheteKernelTemplate {
152152

153153
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
154154

155-
int group_size = maybe_group_size.value_or(K);
156-
group_size = (group_size == -1) ? K : group_size;
155+
int const group_size =
156+
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
157157
int const scale_k = (K + group_size - 1) / group_size;
158158

159159
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);

vllm/model_executor/layers/quantization/kernels/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
from typing import List, Optional, Type
33

4+
from vllm.model_executor.layers.quantization.kernels.machete import (
5+
MacheteLinearKernel)
6+
from vllm.model_executor.layers.quantization.kernels.marlin import (
7+
MarlinLinearKernel)
8+
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
9+
MPLinearKernel, MPLinearLayerConfig)
410
from vllm.platforms import current_platform
511

6-
from .MacheteLinearKernel import MacheteLinearKernel
7-
from .MarlinLinearKernel import MarlinLinearKernel
8-
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
9-
1012
# in priority/performance order (when available)
1113
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
1214
MacheteLinearKernel,
@@ -17,6 +19,24 @@
1719
def choose_mp_linear_kernel(
1820
config: MPLinearLayerConfig,
1921
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
22+
"""
23+
Choose an MPLinearKernel that can implement the given config for the given
24+
compute capability. Attempts to choose the best kernel in terms of
25+
performance.
26+
27+
Args:
28+
config (MPLinearLayerConfig): Description of the linear layer to be
29+
implemented.
30+
compute_capability (Optional[int], optional): The compute capability of
31+
the target device, if None uses `current_platform` to get the compute
32+
capability. Defaults to None.
33+
34+
Raises:
35+
ValueError: If no kernel can implement the given config.
36+
37+
Returns:
38+
Type[MPLinearKernel]: Chosen kernel.
39+
"""
2040
if compute_capability is None:
2141
if current_platform is None:
2242
raise ValueError("Cannot determine compute capability")

vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py renamed to vllm/model_executor/layers/quantization/kernels/machete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from functools import partial
2+
from typing import Optional, Tuple
3+
4+
import torch
25

36
from vllm import _custom_ops as ops
47
from vllm.model_executor.layers.quantization.utils.machete_utils import (
@@ -9,7 +12,7 @@
912
from vllm.model_executor.parameter import (BasevLLMParameter,
1013
permute_param_layout_)
1114

12-
from .MPLinearKernel import *
15+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
1316

1417

1518
class MacheteLinearKernel(MPLinearKernel):

vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py renamed to vllm/model_executor/layers/quantization/kernels/marlin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
15
from vllm import _custom_ops as ops
26
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
37
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
@@ -7,7 +11,7 @@
711
from vllm.model_executor.parameter import (BasevLLMParameter,
812
permute_param_layout_)
913

10-
from .MPLinearKernel import *
14+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
1115

1216

1317
class MarlinLinearKernel(MPLinearKernel):
@@ -111,7 +115,7 @@ def apply_weights(self,
111115
c = self.config
112116
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
113117

114-
# `process_weights_after_loading`` will ensure w_zp and w_gidx are not
118+
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
115119
# None for marlin
116120
return apply_gptq_marlin_linear(
117121
input=x,

0 commit comments

Comments
 (0)