Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1c552d5

Browse files
committed
format.sh
1 parent eca5d2f commit 1c552d5

File tree

4 files changed

+44
-38
lines changed

4 files changed

+44
-38
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ struct cutlass_2x_gemm {
120120

121121
template <typename Gemm>
122122
void cutlass_scaled_mm_dispatcher(torch::Tensor& out, torch::Tensor const& a,
123-
torch::Tensor const& b,
124-
torch::Tensor const& a_scales,
125-
torch::Tensor const& b_scales) {
123+
torch::Tensor const& b,
124+
torch::Tensor const& a_scales,
125+
torch::Tensor const& b_scales) {
126126
using ElementAB = typename Gemm::ElementAB;
127127
using ElementD = typename Gemm::ElementD;
128128

@@ -195,9 +195,9 @@ void cutlass_scaled_mm_dispatcher(torch::Tensor& out, torch::Tensor const& a,
195195
} // namespace
196196

197197
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
198-
torch::Tensor const& b,
199-
torch::Tensor const& a_scales,
200-
torch::Tensor const& b_scales) {
198+
torch::Tensor const& b,
199+
torch::Tensor const& a_scales,
200+
torch::Tensor const& b_scales) {
201201
TORCH_CHECK(a.dtype() == torch::kInt8);
202202
TORCH_CHECK(b.dtype() == torch::kInt8);
203203
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -227,9 +227,9 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
227227
}
228228

229229
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
230-
torch::Tensor const& b,
231-
torch::Tensor const& a_scales,
232-
torch::Tensor const& b_scales) {
230+
torch::Tensor const& b,
231+
torch::Tensor const& a_scales,
232+
torch::Tensor const& b_scales) {
233233
TORCH_CHECK(a.dtype() == torch::kInt8);
234234
TORCH_CHECK(b.dtype() == torch::kInt8);
235235
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -259,9 +259,9 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
259259
}
260260

261261
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
262-
torch::Tensor const& b,
263-
torch::Tensor const& a_scales,
264-
torch::Tensor const& b_scales) {
262+
torch::Tensor const& b,
263+
torch::Tensor const& a_scales,
264+
torch::Tensor const& b_scales) {
265265
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
266266
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
267267
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,30 @@
44
#include <torch/extension.h>
55

66
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
7-
torch::Tensor const& b,
8-
torch::Tensor const& a_scales,
9-
torch::Tensor const& b_scales);
7+
torch::Tensor const& b,
8+
torch::Tensor const& a_scales,
9+
torch::Tensor const& b_scales);
1010

1111
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
12-
torch::Tensor const& b,
13-
torch::Tensor const& a_scales,
14-
torch::Tensor const& b_scales);
12+
torch::Tensor const& b,
13+
torch::Tensor const& a_scales,
14+
torch::Tensor const& b_scales);
1515

1616
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
17-
torch::Tensor const& b,
18-
torch::Tensor const& a_scales,
19-
torch::Tensor const& b_scales);
17+
torch::Tensor const& b,
18+
torch::Tensor const& a_scales,
19+
torch::Tensor const& b_scales);
2020

2121
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
2222
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
23-
torch::Tensor const& b,
24-
torch::Tensor const& a_scales,
25-
torch::Tensor const& b_scales);
23+
torch::Tensor const& b,
24+
torch::Tensor const& a_scales,
25+
torch::Tensor const& b_scales);
2626
#endif
2727

2828
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
29-
torch::Tensor const& b, torch::Tensor const& a_scales,
30-
torch::Tensor const& b_scales) {
29+
torch::Tensor const& b, torch::Tensor const& a_scales,
30+
torch::Tensor const& b_scales) {
3131
int32_t major_capability;
3232
int32_t minor_capability;
3333
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,

tests/kernels/test_cutlass.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def cutlass_fp8_gemm_helper(m: int,
5252
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
5353

5454
# Convert outputs to fp32, since allclose is not implemented for fp8_e4m3
55-
assert torch.allclose(out.to(torch.float32), baseline.to(torch.float32), rtol=1e-2, atol=1e-1)
55+
assert torch.allclose(out.to(torch.float32),
56+
baseline.to(torch.float32),
57+
rtol=1e-2,
58+
atol=1e-1)
5659

5760

5861
def cutlass_int8_gemm_helper(m: int,
@@ -79,7 +82,7 @@ def cutlass_int8_gemm_helper(m: int,
7982
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
8083
scale_b *
8184
b.to(dtype=torch.float32)).to(dtype=out_dtype)
82-
85+
8386
rtol = 1e0 if out_dtype is torch.int8 else 1e-1
8487
assert torch.allclose(out, baseline, rtol=rtol, atol=1e0)
8588

@@ -108,7 +111,8 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
108111

109112
@pytest.mark.parametrize("per_act_token", [True, False])
110113
@pytest.mark.parametrize("per_out_ch", [True, False])
111-
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16, torch.int8])
114+
@pytest.mark.parametrize("out_dtype",
115+
[torch.bfloat16, torch.float16, torch.int8])
112116
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
113117
out_dtype: Type[torch.dtype]):
114118
cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch,
@@ -117,7 +121,8 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
117121

118122
@pytest.mark.parametrize("per_act_token", [True, False])
119123
@pytest.mark.parametrize("per_out_ch", [True, False])
120-
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn])
124+
@pytest.mark.parametrize("out_dtype",
125+
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
121126
@pytest.mark.skipif(capability < 89,
122127
reason="FP8 is not supported on this GPU type.")
123128
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@@ -183,10 +188,10 @@ def test_cutlass_subset():
183188
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
184189

185190
out = ops.cutlass_scaled_mm(a,
186-
b,
187-
scale_a,
188-
scale_b,
189-
out_dtype=torch.bfloat16)
191+
b,
192+
scale_a,
193+
scale_b,
194+
out_dtype=torch.bfloat16)
190195
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
191196
scale_b *
192197
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
@@ -206,7 +211,7 @@ def __init__(self, b, scale_a, scale_b, out_dtype):
206211

207212
def forward(self, a):
208213
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
209-
self.out_dtype)
214+
self.out_dtype)
210215

211216

212217
@pytest.mark.parametrize("per_act_token", [True, False])

vllm/_custom_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,12 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
176176

177177

178178
# cutlass
179-
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor,
180-
a_scales: torch.Tensor, b_scales: torch.Tensor,
179+
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, a_scales: torch.Tensor,
180+
b_scales: torch.Tensor,
181181
out_dtype: Type[torch.dtype]) -> torch.Tensor:
182182
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
183-
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16 or out_dtype is a.dtype)
183+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16
184+
or out_dtype is a.dtype)
184185

185186
m = a.shape[0]
186187
n = b.shape[1]

0 commit comments

Comments
 (0)