@@ -541,8 +541,11 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
541541
542542/* argsort methods for 32-bit and 64-bit dtypes */
543543template <typename T>
544- X86_SIMD_SORT_INLINE void
545- avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
544+ X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
545+ arrsize_t *arg,
546+ arrsize_t arrsize,
547+ bool hasnan = false ,
548+ bool descending = false )
546549{
547550 /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548551 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
@@ -558,29 +561,37 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558561 if constexpr (std::is_floating_point_v<T>) {
559562 if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560563 std_argsort_withnan (arr, arg, 0 , arrsize);
564+
565+ if (descending) { std::reverse (arg, arg + arrsize); }
566+
561567 return ;
562568 }
563569 }
564570 UNUSED (hasnan);
565571 argsort_64bit_<vectype, argtype>(
566572 arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
573+
574+ if (descending) { std::reverse (arg, arg + arrsize); }
567575 }
568576}
569577
570578template <typename T>
571- X86_SIMD_SORT_INLINE std::vector<arrsize_t >
572- avx512_argsort ( T *arr, arrsize_t arrsize, bool hasnan = false )
579+ X86_SIMD_SORT_INLINE std::vector<arrsize_t > avx512_argsort (
580+ T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
573581{
574582 std::vector<arrsize_t > indices (arrsize);
575583 std::iota (indices.begin (), indices.end (), 0 );
576- avx512_argsort<T>(arr, indices.data (), arrsize, hasnan);
584+ avx512_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
577585 return indices;
578586}
579587
580588/* argsort methods for 32-bit and 64-bit dtypes */
581589template <typename T>
582- X86_SIMD_SORT_INLINE void
583- avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
590+ X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
591+ arrsize_t *arg,
592+ arrsize_t arrsize,
593+ bool hasnan = false ,
594+ bool descending = false )
584595{
585596 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586597 avx2_half_vector<T>,
@@ -594,22 +605,27 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594605 if constexpr (std::is_floating_point_v<T>) {
595606 if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596607 std_argsort_withnan (arr, arg, 0 , arrsize);
608+
609+ if (descending) { std::reverse (arg, arg + arrsize); }
610+
597611 return ;
598612 }
599613 }
600614 UNUSED (hasnan);
601615 argsort_64bit_<vectype, argtype>(
602616 arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
617+
618+ if (descending) { std::reverse (arg, arg + arrsize); }
603619 }
604620}
605621
606622template <typename T>
607- X86_SIMD_SORT_INLINE std::vector<arrsize_t >
608- avx2_argsort ( T *arr, arrsize_t arrsize, bool hasnan = false )
623+ X86_SIMD_SORT_INLINE std::vector<arrsize_t > avx2_argsort (
624+ T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
609625{
610626 std::vector<arrsize_t > indices (arrsize);
611627 std::iota (indices.begin (), indices.end (), 0 );
612- avx2_argsort<T>(arr, indices.data (), arrsize, hasnan);
628+ avx2_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
613629 return indices;
614630}
615631
0 commit comments