Skip to content

Commit

Permalink
sync to Latest tokenizer code from llama2.c
Browse files Browse the repository at this point in the history
  • Loading branch information
ankan-ban committed Sep 12, 2023
1 parent 7a456ba commit 3486e56
Showing 1 changed file with 198 additions and 64 deletions.
262 changes: 198 additions & 64 deletions llama2_q4_opt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ int checkpoint_init_weights(TransformerWeights* w, Config* p, FILE* f) {
scratch_size *= sizeof(half);
void* scratchCpu = malloc(scratch_size);

printf("\nLoading Weights... ");

readWeight(w->token_embedding_table, f, p->vocab_size * p->dim * sizeof(half), scratchCpu);
readWeight(w->wcls, f, p->vocab_size * p->dim * sizeof(half), scratchCpu);
readWeight(w->rms_final_weight, f, p->dim * sizeof(half), scratchCpu);
Expand All @@ -599,7 +601,7 @@ int checkpoint_init_weights(TransformerWeights* w, Config* p, FILE* f) {
readWeight(w->layers[i].rms_ffn_weight, f, p->dim * sizeof(half), scratchCpu);
}

printf("\nloaded weights\n");
printf("done!\n");
free(scratchCpu);
return 0;
}
Expand Down Expand Up @@ -760,31 +762,182 @@ void transformer(bool gen_token, Config* p, RunState* s, TransformerWeights* w)
}

// ----------------------------------------------------------------------------
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
// hardcoded for llama models
constexpr int bos_token = 1;
constexpr int eos_token = 2;


// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
char* str;
int id;
} TokenIndex;

int str_lookup(char *str, char **vocab, int vocab_size) {
// find the first perfect match for str in vocab, return its index or -1 if not found
typedef struct {
char** vocab;
float* vocab_scores;
TokenIndex* sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;

int compare_tokens(const void* a, const void* b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
FILE* file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
if (strcmp(str, vocab[i]) == 0) {
return i;
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i] = (char*)malloc(len + 1);
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}

void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
free(t->vocab);
free(t->vocab_scores);
free(t->sorted_vocab);
}

char* decode(Tokenizer* t, int prev_token, int token) {
char* piece = t->vocab[token];
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == bos_token && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
piece = (char*)t->byte_pieces + byte_val * 2;
}
return piece;
}

void safe_printf(char* piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
if (piece == NULL) { return; }
if (piece[0] == '\0') { return; }
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
return -1;
printf("%s", piece);
}

void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {

// a temporary buffer to merge two consecutive tokens
char* str_buffer = (char*) malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
int str_lookup(char* str, TokenIndex* sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { str }; // acts as the key to search for
TokenIndex* res = (TokenIndex*) bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
return res != NULL ? res->id : -1;
}

// first encode every individual byte in the input string
*n_tokens = 0; // the number of tokens
for (char *c = text; *c != '\0'; c++) {
sprintf(str_buffer, "%c", *c);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id == -1) { printf("not good\n"); exit(1);}
tokens[*n_tokens] = id;
(*n_tokens)++;
void encode(Tokenizer* t, char* text, int8_t bos, int8_t eos, int* tokens, int* n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }

if (t->sorted_vocab == NULL) {
// lazily malloc and sort the vocabulary
t->sorted_vocab = (TokenIndex*) malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
t->sorted_vocab[i].str = t->vocab[i];
t->sorted_vocab[i].id = i;
}
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
}

// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char* str_buffer = (char *) malloc((t->max_token_length * 2 + 1 + 2) * sizeof(char));
size_t str_len = 0;

// start at 0 tokens
*n_tokens = 0;

// add optional BOS (=1) token, if desired
if (bos) tokens[(*n_tokens)++] = bos_token;

// add_dummy_prefix is true by default
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;
}

// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
// U+0000 U+007F 0xxxxxxx
// U+0080 U+07FF 110xxxxx 10xxxxxx
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx

// process the raw (UTF-8) byte sequence of the input string
for (char* c = text; *c != '\0'; c++) {

// reset buffer if the current byte is ASCII or a leading byte
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
// 0x80 is 10000000
// in UTF-8, all continuation bytes start with "10" in first two bits
// so in English this is: "if this byte is not a continuation byte"
if ((*c & 0xC0) != 0x80) {
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
// => reset our location, as we're starting a new UTF-8 codepoint
str_len = 0;
}

// append the current byte to the buffer
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
str_buffer[str_len] = '\0';

// while the next character is a continuation byte, continue appending
// but if there are too many of them, just stop to avoid overruning str_buffer size.
if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
continue;
}

// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);

if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens[(*n_tokens)++] = id;
}
else {
// byte_fallback encoding: just encode each byte as a token
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
// so the individual bytes only start at index 3
for (int i = 0; i < str_len; i++) {
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
}
}
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
}

// merge the best consecutive pair each iteration, according the scores in vocab_scores
Expand All @@ -793,13 +946,13 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
int best_id = -1;
int best_idx = -1;

for (int i=0; i < (*n_tokens-1); i++) {
for (int i = 0; i < (*n_tokens - 1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id != -1 && vocab_scores[id] > best_score) {
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i + 1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocab_scores[id];
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
}
Expand All @@ -812,12 +965,15 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
tokens[i] = tokens[i+1];
for (int i = best_idx + 1; i < (*n_tokens - 1); i++) {
tokens[i] = tokens[i + 1];
}
(*n_tokens)--; // token length decreased
}

// add optional EOS (=2) token, if desired
if (eos) tokens[(*n_tokens)++] = eos_token;

free(str_buffer);
}

Expand Down Expand Up @@ -876,77 +1032,56 @@ int main(int argc, char *argv[]) {
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }

// read in the tokenizer.bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
unsigned int max_token_length;
{
FILE *file = fopen("tokenizer.bin", "rb");
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
int len;
for (int i = 0; i < config.vocab_size; i++) {
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
vocab[i] = (char *)malloc(len + 1);
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}
// create and init the tokenizer
Tokenizer tokenizer;
build_tokenizer(&tokenizer, "tokenizer.bin", config.vocab_size);

// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
cudaStreamCreate(&stream);

// process the prompt, if any
int *prompt_tokens = NULL;
int *prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
int num_prompt_tokens = 0;
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));

char input_message[2048];
strcpy(input_message, prompt);

while (1)
{
if (input_message != NULL) {
bpe_encode(input_message, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
encode(&tokenizer, input_message, 1, 0, prompt_tokens, &num_prompt_tokens);
//printf("\nPrompt tokens: %d - \n", num_prompt_tokens);
//for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
//printf("\n");
}


// start the main loop
long start = time_in_ms(); // used to time our code
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int token = bos_token; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence

// copy the prompt tokens into shared list of tokens (so that GPU can access them).
// init state
cudaMemset(state.pos, 0, sizeof(int));
state.shared_data->pos = 0;
state.shared_data->tokens[0] = token; // BOS
memcpy(&state.shared_data->tokens[1], prompt_tokens, sizeof(int) * num_prompt_tokens);
memcpy(&state.shared_data->tokens, prompt_tokens, sizeof(int) * num_prompt_tokens);

printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
while (pos < steps) {
// wait for GPU work for previous iteration to complete
// the idea is to keep GPU working in parallel with any CPU work (e.g, printing tokens to console).
cudaStreamSynchronize(stream);
// Perf note: don't put CPU work here "before" calling transformer as it won't overlap with GPU execution.
transformer(pos >= num_prompt_tokens, &config, &state, &weights); // forward the transformer to get next token
transformer(pos >= num_prompt_tokens - 1, &config, &state, &weights); // forward the transformer to get next token

if (pos > 0)
{
next = state.shared_data->tokens[pos]; // Note: this is output token from previous iteration

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char* token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next] + 1 : vocab[next];
printf("%s", token_str);
//printf(" [%d - %s] ", next, token_str);
fflush(stdout);

if (next == 2) break; // break if EOS token is reached
char* piece = decode(&tokenizer, token, next);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
if (next == eos_token) break; // break if EOS token is reached

// advance forward
token = next;
Expand All @@ -962,6 +1097,7 @@ int main(int argc, char *argv[]) {

printf("enter next prompt: ");
fgets(input_message, sizeof(input_message), stdin);
input_message[strlen(input_message) - 1] = 0; // strip the new-line
}

// memory cleanup
Expand All @@ -972,9 +1108,7 @@ int main(int argc, char *argv[]) {
if (graphCaptured[i]) cudaGraphExecDestroy(cudaGraphInstance[i]);
#endif

for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
free(vocab_scores);
free_tokenizer(&tokenizer);
if (prompt_tokens != NULL) free(prompt_tokens);
return 0;
}

0 comments on commit 3486e56

Please sign in to comment.