Skip to content

Commit

Permalink
support min_p sampling parameter (#2420)
Browse files Browse the repository at this point in the history
* support pytorch backend min_p

* support turbomind backend min_p

* fix comments

* remove unused header

* remove inplace

* use _filter_minp_sorted_

* remove end_ids from sampling

* use larger grid size in invokeTopPSortInitialize

* skip softmax for topk request

* use const

* use nullptr

* use eps

* fix pr test
  • Loading branch information
irexyc authored Sep 9, 2024
1 parent 659a6b0 commit 8e478d4
Show file tree
Hide file tree
Showing 36 changed files with 2,002 additions and 4,926 deletions.
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,9 @@ endif()
########################################

add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BaseSamplingLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
$<TARGET_OBJECTS:activation_kernels>
$<TARGET_OBJECTS:ban_bad_words>
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class GenerationConfig:
tokens with top_p probability mass
top_k (int): An alternative to sampling with temperature, where
the model considers the top_k tokens with the highest probability
min_p (float): Minimum token probability, which will be scaled by the
probability of the most likely token. It must be a value between
0 and 1. Typical values are in the 0.01-0.2 range, comparably
selective as setting `top_p` in the 0.99-0.8 range (use the
opposite of normal `top_p` values)
temperature (float): Sampling temperature
repetition_penalty (float): Penalty to prevent the model from
generating repeated words or phrases. A value larger than
Expand Down Expand Up @@ -59,6 +64,7 @@ class GenerationConfig:
do_sample: bool = False
top_p: float = 1.0
top_k: int = 50
min_p: float = 0.0
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
Expand Down Expand Up @@ -102,6 +108,8 @@ def __post_init__(self):
assert self.top_p > 0 and self.top_p <= 1 # (0, 1]
assert self.top_k >= 0, 'top_k can not be a negative integer'
assert self.temperature >= 0 and self.temperature <= 2 # [0,2]
assert 0 <= self.min_p <= 1, \
f'min_p should be in range [0, 1], but found {self.min_p}'


@pydantic_dataclass
Expand Down
31 changes: 31 additions & 0 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def _filter_topp_sorted_(scores: torch.Tensor,
return scores


def _filter_minp_sorted_(scores: torch.Tensor,
minp: torch.Tensor,
filter_value: float = -float('inf')):
"""filter minp on sorted scores."""
softmax_scores = scores.softmax(-1)
top_probs, _ = softmax_scores.max(dim=-1, keepdim=True)
scaled_min_p = minp.unsqueeze(dim=1) * top_probs
mask = softmax_scores < scaled_min_p
scores.masked_fill_(mask, filter_value)
return scores


def _multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
Expand Down Expand Up @@ -118,6 +130,7 @@ class SamplingInputs:
repetition_penalty: torch.Tensor = None
top_k: torch.LongTensor = None
top_p: torch.Tensor = None
min_p: torch.Tensor = None
random_seeds: int = None
random_offsets: int = None
max_top_k: int = 1
Expand All @@ -133,6 +146,7 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence]):
repetition_penalty = [None] * batch_size
top_k = [None] * batch_size
top_p = [None] * batch_size
min_p = [None] * batch_size
bad_words = [None] * batch_size
stop_words = [None] * batch_size
random_seeds = [torch.seed() & 0xffffffff] * batch_size
Expand All @@ -148,6 +162,7 @@ def __gather_params():
repetition_penalty[idx] = param.repetition_penalty
top_k[idx] = param.top_k
top_p[idx] = param.top_p
min_p[idx] = param.min_p
random_offsets[idx] = seq.random_offsets
response_formats[idx] = param.response_format
if param.random_seed is not None:
Expand All @@ -171,6 +186,15 @@ def __get_topp(top_p):
top_p = torch.tensor(top_p)
return top_p, min_top_p

def __get_minp(min_p):
"""get minp."""
max_min_p = max(min_p)
if max_min_p == 0.0:
min_p = None
else:
min_p = torch.Tensor(min_p)
return min_p

def __get_bad_words(bad_words):
"""get bad words."""
max_bw_len = max(len(bw) for bw in bad_words)
Expand Down Expand Up @@ -205,11 +229,13 @@ def __get_bad_words(bad_words):
if max_top_k == 1:
top_k = None
top_p, min_top_p = None, 1.0
min_p = None
random_seeds = None
random_offsets = None
else:
top_k = torch.tensor(top_k)
top_p, min_top_p = __get_topp(top_p)
min_p = __get_minp(min_p)
random_seeds = torch.tensor(random_seeds)
random_offsets = torch.tensor(random_offsets)

Expand All @@ -220,6 +246,7 @@ def __get_bad_words(bad_words):
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
min_p=min_p,
random_seeds=random_seeds,
random_offsets=random_offsets,
response_formats=tuple(response_formats),
Expand Down Expand Up @@ -331,6 +358,10 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
if top_p is not None:
scores = _filter_topp_sorted_(scores, top_p)

min_p = sampling_inputs.min_p
if min_p is not None:
scores = _filter_minp_sorted_(scores, min_p)

softmax_scores = scores.softmax(1)

softmax_scores = softmax_scores[:, :max_topk]
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SamplingParam:
"""Sampling parameter."""
top_p: float = 1.0
top_k: int = 1
min_p: float = 0.0
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
Expand All @@ -62,6 +63,7 @@ def from_gen_config(self, gen_config: GenerationConfig):

