Skip to content

Commit d52f2da

Browse files
xslingcnyzh119
andauthored
sampling: support min_p sampling (#422)
This PR supports min_p sampling by adding `sampling.min_p_sampling_from_probs` API. - [x] Implement kernel - [x] Add Tests Ref: [Min P Sampling](https://arxiv.org/abs/2407.01082). --------- Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent 8e482d9 commit d52f2da

File tree

7 files changed

+278
-0
lines changed

7 files changed

+278
-0
lines changed

docs/api/python/sampling.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Kernels for LLM sampling.
1313
sampling_from_probs
1414
top_p_sampling_from_probs
1515
top_k_sampling_from_probs
16+
min_p_sampling_from_probs
1617
top_k_top_p_sampling_from_probs
1718
top_p_renorm_prob
1819
top_k_renorm_prob

include/flashinfer/sampling.cuh

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct SamplingTempStorage {
8585
union {
8686
T value;
8787
Pair<T> pair;
88+
T max_p;
8889
} block_aggregate;
8990
} data;
9091
};
@@ -447,6 +448,112 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
447448
}
448449
}
449450

451+
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
452+
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
453+
typename DType, typename IdType>
454+
__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType* min_p,
455+
IdType* output, bool* success, uint32_t d,
456+
uint32_t max_min_p_rounds) {
457+
const uint32_t batch_size = gridDim.x;
458+
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
459+
DType p = min_p[bx];
460+
461+
extern __shared__ __align__(
462+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
463+
uint8_t smem_sampling[];
464+
auto& temp_storage = reinterpret_cast<
465+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
466+
467+
vec_t<DType, VEC_SIZE> probs_vec;
468+
DType aggregate;
469+
DType q = DType(1);
470+
DType pivot = DType(0);
471+
472+
DType max_p = 0;
473+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
474+
probs_vec.fill(DType(0));
475+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
476+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
477+
}
478+
DType probs_[VEC_SIZE];
479+
#pragma unroll
480+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
481+
probs_[j] = probs_vec[j];
482+
}
483+
max_p = max(max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
484+
.Reduce<VEC_SIZE>(probs_, cub::Max()));
485+
__syncthreads();
486+
}
487+
if (tx == 0) {
488+
temp_storage.data.block_aggregate.max_p = max_p;
489+
}
490+
__syncthreads();
491+
DType scaled_p = temp_storage.data.block_aggregate.max_p * p;
492+
493+
IdType sampled_id;
494+
for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
495+
temp_storage.data.sampled_id = d - 1;
496+
__syncthreads();
497+
DType u = uniform_samples[round * batch_size + bx] * q;
498+
aggregate = DType(0);
499+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
500+
probs_vec.fill(DType(0));
501+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
502+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
503+
}
504+
505+
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
506+
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
507+
&temp_storage);
508+
if (aggregate > u) {
509+
break;
510+
}
511+
}
512+
__syncthreads();
513+
sampled_id = temp_storage.data.sampled_id;
514+
pivot = max(pivot, probs[bx * d + sampled_id]);
515+
if (pivot >= scaled_p) {
516+
break;
517+
}
518+
519+
DType aggregate_gt_pivot = DType(0);
520+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
521+
probs_vec.fill(DType(0));
522+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
523+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524+
}
525+
526+
DType probs_gt_pivot[VEC_SIZE];
527+
#pragma unroll
528+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
529+
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
530+
}
531+
532+
aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
533+
.Sum<VEC_SIZE>(probs_gt_pivot);
534+
if (tx == 0) {
535+
temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
536+
}
537+
__syncthreads();
538+
}
539+
q = temp_storage.data.block_aggregate.value;
540+
}
541+
__syncthreads();
542+
if (tx == 0) {
543+
if (pivot < scaled_p) {
544+
// failed to sample within MAX_ROUNDS
545+
if (success != nullptr) {
546+
success[bx] = false;
547+
}
548+
} else {
549+
output[bx] = sampled_id;
550+
if (success != nullptr) {
551+
success[bx] = true;
552+
}
553+
}
554+
}
555+
}
556+
450557
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
451558
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
452559
typename DType, typename IdType>
@@ -636,6 +743,30 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
636743
return cudaSuccess;
637744
}
638745

746+
template <typename T, typename IdType>
747+
cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p, IdType* output,
748+
bool* success, uint32_t batch_size, uint32_t d,
749+
uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) {
750+
constexpr uint32_t BLOCK_THREADS = 1024;
751+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
752+
753+
const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
754+
dim3 nblks(batch_size);
755+
dim3 nthrs(BLOCK_THREADS);
756+
void* args[] = {&probs, &uniform_samples, &min_p, &output, &success, &d, &max_rounds};
757+
758+
DISPATCH_ALIGNED_VEC_SIZE(
759+
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
760+
auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE,
761+
DETERMINISTIC, T, IdType>;
762+
FLASHINFER_CUDA_CALL(
763+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
764+
FLASHINFER_CUDA_CALL(
765+
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
766+
})});
767+
return cudaSuccess;
768+
}
769+
639770
template <typename T, typename IdType>
640771
cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p,
641772
IdType* output, bool* success, uint32_t batch_size, uint32_t d,

python/csrc/flashinfer_ops.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2626
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
2727
m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs,
2828
"Top-k sampling from probabilities");
29+
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs,
30+
"Min-p sampling from probabilities");
2931
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
3032
"Top-p sampling from probabilities");
3133
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs,

python/csrc/flashinfer_ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
4646
torch::Tensor uniform_samples,
4747
unsigned int top_k, bool deterministic);
4848

