Skip to content

Commit 46ec893

Browse files
authored
misc: optimize group_gemm test (#493)
Reduce memory consumption to avoid OOM
1 parent 8f71591 commit 46ec893

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

python/tests/test_group_gemm.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,59 @@
1919
import torch
2020

2121

22+
DTYPES = [torch.float16]
23+
CUDA_DEVICES = ["cuda:0"]
24+
25+
2226
@pytest.mark.parametrize("batch_size", [1, 77, 199])
2327
@pytest.mark.parametrize("num_rows_per_batch", [3, 10, 99])
2428
@pytest.mark.parametrize("d_in", [128, 1024, 4096])
2529
@pytest.mark.parametrize("d_out", [128, 1024, 4096])
2630
@pytest.mark.parametrize("use_weight_indices", [False, True])
2731
@pytest.mark.parametrize("column_major", [False, True])
32+
@pytest.mark.parametrize("dtype", DTYPES)
33+
@pytest.mark.parametrize("device", CUDA_DEVICES)
2834
def test_segment_gemm(
2935
batch_size,
3036
num_rows_per_batch,
3137
d_in,
3238
d_out,
3339
use_weight_indices,
3440
column_major,
41+
dtype,
42+
device,
3543
):
3644
if batch_size * num_rows_per_batch > 8192:
3745
pytest.skip("batch_size * num_rows_per_batch too large for test.")
3846
torch.manual_seed(42)
39-
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
47+
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device)
4048
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer)
41-
x = (
42-
(torch.randn(batch_size * num_rows_per_batch, d_in) / 10)
43-
.to(0)
44-
.to(torch.float16)
49+
x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(
50+
device
4551
)
4652
if use_weight_indices:
4753
num_weights = 1024
4854
if column_major:
49-
weight = (
50-
(torch.randn(num_weights, d_out, d_in) / 10).to(0).to(torch.float16)
55+
weight = torch.randn(num_weights, d_out, d_in, dtype=dtype).to(
56+
device
5157
)
5258
else:
53-
weight = (
54-
(torch.randn(num_weights, d_in, d_out) / 10).to(0).to(torch.float16)
59+
weight = torch.randn(num_weights, d_in, d_out, dtype=dtype).to(
60+
device
5561
)
5662
else:
5763
if column_major:
58-
weight = (torch.randn(batch_size, d_out, d_in) / 10).to(0).to(torch.float16)
64+
weight = torch.randn(batch_size, d_out, d_in, dtype=dtype).to(device)
5965
else:
60-
weight = (torch.randn(batch_size, d_in, d_out) / 10).to(0).to(torch.float16)
66+
weight = torch.randn(batch_size, d_in, d_out, dtype=dtype).to(device)
6167
y = segment_gemm.run(
6268
x,
6369
weight,
6470
batch_size,
6571
weight_column_major=column_major,
6672
seg_lens=torch.full((batch_size,), num_rows_per_batch, dtype=torch.int64),
6773
weight_indices=(
68-
(torch.arange(0, batch_size) % num_weights).to(0)
74+
(torch.arange(0, batch_size) % num_weights).to(device)
6975
if use_weight_indices
7076
else None
7177
),
@@ -74,31 +80,26 @@ def test_segment_gemm(
7480
if use_weight_indices:
7581
for i in range(batch_size):
7682
torch.testing.assert_close(
77-
y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].cpu().numpy(),
83+
y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch],
7884
torch.matmul(
7985
x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch],
8086
(
8187
weight[i % num_weights].T
8288
if column_major
8389
else weight[i % num_weights]
8490
),
85-
)
86-
.cpu()
87-
.numpy(),
91+
),
8892
rtol=1e-3,
8993
atol=1e-3,
9094
msg="assertion failed at batch {}".format(i),
9195
)
9296
else:
9397
torch.testing.assert_close(
94-
y.cpu().numpy(),
98+
y,
9599
torch.matmul(
96100
x.view(batch_size, num_rows_per_batch, d_in),
97101
weight.transpose(-1, -2) if column_major else weight,
98-
)
99-
.view(batch_size * num_rows_per_batch, d_out)
100-
.cpu()
101-
.numpy(),
102+
).view(batch_size * num_rows_per_batch, d_out),
102103
rtol=1e-3,
103104
atol=1e-3,
104105
)

0 commit comments

Comments
 (0)