1919import torch
2020
2121
22+ DTYPES = [torch .float16 ]
23+ CUDA_DEVICES = ["cuda:0" ]
24+
25+
2226@pytest .mark .parametrize ("batch_size" , [1 , 77 , 199 ])
2327@pytest .mark .parametrize ("num_rows_per_batch" , [3 , 10 , 99 ])
2428@pytest .mark .parametrize ("d_in" , [128 , 1024 , 4096 ])
2529@pytest .mark .parametrize ("d_out" , [128 , 1024 , 4096 ])
2630@pytest .mark .parametrize ("use_weight_indices" , [False , True ])
2731@pytest .mark .parametrize ("column_major" , [False , True ])
32+ @pytest .mark .parametrize ("dtype" , DTYPES )
33+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
2834def test_segment_gemm (
2935 batch_size ,
3036 num_rows_per_batch ,
3137 d_in ,
3238 d_out ,
3339 use_weight_indices ,
3440 column_major ,
41+ dtype ,
42+ device ,
3543):
3644 if batch_size * num_rows_per_batch > 8192 :
3745 pytest .skip ("batch_size * num_rows_per_batch too large for test." )
3846 torch .manual_seed (42 )
39- workspace_buffer = torch .empty (32 * 1024 * 1024 , dtype = torch .int8 ).to (0 )
47+ workspace_buffer = torch .empty (32 * 1024 * 1024 , dtype = torch .int8 ).to (device )
4048 segment_gemm = flashinfer .gemm .SegmentGEMMWrapper (workspace_buffer )
41- x = (
42- (torch .randn (batch_size * num_rows_per_batch , d_in ) / 10 )
43- .to (0 )
44- .to (torch .float16 )
49+ x = torch .randn (batch_size * num_rows_per_batch , d_in , dtype = dtype ).to (
50+ device
4551 )
4652 if use_weight_indices :
4753 num_weights = 1024
4854 if column_major :
49- weight = (
50- ( torch . randn ( num_weights , d_out , d_in ) / 10 ). to ( 0 ). to ( torch . float16 )
55+ weight = torch . randn ( num_weights , d_out , d_in , dtype = dtype ). to (
56+ device
5157 )
5258 else :
53- weight = (
54- ( torch . randn ( num_weights , d_in , d_out ) / 10 ). to ( 0 ). to ( torch . float16 )
59+ weight = torch . randn ( num_weights , d_in , d_out , dtype = dtype ). to (
60+ device
5561 )
5662 else :
5763 if column_major :
58- weight = ( torch .randn (batch_size , d_out , d_in ) / 10 ).to (0 ). to ( torch . float16 )
64+ weight = torch .randn (batch_size , d_out , d_in , dtype = dtype ).to (device )
5965 else :
60- weight = ( torch .randn (batch_size , d_in , d_out ) / 10 ).to (0 ). to ( torch . float16 )
66+ weight = torch .randn (batch_size , d_in , d_out , dtype = dtype ).to (device )
6167 y = segment_gemm .run (
6268 x ,
6369 weight ,
6470 batch_size ,
6571 weight_column_major = column_major ,
6672 seg_lens = torch .full ((batch_size ,), num_rows_per_batch , dtype = torch .int64 ),
6773 weight_indices = (
68- (torch .arange (0 , batch_size ) % num_weights ).to (0 )
74+ (torch .arange (0 , batch_size ) % num_weights ).to (device )
6975 if use_weight_indices
7076 else None
7177 ),
@@ -74,31 +80,26 @@ def test_segment_gemm(
7480 if use_weight_indices :
7581 for i in range (batch_size ):
7682 torch .testing .assert_close (
77- y [i * num_rows_per_batch : (i + 1 ) * num_rows_per_batch ]. cpu (). numpy () ,
83+ y [i * num_rows_per_batch : (i + 1 ) * num_rows_per_batch ],
7884 torch .matmul (
7985 x [i * num_rows_per_batch : (i + 1 ) * num_rows_per_batch ],
8086 (
8187 weight [i % num_weights ].T
8288 if column_major
8389 else weight [i % num_weights ]
8490 ),
85- )
86- .cpu ()
87- .numpy (),
91+ ),
8892 rtol = 1e-3 ,
8993 atol = 1e-3 ,
9094 msg = "assertion failed at batch {}" .format (i ),
9195 )
9296 else :
9397 torch .testing .assert_close (
94- y . cpu (). numpy () ,
98+ y ,
9599 torch .matmul (
96100 x .view (batch_size , num_rows_per_batch , d_in ),
97101 weight .transpose (- 1 , - 2 ) if column_major else weight ,
98- )
99- .view (batch_size * num_rows_per_batch , d_out )
100- .cpu ()
101- .numpy (),
102+ ).view (batch_size * num_rows_per_batch , d_out ),
102103 rtol = 1e-3 ,
103104 atol = 1e-3 ,
104105 )
0 commit comments