@@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
136136template <typename T>
137137void _TestMergeKernelCorrectness (size_t num_index_sets, size_t seq_len, size_t num_heads,
138138 size_t head_dim, bool sparse_s) {
139- EXPECT_GT (num_index_sets, 1 ) << " num_index_sets must be greater than 1" ;
140139 std::vector<T> V_host (seq_len * num_index_sets * num_heads * head_dim);
141140 std::vector<float > V_host_trans_f32 (num_index_sets * seq_len * num_heads * head_dim);
142141 std::vector<float > S_host (seq_len * num_index_sets * num_heads);
@@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
178177 thrust::device_vector<T> V_merged_1_device (seq_len * num_heads * head_dim);
179178 thrust::device_vector<float > S_merged_1_device (seq_len * num_heads);
180179
181- // Method 0: use MergeState
182- MergeState (thrust::raw_pointer_cast (V_device_trans_f32.data ()),
183- thrust::raw_pointer_cast (S_device_trans.data ()),
184- thrust::raw_pointer_cast (V_device_trans_f32.data () + seq_len * num_heads * head_dim),
185- thrust::raw_pointer_cast (S_device_trans.data () + seq_len * num_heads),
186- thrust::raw_pointer_cast (V_merged_0_device.data ()),
187- thrust::raw_pointer_cast (S_merged_0_device.data ()), seq_len, num_heads, head_dim);
188- for (uint i = 2 ; i < num_index_sets; ++i) {
189- MergeStateInPlace (
190- thrust::raw_pointer_cast (V_merged_0_device.data ()),
191- thrust::raw_pointer_cast (S_merged_0_device.data ()),
192- thrust::raw_pointer_cast (V_device_trans_f32.data () + i * seq_len * num_heads * head_dim),
193- thrust::raw_pointer_cast (S_device_trans.data () + i * seq_len * num_heads), seq_len,
194- num_heads, head_dim);
180+ if (num_index_sets > 1 ) {
181+ // Method 0: use MergeState
182+ MergeState (thrust::raw_pointer_cast (V_device_trans_f32.data ()),
183+ thrust::raw_pointer_cast (S_device_trans.data ()),
184+ thrust::raw_pointer_cast (V_device_trans_f32.data () + seq_len * num_heads * head_dim),
185+ thrust::raw_pointer_cast (S_device_trans.data () + seq_len * num_heads),
186+ thrust::raw_pointer_cast (V_merged_0_device.data ()),
187+ thrust::raw_pointer_cast (S_merged_0_device.data ()), seq_len, num_heads, head_dim);
188+ for (uint i = 2 ; i < num_index_sets; ++i) {
189+ MergeStateInPlace (
190+ thrust::raw_pointer_cast (V_merged_0_device.data ()),
191+ thrust::raw_pointer_cast (S_merged_0_device.data ()),
192+ thrust::raw_pointer_cast (V_device_trans_f32.data () + i * seq_len * num_heads * head_dim),
193+ thrust::raw_pointer_cast (S_device_trans.data () + i * seq_len * num_heads), seq_len,
194+ num_heads, head_dim);
195+ }
196+ } else {
197+ V_merged_0_device = V_device;
198+ S_merged_0_device = S_device;
195199 }
196200
197201 // Method 1: use MergeStates
@@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
479483
480484template <typename T>
481485void TestMergeKernelCorrectness () {
482- for (size_t num_index_sets : {2 , 9 , 81 , 513 }) {
486+ for (size_t num_index_sets : {1 , 2 , 9 , 81 , 513 }) {
483487 for (size_t seq_len : {4 , 16 , 77 }) {
484488 for (size_t num_heads : {1 , 21 , 32 }) {
485489 for (size_t head_dim : {64 , 128 , 256 }) {
0 commit comments