Skip to content

Commit

Permalink
Add support for sampling (top-p and mult)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankan-ban committed Oct 4, 2023
1 parent c247a35 commit 38b7e9a
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 19 deletions.
4 changes: 4 additions & 0 deletions common.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,7 @@ typedef struct {

float* logits_array; // array of output logits used to compute perplexity (seq_len, vocab_size)
} RunState;

int divUp(int a, int b) {
return (a - 1) / b + 1;
}
96 changes: 94 additions & 2 deletions gpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __global__ void rmsnorm_kernel(half* o, half* x, half* weight, int size, int ele
__shared__ float shared_ss;
if (threadIdx.x == 0) {
ss /= size;
ss += 1e-6f;
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
shared_ss = ss;
}
Expand Down Expand Up @@ -353,4 +353,96 @@ __global__ void argmax_kernel(half* __restrict__ x, int size, int* result, volat
*pPos = token_pos;
*pPosGpu = token_pos;
}
}
}

// This is used for Top-P sampling. We do the following:
// 1. Divide the logits by temperature
// 2. Compute softmax
// 3. Write the indices in an array
__global__ void softmax_logits_kernel(half* __restrict__ logits, int size, float temperature, int *indices) {
int tid = threadIdx.x;
int step = blockDim.x;


for (int t = tid; t < size; t += step)
{
// first just write the indices array
indices[t] = t;

// divide by temperature
float val = (float)logits[t];
val /= temperature;
logits[t] = (half)val;
}
__syncthreads();

// Compute the softmax
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage temp;
__shared__ float shared_val;

// find max value (for numerical stability)
float max_val = tid < size ? ((float)logits[tid]) : -FLT_MAX;
for (int i = tid + step; i < size; i += step)
if ((float)logits[i] > max_val)
max_val = logits[i];

max_val = BlockReduce(temp).Reduce(max_val, cub::Max());
if (threadIdx.x == 0)
shared_val = max_val;
__syncthreads();
max_val = shared_val;

// exp and sum
float sum = 0.0f;
for (int i = tid; i < size; i += step) {
float v = expf(float(logits[i]) - max_val);
logits[i] = (half)v;
sum += v;
}

sum = BlockReduce(temp).Sum(sum);
if (threadIdx.x == 0)
shared_val = sum;
__syncthreads();
sum = shared_val;

// normalize and write the result
for (int t = tid; t < size; t += step)
logits[t] = (half)(float(logits[t]) / sum);
}

// ----------------------------------------------------------------------------

// find the index in the array that crosses top-p threshold
__global__ void sample_top_p_kernel(half* sorted_logits_prefix_sum, int* indices, int n, float top_p_threshold, int* result, volatile int* pPos, int* pPosGpu)
{
int tid = threadIdx.x;
int step = blockDim.x;

int min_index = n - 1;

for (int t = tid; t < n; t += step) {
if ((float)(sorted_logits_prefix_sum[t]) >= top_p_threshold) {
if (t < min_index) {
min_index = t;
}
}
}

// find the min across the block
using BlockReduce = cub::BlockReduce<int, 1024>;
__shared__ typename BlockReduce::TempStorage temp;
int min_index_global = BlockReduce(temp).Reduce(min_index, cub::Min());
if (threadIdx.x == 0)
{
int token_pos = *pPos;
token_pos++;
result[token_pos] = indices[min_index_global];

// update the token indices
*pPos = token_pos;
*pPosGpu = token_pos;
}
}

36 changes: 25 additions & 11 deletions llama2_q4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Inference for Llama-2 Transformer model in pure Cuda.
#include "common.h"
#include "gpu_kernels.h"
#include "tokenizer.h"
#include "sampler.h"
#include "perplexity.h"

constexpr int group_size = 128; // hardcoded for this implementation
Expand Down Expand Up @@ -79,10 +80,6 @@ void free_run_state(RunState* s) {
cudaFreeHost(s->shared_data);
}

int divUp(int a, int b) {
return (a - 1) / b + 1;
}

