@@ -472,7 +472,8 @@ X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
472472                                   arrsize_t  *arg,
473473                                   arrsize_t  left,
474474                                   arrsize_t  right,
475-                                    arrsize_t  max_iters)
475+                                    arrsize_t  max_iters,
476+                                    arrsize_t  task_threshold)
476477{
477478    /* 
478479     * Resort to std::sort if quicksort isnt making any progress 
@@ -494,11 +495,57 @@ X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
494495    type_t  biggest = vtype::type_min ();
495496    arrsize_t  pivot_index = argpartition_unrolled<vtype, argtype, 4 >(
496497            arr, arg, left, right + 1 , pivot, &smallest, &biggest);
498+ #ifdef  XSS_COMPILE_OPENMP
499+     if  (pivot != smallest) {
500+         bool  parallel_left = (pivot_index - left) > task_threshold;
501+         if  (parallel_left) {
502+ #pragma  omp task
503+             argsort_<vtype, argtype>(arr,
504+                                      arg,
505+                                      left,
506+                                      pivot_index - 1 ,
507+                                      max_iters - 1 ,
508+                                      task_threshold);
509+         }
510+         else  {
511+             argsort_<vtype, argtype>(arr,
512+                                      arg,
513+                                      left,
514+                                      pivot_index - 1 ,
515+                                      max_iters - 1 ,
516+                                      task_threshold);
517+         }
518+     }
519+     if  (pivot != biggest) {
520+         bool  parallel_right = (right - pivot_index) > task_threshold;
521+ 
522+         if  (parallel_right) {
523+ #pragma  omp task
524+             argsort_<vtype, argtype>(arr,
525+                                      arg,
526+                                      pivot_index,
527+                                      right,
528+                                      max_iters - 1 ,
529+                                      task_threshold);
530+         }
531+         else  {
532+             argsort_<vtype, argtype>(arr,
533+                                      arg,
534+                                      pivot_index,
535+                                      right,
536+                                      max_iters - 1 ,
537+                                      task_threshold);
538+         }
539+     }
540+ #else 
541+     UNUSED (task_threshold);
497542    if  (pivot != smallest)
498543        argsort_<vtype, argtype>(
499-                 arr, arg, left, pivot_index - 1 , max_iters - 1 );
544+                 arr, arg, left, pivot_index - 1 , max_iters - 1 ,  0 );
500545    if  (pivot != biggest)
501-         argsort_<vtype, argtype>(arr, arg, pivot_index, right, max_iters - 1 );
546+         argsort_<vtype, argtype>(
547+                 arr, arg, pivot_index, right, max_iters - 1 , 0 );
548+ #endif 
502549}
503550
504551template  <typename  vtype, typename  argtype, typename  type_t >
@@ -570,8 +617,43 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
570617            }
571618        }
572619        UNUSED (hasnan);
620+ 
621+ #ifdef  XSS_COMPILE_OPENMP
622+ 
623+         bool  use_parallel = arrsize > 10000 ;
624+ 
625+         if  (use_parallel) {
626+             //  This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
627+             constexpr  int  thread_limit = 8 ;
628+             int  thread_count = std::min (thread_limit, omp_get_max_threads ());
629+             arrsize_t  task_threshold
630+                     = std::max ((arrsize_t )100000 , arrsize / 100 );
631+ 
632+             //  We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
633+             //  The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
634+             //  Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
635+ #pragma  omp parallel num_threads(thread_count)
636+ #pragma  omp single
637+             argsort_<vectype, argtype>(arr,
638+                                        arg,
639+                                        0 ,
640+                                        arrsize - 1 ,
641+                                        2  * (arrsize_t )log2 (arrsize),
642+                                        task_threshold);
643+         }
644+         else  {
645+             argsort_<vectype, argtype>(arr,
646+                                        arg,
647+                                        0 ,
648+                                        arrsize - 1 ,
649+                                        2  * (arrsize_t )log2 (arrsize),
650+                                        std::numeric_limits<arrsize_t >::max ());
651+         }
652+ #pragma  omp taskwait
653+ #else 
573654        argsort_<vectype, argtype>(
574-                 arr, arg, 0 , arrsize - 1 , 2  * (arrsize_t )log2 (arrsize));
655+                 arr, arg, 0 , arrsize - 1 , 2  * (arrsize_t )log2 (arrsize), 0 );
656+ #endif 
575657
576658        if  (descending) { std::reverse (arg, arg + arrsize); }
577659    }
0 commit comments