Skip to content

Commit bab57f2

Browse files
authored
[CI] Speed up sparse tensor core test via vectorized generating sparse data (#1009)
1 parent 340bfc5 commit bab57f2

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
6666
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
6767

6868
full_tensor = torch.randn(shape, dtype=dtype, device=device)
69-
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
70-
7169
group_count = shape[-1] // 4
7270
group_shape = shape[:-1] + (group_count, 4)
7371

74-
reshaped = full_tensor.view(*group_shape)
75-
76-
for idx in range(reshaped.numel() // 4):
77-
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
78-
while flat_idx[0] == flat_idx[1]:
79-
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
80-
i = idx // group_count
81-
j = idx % group_count
82-
mask.view(*group_shape)[i, j, flat_idx[0]] = True
83-
mask.view(*group_shape)[i, j, flat_idx[1]] = True
72+
rand_vals = torch.rand(group_shape, device=device)
73+
topk_indices = rand_vals.topk(k=2, dim=-1).indices
74+
mask = torch.zeros(group_shape, dtype=torch.bool, device=device)
75+
mask.scatter_(-1, topk_indices, True)
76+
mask = mask.view(shape)
8477

8578
sparse_tensor = full_tensor * mask
8679
return sparse_tensor

0 commit comments

Comments
 (0)