@@ -85,6 +85,7 @@ struct SamplingTempStorage {
8585 union {
8686 T value;
8787 Pair<T> pair;
88+ T max_p;
8889 } block_aggregate;
8990 } data;
9091};
@@ -447,6 +448,112 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
447448 }
448449}
449450
451+ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
452+ BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
453+ typename DType, typename IdType>
454+ __global__ void MinPSamplingFromProbKernel (DType* probs, DType* uniform_samples, DType* min_p,
455+ IdType* output, bool * success, uint32_t d,
456+ uint32_t max_min_p_rounds) {
457+ const uint32_t batch_size = gridDim .x ;
458+ const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
459+ DType p = min_p[bx];
460+
461+ extern __shared__ __align__ (
462+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
463+ uint8_t smem_sampling[];
464+ auto & temp_storage = reinterpret_cast <
465+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
466+
467+ vec_t <DType, VEC_SIZE> probs_vec;
468+ DType aggregate;
469+ DType q = DType (1 );
470+ DType pivot = DType (0 );
471+
472+ DType max_p = 0 ;
473+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
474+ probs_vec.fill (DType (0 ));
475+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
476+ probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
477+ }
478+ DType probs_[VEC_SIZE];
479+ #pragma unroll
480+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
481+ probs_[j] = probs_vec[j];
482+ }
483+ max_p = max (max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim .reduce )
484+ .Reduce <VEC_SIZE>(probs_, cub::Max ()));
485+ __syncthreads ();
486+ }
487+ if (tx == 0 ) {
488+ temp_storage.data .block_aggregate .max_p = max_p;
489+ }
490+ __syncthreads ();
491+ DType scaled_p = temp_storage.data .block_aggregate .max_p * p;
492+
493+ IdType sampled_id;
494+ for (uint32_t round = 0 ; round < max_min_p_rounds; ++round) {
495+ temp_storage.data .sampled_id = d - 1 ;
496+ __syncthreads ();
497+ DType u = uniform_samples[round * batch_size + bx] * q;
498+ aggregate = DType (0 );
499+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
500+ probs_vec.fill (DType (0 ));
501+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
502+ probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
503+ }
504+
505+ DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
506+ DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
507+ &temp_storage);
508+ if (aggregate > u) {
509+ break ;
510+ }
511+ }
512+ __syncthreads ();
513+ sampled_id = temp_storage.data .sampled_id ;
514+ pivot = max (pivot, probs[bx * d + sampled_id]);
515+ if (pivot >= scaled_p) {
516+ break ;
517+ }
518+
519+ DType aggregate_gt_pivot = DType (0 );
520+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
521+ probs_vec.fill (DType (0 ));
522+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
523+ probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524+ }
525+
526+ DType probs_gt_pivot[VEC_SIZE];
527+ #pragma unroll
528+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
529+ probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType (0 );
530+ }
531+
532+ aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim .reduce )
533+ .Sum <VEC_SIZE>(probs_gt_pivot);
534+ if (tx == 0 ) {
535+ temp_storage.data .block_aggregate .value = aggregate_gt_pivot;
536+ }
537+ __syncthreads ();
538+ }
539+ q = temp_storage.data .block_aggregate .value ;
540+ }
541+ __syncthreads ();
542+ if (tx == 0 ) {
543+ if (pivot < scaled_p) {
544+ // failed to sample within MAX_ROUNDS
545+ if (success != nullptr ) {
546+ success[bx] = false ;
547+ }
548+ } else {
549+ output[bx] = sampled_id;
550+ if (success != nullptr ) {
551+ success[bx] = true ;
552+ }
553+ }
554+ }
555+ }
556+
450557template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
451558 BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
452559 typename DType, typename IdType>
@@ -636,6 +743,30 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
636743 return cudaSuccess;
637744}
638745
746+ template <typename T, typename IdType>
747+ cudaError_t MinPSamplingFromProb (T* probs, T* uniform_samples, T* min_p, IdType* output,
748+ bool * success, uint32_t batch_size, uint32_t d,
749+ uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0 ) {
750+ constexpr uint32_t BLOCK_THREADS = 1024 ;
751+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
752+
753+ const uint32_t smem_size = sizeof (SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
754+ dim3 nblks (batch_size);
755+ dim3 nthrs (BLOCK_THREADS);
756+ void * args[] = {&probs, &uniform_samples, &min_p, &output, &success, &d, &max_rounds};
757+
758+ DISPATCH_ALIGNED_VEC_SIZE (
759+ vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC (deterministic, DETERMINISTIC, {
760+ auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE,
761+ DETERMINISTIC, T, IdType>;
762+ FLASHINFER_CUDA_CALL (
763+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
764+ FLASHINFER_CUDA_CALL (
765+ cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
766+ })});
767+ return cudaSuccess;
768+ }
769+
639770template <typename T, typename IdType>
640771cudaError_t TopKTopPSamplingFromProb (T* probs, T* uniform_samples, IdType* top_k, T* top_p,
641772 IdType* output, bool * success, uint32_t batch_size, uint32_t d,
0 commit comments