49+
std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
50+
torch::Tensor uniform_samples,
51+
torch::Tensor min_p, bool deterministic);
52+
4953
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
5054
torch::Tensor uniform_samples,
5155
torch::Tensor top_k, torch::Tensor top_p,

python/csrc/sampling.cu

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,42 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
106106
return {samples, success};
107107
}
108108

109+
std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
110+
torch::Tensor uniform_samples,
111+
torch::Tensor min_p, bool deterministic) {
112+
CHECK_INPUT(probs);
113+
CHECK_INPUT(uniform_samples);
114+
CHECK_INPUT(min_p);
115+
auto device = probs.device();
116+
CHECK_EQ(uniform_samples.device(), device);
117+
CHECK_EQ(min_p.device(), device);
118+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
119+
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
120+
CHECK_DIM(1, min_p); // min_p: (batch_size,)
121+
unsigned int batch_size = probs.size(0);
122+
unsigned int vocab_size = probs.size(1);
123+
unsigned int max_rounds = uniform_samples.size(0);
124+
CHECK_EQ(uniform_samples.size(1), batch_size);
125+
CHECK_EQ(min_p.size(0), batch_size);
126+
probs = probs.to(torch::kFloat32);
127+
uniform_samples = uniform_samples.to(torch::kFloat32);
128+
min_p = min_p.to(torch::kFloat32);
129+
130+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
131+
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
132+
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
133+
134+
cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
135+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
136+
static_cast<float*>(min_p.data_ptr()), static_cast<int*>(samples.data_ptr()),
137+
static_cast<bool*>(success.data_ptr()), batch_size, vocab_size, max_rounds, deterministic,
138+
torch_current_stream);
139+
TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " +
140+
std::string(cudaGetErrorString(status)));
141+
142+
return {samples, success};
143+
}
144+
109145
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
110146
torch::Tensor uniform_samples,
111147
torch::Tensor top_k, torch::Tensor top_p,

python/flashinfer/sampling.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,76 @@ def top_k_sampling_from_probs(
214214
)
215215

216216

217+
def min_p_sampling_from_probs(
218+
probs: torch.Tensor,
219+
uniform_samples: torch.Tensor,
220+
min_p: torch.Tensor,
221+
deterministic: bool = True,
222+
) -> Tuple[torch.Tensor, torch.Tensor]:
223+
r"""Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
224+
225+
this operator implements GPU-based rejection sampling without explicit sorting.
226+
227+
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
228+
which is more efficient than the naive implementation that launches a series of kernels.
229+
230+
Parameters
231+
----------
232+
probs: torch.Tensor
233+
Probabilities, shape ``(batch_size, num_classes)``.
234+
uniform_samples: torch.Tensor
235+
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
236+
where the first dimension is the maximum number of rounds for rejection sampling.
237+
Expected to be uniformly distributed in ``[0, 1)``.
238+
min_p: torch.Tensor
239+
The :math:`p_{\text{base}}` in min_p sampling for each request, shape ``(batch_size,)``.
240+
deterministic: bool
241+
Whether to use deterministic kernel implementation, default is ``True``.
242+
243+
Returns
244+
-------
245+
samples: torch.Tensor
246+
Sampled categories, shape ``(batch_size,)``.
247+
success: torch.Tensor
248+
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
249+
shape ``(batch_size,)``.
250+
251+
Examples
252+
--------
253+
254+
>>> import torch
255+
>>> import flashinfer
256+
>>> torch.manual_seed(42)
257+
<torch._C.Generator object at 0x7f8b3db06df0>
258+
>>> batch_size = 4
259+
>>> vocab_size = 5
260+
>>> max_rounds = 3
261+
>>> min_p = torch.full((batch_size,), 0.05).to(0)
262+
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
263+
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
264+
>>> norm_prob
265+
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
266+
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
267+
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
268+
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
269+
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
270+
>>> samples, success = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
271+
>>> samples
272+
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
273+
>>> success
274+
tensor([True, True, True, True], device='cuda:0')
275+
276+
Notes
277+
-----
278+
This function expects float32 inputs, and the output is int32.
279+
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
280+
implementation usually use much fewer rounds for rejection sampling because of early stopping.
281+
"""
282+
return _kernels.min_p_sampling_from_probs(
283+
probs, uniform_samples, min_p, deterministic
284+
)
285+
286+
217287
def top_k_top_p_sampling_from_probs(
218288
probs: torch.Tensor,
219289
uniform_samples: torch.Tensor,

python/tests/test_sampling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,40 @@ def test_top_k_sampling(batch_size, vocab_size, k):
9595
]
9696

9797

98+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
99+
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
100+
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
101+
def test_min_p_sampling(batch_size, vocab_size, p):
102+
torch.manual_seed(42)
103+
max_min_p_trails = 32
104+
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
105+
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
106+
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
107+
# scale min-p
108+
top_probs = sorted_prob[:, -1].unsqueeze(-1)
109+
scaled_p = p * top_probs
110+
# min-p mask
111+
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
112+
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
113+
114+
uniform_samples = torch.empty(max_min_p_trails, batch_size, dtype=torch.float32).to(
115+
0
116+
)
117+
min_p_tensor = torch.full((batch_size,), p).to(0)
118+
119+
num_trails = 1000
120+
for _ in range(num_trails):
121+
uniform_samples.uniform_()
122+
samples, success = flashinfer.sampling.min_p_sampling_from_probs(
123+
normalized_prob, uniform_samples, min_p_tensor
124+
)
125+
assert torch.all(success)
126+
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
127+
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
128+
torch.arange(batch_size), samples
129+
]
130+
131+
98132
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
99133
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
100134
@pytest.mark.parametrize("p", [0.1, 0.5])

0 commit comments

Comments
 (0)