size_t getPackedWeightHeight(size_t height)
{
// Each uint32 element in the packed weight matrix contain 8 elements from the original matrix.
Expand Down Expand Up @@ -314,7 +311,7 @@ constexpr int MAX_GRAPHS = 8;
cudaGraphExec_t cudaGraphInstance[MAX_GRAPHS];
bool graphCaptured[MAX_GRAPHS];

void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w, bool copyLogits) {
void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w, bool copyLogits, Sampler *pSampler) {
#if DUMP_PER_TOKEN_TIMINGS == 1
cudaEvent_t start, stop;
cudaEventCreate(&start);
Expand Down Expand Up @@ -349,11 +346,10 @@ void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w,
// copy to the right slot in logits_array (and convert to FP32)
// we compute perplexity on the CPU later.
float* pOutput = s->logits_array + p->vocab_size * s->shared_data->pos;
convert_fp16_to_fp32 << < divUp(p->vocab_size, 128), 128 >> > (pOutput, s->logits, p->vocab_size);
convert_fp16_to_fp32 << < divUp(p->vocab_size, 128), 128, 0, stream >> > (pOutput, s->logits, p->vocab_size);
}

// sample the next token using greedy argmax sampling: take the token with the highest probability (not included in the graph because of gen_token variable)
argmax_kernel <<<1, 1024, 0, stream>>> (s->logits, p->vocab_size, &(s->shared_data->tokens[0]), &(s->shared_data->pos), s->pos, gen_token);
sample(pSampler, s, gen_token, stream);

