Skip to content

Commit f4e7488

Browse files
Add Sort API for Kernel Primitive API (#39734)
* Add Sort API for Kernel Primitive API * update & -> ptr
1 parent de760d2 commit f4e7488

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

paddle/phi/kernels/primitive/compute_primitives.h

+123
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
132132
return shared_memory[threadIdx.x];
133133
}
134134

135+
// Swap data
136+
template <typename T>
137+
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
138+
T t_value;
139+
t_value = (*first_value);
140+
(*first_value) = (*second_value);
141+
(*second_value) = t_value;
142+
}
143+
144+
// swap with monotonic_type
145+
template <typename T>
146+
__device__ __forceinline__ void Comparator(T* first_value,
147+
T* second_value,
148+
int monotonic_type) {
149+
if (((*first_value) > (*second_value)) == monotonic_type) {
150+
Swap<T>(first_value, second_value);
151+
}
152+
}
153+
154+
template <typename T, typename IndexType>
155+
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,
156+
157+
T* second_value,
158+
IndexType* first_index,
159+
IndexType* second_index,
160+
int monotonic_type) {
161+
if ((*first_value > (*second_value)) == monotonic_type) {
162+
// swap value
163+
Swap<T>(first_value, second_value);
164+
// swap index
165+
Swap<IndexType>(first_index, second_index);
166+
}
167+
}
168+
135169
} // namespace details
136170

137171
/**
@@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out,
481515
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
482516
}
483517

518+
#define SHARED_SIZE_LIMIT \
519+
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
520+
// larger than blockDim.x * 2
521+
// if monotonic_type = 1 then increase
522+
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
523+
// == 1 the increase
524+
template <typename T>
525+
__device__ __forceinline__ void Sort(T* dst,
526+
const T* src_data,
527+
int num,
528+
int monotonic_type) {
529+
// todo: set num = Pow2(num)
530+
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
531+
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
532+
// blockDim * 2
533+
// Copy value and index from src and src_index
534+
value[threadIdx.x] = src_data[0];
535+
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
536+
// make bitonicSort
537+
for (int size = 2; size < num; size <<= 1) {
538+
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
539+
for (int stride = size / 2; stride > 0; stride >>= 1) {
540+
__syncthreads();
541+
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
542+
details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type);
543+
}
544+
}
545+
// last sort
546+
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
547+
__syncthreads();
548+
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
549+
// last sort when monotonic_type = 1 then increase
550+
details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type);
551+
}
552+
__syncthreads();
553+
dst[0] = value[threadIdx.x];
554+
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
555+
}
556+
557+
template <typename T, typename IndexType>
558+
__device__ __forceinline__ void Sort(T* dst,
559+
IndexType* dst_index,
560+
const T* src_data,
561+
IndexType* src_index,
562+
int num,
563+
int monotonic_type) {
564+
// todo: set num = Pow2(num)
565+
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
566+
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
567+
// blockDim * 2
568+
__shared__ IndexType index[SHARED_SIZE_LIMIT];
569+
// Copy value and index from src and src_index
570+
value[threadIdx.x] = src_data[0];
571+
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
572+
// index
573+
index[threadIdx.x] = src_index[0];
574+
index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1];
575+
// make bitonicSort
576+
for (int size = 2; size < num; size <<= 1) {
577+
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
578+
for (int stride = size / 2; stride > 0; stride >>= 1) {
579+
__syncthreads();
580+
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
581+
details::ComparatorWithIndex<T, IndexType>(&value[pos],
582+
&value[pos + stride],
583+
&index[pos],
584+
&index[pos + stride],
585+
bitonic_type);
586+
}
587+
}
588+
589+
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
590+
__syncthreads();
591+
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
592+
// last sort when monotonic_type = 1 then increase
593+
details::ComparatorWithIndex<T, IndexType>(&value[pos],
594+
&value[pos + stride],
595+
&index[pos],
596+
&index[pos + stride],
597+
monotonic_type);
598+
}
599+
600+
__syncthreads();
601+
dst[0] = value[threadIdx.x];
602+
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
603+
dst_index[0] = index[threadIdx.x];
604+
dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
605+
}
606+
484607
} // namespace kps
485608
} // namespace phi

0 commit comments

Comments
 (0)