@@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
132
132
return shared_memory[threadIdx.x ];
133
133
}
134
134
135
+ // Swap data
136
+ template <typename T>
137
+ __device__ __forceinline__ void Swap (T* first_value, T* second_value) {
138
+ T t_value;
139
+ t_value = (*first_value);
140
+ (*first_value) = (*second_value);
141
+ (*second_value) = t_value;
142
+ }
143
+
144
+ // swap with monotonic_type
145
+ template <typename T>
146
+ __device__ __forceinline__ void Comparator (T* first_value,
147
+ T* second_value,
148
+ int monotonic_type) {
149
+ if (((*first_value) > (*second_value)) == monotonic_type) {
150
+ Swap<T>(first_value, second_value);
151
+ }
152
+ }
153
+
154
+ template <typename T, typename IndexType>
155
+ __device__ __forceinline__ void ComparatorWithIndex (T* first_value,
156
+
157
+ T* second_value,
158
+ IndexType* first_index,
159
+ IndexType* second_index,
160
+ int monotonic_type) {
161
+ if ((*first_value > (*second_value)) == monotonic_type) {
162
+ // swap value
163
+ Swap<T>(first_value, second_value);
164
+ // swap index
165
+ Swap<IndexType>(first_index, second_index);
166
+ }
167
+ }
168
+
135
169
} // namespace details
136
170
137
171
/* *
@@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out,
481
515
static_cast <OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32 ]);
482
516
}
483
517
518
+ #define SHARED_SIZE_LIMIT \
519
+ 1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
520
+ // larger than blockDim.x * 2
521
+ // if monotonic_type = 1 then increase
522
+ // if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
523
+ // == 1 the increase
524
+ template <typename T>
525
+ __device__ __forceinline__ void Sort (T* dst,
526
+ const T* src_data,
527
+ int num,
528
+ int monotonic_type) {
529
+ // todo: set num = Pow2(num)
530
+ // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
531
+ __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
532
+ // blockDim * 2
533
+ // Copy value and index from src and src_index
534
+ value[threadIdx.x ] = src_data[0 ];
535
+ value[threadIdx.x + (SHARED_SIZE_LIMIT / 2 )] = src_data[1 ];
536
+ // make bitonicSort
537
+ for (int size = 2 ; size < num; size <<= 1 ) {
538
+ int bitonic_type = (threadIdx.x & (size / 2 )) != 0 ;
539
+ for (int stride = size / 2 ; stride > 0 ; stride >>= 1 ) {
540
+ __syncthreads ();
541
+ int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1 ));
542
+ details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type);
543
+ }
544
+ }
545
+ // last sort
546
+ for (int stride = SHARED_SIZE_LIMIT / 2 ; stride > 0 ; stride >>= 1 ) {
547
+ __syncthreads ();
548
+ int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1 ));
549
+ // last sort when monotonic_type = 1 then increase
550
+ details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type);
551
+ }
552
+ __syncthreads ();
553
+ dst[0 ] = value[threadIdx.x ];
554
+ dst[1 ] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2 )];
555
+ }
556
+
557
+ template <typename T, typename IndexType>
558
+ __device__ __forceinline__ void Sort (T* dst,
559
+ IndexType* dst_index,
560
+ const T* src_data,
561
+ IndexType* src_index,
562
+ int num,
563
+ int monotonic_type) {
564
+ // todo: set num = Pow2(num)
565
+ // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
566
+ __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
567
+ // blockDim * 2
568
+ __shared__ IndexType index [SHARED_SIZE_LIMIT];
569
+ // Copy value and index from src and src_index
570
+ value[threadIdx.x ] = src_data[0 ];
571
+ value[threadIdx.x + (SHARED_SIZE_LIMIT / 2 )] = src_data[1 ];
572
+ // index
573
+ index [threadIdx.x ] = src_index[0 ];
574
+ index [threadIdx.x + (SHARED_SIZE_LIMIT / 2 )] = src_index[1 ];
575
+ // make bitonicSort
576
+ for (int size = 2 ; size < num; size <<= 1 ) {
577
+ int bitonic_type = (threadIdx.x & (size / 2 )) != 0 ;
578
+ for (int stride = size / 2 ; stride > 0 ; stride >>= 1 ) {
579
+ __syncthreads ();
580
+ int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1 ));
581
+ details::ComparatorWithIndex<T, IndexType>(&value[pos],
582
+ &value[pos + stride],
583
+ &index [pos],
584
+ &index [pos + stride],
585
+ bitonic_type);
586
+ }
587
+ }
588
+
589
+ for (int stride = SHARED_SIZE_LIMIT / 2 ; stride > 0 ; stride >>= 1 ) {
590
+ __syncthreads ();
591
+ int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1 ));
592
+ // last sort when monotonic_type = 1 then increase
593
+ details::ComparatorWithIndex<T, IndexType>(&value[pos],
594
+ &value[pos + stride],
595
+ &index [pos],
596
+ &index [pos + stride],
597
+ monotonic_type);
598
+ }
599
+
600
+ __syncthreads ();
601
+ dst[0 ] = value[threadIdx.x ];
602
+ dst[1 ] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2 )];
603
+ dst_index[0 ] = index [threadIdx.x ];
604
+ dst_index[1 ] = index [threadIdx.x + (SHARED_SIZE_LIMIT / 2 )];
605
+ }
606
+
484
607
} // namespace kps
485
608
} // namespace phi
0 commit comments