top_k = gen_config.top_k
top_p = gen_config.top_p
min_p = gen_config.min_p
temperature = gen_config.temperature
repetition_penalty = gen_config.repetition_penalty
max_new_tokens = gen_config.max_new_tokens
Expand All @@ -71,6 +73,10 @@ def from_gen_config(self, gen_config: GenerationConfig):
logger.warning('`top_p` has to be a float > 0 and < 1'
f' but is {top_p}')
top_p = 1.0
if min_p < 0 or min_p > 1.0:
logger.warning('`min_p` has to be a float > 0 and < 1'
f' but is {min_p}')
min_p = 0.0
if temperature == 0:
logger.warning('`temperature` is 0, set to 1e-6')
temperature = 1e-6
Expand All @@ -93,6 +99,7 @@ def from_gen_config(self, gen_config: GenerationConfig):
min_new_tokens = 0
return SamplingParam(top_p=top_p,
top_k=top_k,
min_p=min_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=gen_config.ignore_eos,
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
dtype=np.uint32),
runtime_top_k=_broadcast_np(gen_config.top_k, np.uint32),
runtime_top_p=_broadcast_np(gen_config.top_p, np.float32),
runtime_min_p=_broadcast_np(gen_config.min_p, np.float32),
temperature=_broadcast_np(gen_config.temperature, np.float32),
repetition_penalty=_broadcast_np(gen_config.repetition_penalty,
np.float32),
Expand Down Expand Up @@ -564,7 +565,9 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
stop_words = None
bad_words.append(self.eos_id)
else:
stop_words = gen_config.stop_token_ids
stop_words = gen_config.stop_token_ids or []
if self.eos_id not in stop_words:
stop_words.append(self.eos_id)
stop_words = _construct_stop_or_bad_words(stop_words)
bad_words = _construct_stop_or_bad_words(bad_words)

Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

add_library(sampling_kernels STATIC sampling_kernels.cu)
set_property(TARGET sampling_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET sampling_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

if (BUILD_TEST)
add_subdirectory(flash_attention)
endif ()
Expand Down
102 changes: 102 additions & 0 deletions src/turbomind/kernels/sampling_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/turbomind/kernels/sampling_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/utils/constant.h"

namespace turbomind {

template<typename T, int BLOCK_SIZE>
__global__ void sampling(const T* logits,
const int stride,
const int* indices,
const int* kept,
curandState_t* curandstate,
int* output_ids,
int* sequence_length,
float* sampled_logprobs,
uint32_t* sampled_indexes,
uint32_t* sampled_nums)
{
int tid = threadIdx.x;
int batch_id = blockIdx.x;
int n = kept[batch_id];

logits += stride * batch_id;
indices += stride * batch_id;

__shared__ float rand_num_s;
__shared__ int selected;
if (tid == 0) {
rand_num_s = curand_uniform(curandstate + batch_id);
}
__syncthreads();

typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;

float local_rand = rand_num_s;
float prefix_sum = 0.f;
BlockPrefixCallbackOp prefix_op{0};
int end = (n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_logit = (i < n) ? static_cast<float>(logits[i]) : 0.f;
BlockScan(temp_storage).InclusiveSum(thread_logit, prefix_sum, prefix_op);
auto count = __syncthreads_count(prefix_sum > local_rand);
if (count != 0 || (i + BLOCK_SIZE) >= end) {
if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {
selected = min(i, n - 1);
output_ids[batch_id] = indices[selected];

if (sequence_length != nullptr) {
sequence_length[batch_id] += 1;
}
}
break;
}
}

if (sampled_logprobs != nullptr && sampled_indexes != nullptr && sampled_nums != nullptr) {
__syncthreads();
sampled_logprobs += batch_id * kMaxLogProb;
sampled_indexes += batch_id * kMaxLogProb;
int end = min(n, kMaxLogProb);
for (int i = tid; i < end; i += BLOCK_SIZE) {
sampled_logprobs[i] = logf(logits[i]);
sampled_indexes[i] = indices[i];
}
if (n > kMaxLogProb && selected >= kMaxLogProb) {
if ((kMaxLogProb - 1 + BLOCK_SIZE - tid) % BLOCK_SIZE == 0) {
sampled_logprobs[kMaxLogProb - 1] = logf(logits[selected]);
sampled_indexes[kMaxLogProb - 1] = indices[selected];
}
}
sampled_nums[batch_id] = min(n, kMaxLogProb);
}
}

template<typename T>
void invokeSampling(SamplingParams& params, cudaStream_t stream)
{
const int grid = params.batch_size;
const int block = 256;
sampling<T, block><<<grid, block, 0, stream>>>((T*)params.logits,
params.stride,
params.indices,
params.kept,
params.curandstate,
params.output_ids,
params.sequence_length,
params.sampled_logprobs,
params.sampled_indexes,
params.sampled_nums);
}

template void invokeSampling<float>(SamplingParams& params, cudaStream_t stream);

} // namespace turbomind
42 changes: 42 additions & 0 deletions src/turbomind/kernels/sampling_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <stdint.h>

namespace turbomind {

struct SamplingParams {
void* logits;
int stride;
int* indices;
int* kept;
curandState_t* curandstate;
size_t batch_size;
int* output_ids;
int* sequence_length;
float* sampled_logprobs;
uint32_t* sampled_indexes;
uint32_t* sampled_nums;
};

template<typename T>
void invokeSampling(SamplingParams& params, cudaStream_t stream);

} // namespace turbomind
Loading

0 comments on commit 8e478d4

Please sign in to comment.