#if DUMP_PER_TOKEN_TIMINGS == 1
cudaEventRecord(stop, stream);
Expand Down Expand Up @@ -383,6 +379,9 @@ void error_usage(char *argv[]) {
fprintf(stderr, "Options:\n");
fprintf(stderr, " -n <int> max number of steps to run for, default = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.6\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -q <string> compute perplexity on the given dataset file\n");
exit(EXIT_FAILURE);
Expand All @@ -398,7 +397,9 @@ int main(int argc, char *argv[]) {
int steps = 0; // number of steps to run for
char* prompt = nullptr; // prompt string
bool perplexity = false;

float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.6f; // top-p in nucleus sampling. 1.0 = off. 0.6 works well, but slower
unsigned long long rng_seed = 0; // seed rng with time by default

// poor man's C argparse
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(argv); }
Expand All @@ -413,6 +414,9 @@ int main(int argc, char *argv[]) {
case 'n': steps = atoi(argv[i + 1]); break;
case 'i': prompt = argv[i + 1]; break;
case 'z': tokenizer_path = argv[i + 1]; break;
case 't': temperature = atof(argv[i + 1]); break;
case 'p': topp = atof(argv[i + 1]); break;
case 's': rng_seed = atoi(argv[i + 1]); break;
case 'q': {
dataset_path = argv[i + 1];
perplexity = true;
Expand All @@ -422,6 +426,12 @@ int main(int argc, char *argv[]) {
}
}

// parameter validation/overrides
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.6;
if (steps < 0) steps = 0;

// read in the model.bin file
Config config = {};
TransformerWeights weights;
Expand All @@ -445,6 +455,10 @@ int main(int argc, char *argv[]) {
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, config.vocab_size);

// build the Sampler
Sampler sampler;
build_sampler(&sampler, config.vocab_size, temperature, topp, rng_seed);

// create and init the application RunState
RunState state;
malloc_run_state(&state, &config, perplexity);
Expand All @@ -458,7 +472,7 @@ int main(int argc, char *argv[]) {
else input_message[0] = 0;

if (perplexity) {
parseDataSetAndComputePreplexity(dataset_path, &tokenizer, &config, &state, &weights);
parseDataSetAndComputePreplexity(dataset_path, &tokenizer, &config, &state, &weights, &sampler);
}
else
while (1) {
Expand All @@ -485,7 +499,7 @@ int main(int argc, char *argv[]) {
// the idea is to keep GPU working in parallel with any CPU work (e.g, printing tokens to console).
cudaStreamSynchronize(stream);
// Perf note: don't put CPU work here "before" calling transformer as it won't overlap with GPU execution.
transformer(pos >= num_prompt_tokens - 1, &config, &state, &weights); // forward the transformer to get next token
transformer(pos >= num_prompt_tokens - 1, &config, &state, &weights, false, &sampler); // forward the transformer to get next token

if (pos > 0)
{
Expand Down
12 changes: 6 additions & 6 deletions perplexity.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ float compute_perplexity(int* tokens, float* logits, int num_tokens, int vocab_s
}


void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w, bool copyLogits = false);
void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w, bool copyLogits, Sampler* pSampler);

// ----------------------------------------------------------------------------
float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights) {
float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights, Sampler *pSampler) {
int bytes = strlen(dataset);
int* datasetTokens = (int*)malloc(bytes * sizeof(int));

Expand All @@ -78,7 +78,7 @@ float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config
state->shared_data->tokens[0] = bos_token;
memcpy(&(state->shared_data->tokens[1]), datasetTokens, sizeof(int) * numTokens);
for (int pos = 0; pos < numTokens; pos++) {
transformer(false, config, state, weights, true);
transformer(false, config, state, weights, true, pSampler);
cudaDeviceSynchronize();
}
printf("done!\n");
Expand All @@ -98,7 +98,7 @@ float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config
}


void parseDataSetAndComputePreplexity(char* textFileName, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights)
void parseDataSetAndComputePreplexity(char* textFileName, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights, Sampler *pSampler)
{
FILE* fp = fopen(textFileName, "rb+");
printf("\nLoading Dataset...");
Expand All @@ -125,12 +125,12 @@ void parseDataSetAndComputePreplexity(char* textFileName, Tokenizer* tokenizer,
if (nextseq = strstr(currentSeq, "<|endoftext|>")) {
*nextseq = 0;
nextseq += 13;
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights);
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights, pSampler);
count++;
currentSeq = nextseq;
}
else {
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights);
pplx_product *= get_dataset_perplexity(currentSeq, tokenizer, config, state, weights, pSampler);
count++;
break;
}
Expand Down
82 changes: 82 additions & 0 deletions sampler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#pragma once

typedef struct {
int vocab_size = 0;
int* indices = nullptr;
void* tempStorage_scan = nullptr;
void* tempStorage_sort = nullptr;
size_t temp_storage_bytes_scan = 0;
size_t temp_storage_bytes_sort = 0;
float temperature = 0;
float topp = 0;
unsigned long long rng_state = 0;
} Sampler;

void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
sampler->vocab_size = vocab_size;
sampler->temperature = temperature;
sampler->topp = topp;
sampler->rng_state = rng_seed;

// buffer only used with nucleus sampling
cudaMalloc((void**) & sampler->indices, vocab_size * sizeof(int));
}

void destroy_sampler(Sampler* sampler) {
cudaFree(sampler->indices);
cudaFree(sampler->tempStorage_sort);
cudaFree(sampler->tempStorage_scan);
}

unsigned int random_u32(unsigned long long* state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long* state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}

// sample the token given the logits and some hyperparameters
void sample(Sampler* sampler, RunState* s, bool gen_token, cudaStream_t stream) {
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&sampler->rng_state);

if (sampler->temperature == 0.0f || !gen_token) {
// greedy argmax sampling: take the token with the highest probability
argmax_kernel << <1, 1024, 0, stream >> > (s->logits, sampler->vocab_size, &(s->shared_data->tokens[0]), &(s->shared_data->pos), s->pos, gen_token);
}
else {
// apply the temperature to the logits, and then perform softmax
softmax_logits_kernel <<<1, 1024, 0, stream >>> (s->logits, sampler->vocab_size, sampler->temperature, sampler->indices);

float threshold = 0.0f;
// we sample from this distribution to get the next token
if (sampler->topp <= 0 || sampler->topp >= 1) {
threshold = coin;
}
else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
if (sampler->temp_storage_bytes_sort == 0) {
cub::DeviceRadixSort::SortPairsDescending(sampler->tempStorage_sort, sampler->temp_storage_bytes_sort, s->logits, s->logits, sampler->indices, sampler->indices,
sampler->vocab_size, 0, sizeof(half) * 8, stream);
cudaMalloc(&sampler->tempStorage_sort, sampler->temp_storage_bytes_sort);
}

cub::DeviceRadixSort::SortPairsDescending(sampler->tempStorage_sort, sampler->temp_storage_bytes_sort, s->logits, s->logits, sampler->indices, sampler->indices,
sampler->vocab_size, 0, sizeof(half) * 8, stream);
threshold = coin * sampler->topp;
}

// Sample from the predicted probability distribution
if (sampler->temp_storage_bytes_scan == 0) {
cub::DeviceScan::InclusiveSum(sampler->tempStorage_scan, sampler->temp_storage_bytes_scan, s->logits, s->logits, sampler->vocab_size, stream);
cudaMalloc(&sampler->tempStorage_scan, sampler->temp_storage_bytes_scan);
}
cub::DeviceScan::InclusiveSum(sampler->tempStorage_scan, sampler->temp_storage_bytes_scan, s->logits, s->logits, sampler->vocab_size, stream);

sample_top_p_kernel << <1, 1024, 0, stream >> > (s->logits, sampler->indices, sampler->vocab_size, threshold, &(s->shared_data->tokens[0]), &(s->shared_data->pos), s->pos);
}
}

0 comments on commit 38b7e9a

Please sign in to comment.