@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
6666        raise  ValueError ("Last dimension must be divisible by 4 for 2:4 sparsity." )
6767
6868    full_tensor  =  torch .randn (shape , dtype = dtype , device = device )
69-     mask  =  torch .zeros_like (full_tensor , dtype = torch .bool )
70- 
7169    group_count  =  shape [- 1 ] //  4 
7270    group_shape  =  shape [:- 1 ] +  (group_count , 4 )
7371
74-     reshaped  =  full_tensor .view (* group_shape )
75- 
76-     for  idx  in  range (reshaped .numel () //  4 ):
77-         flat_idx  =  torch .randint (0 , 4 , (2 ,), dtype = torch .int64 )
78-         while  flat_idx [0 ] ==  flat_idx [1 ]:
79-             flat_idx [1 ] =  torch .randint (0 , 4 , (1 ,), dtype = torch .int64 )
80-         i  =  idx  //  group_count 
81-         j  =  idx  %  group_count 
82-         mask .view (* group_shape )[i , j , flat_idx [0 ]] =  True 
83-         mask .view (* group_shape )[i , j , flat_idx [1 ]] =  True 
72+     rand_vals  =  torch .rand (group_shape , device = device )
73+     topk_indices  =  rand_vals .topk (k = 2 , dim = - 1 ).indices 
74+     mask  =  torch .zeros (group_shape , dtype = torch .bool , device = device )
75+     mask .scatter_ (- 1 , topk_indices , True )
76+     mask  =  mask .view (shape )
8477
8578    sparse_tensor  =  full_tensor  *  mask 
8679    return  sparse_tensor 
0 commit comments