Skip to content

Commit

Permalink
initial support for code-llama models
Browse files Browse the repository at this point in the history
- Tested with codellama-13b-instruct model. Seems to work well (with it's own tokenizer.bin)
- As a hack we currently decide the value of rope_theta based on vocab_size. Need to fix this (put in config/bin file header).
  • Loading branch information
ankan-ban committed Oct 5, 2023
1 parent 3ebbf8c commit ec8e1a3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
4 changes: 3 additions & 1 deletion common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ typedef struct {
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int vocab_size; // vocabulary size, usually 32000 for llama2 models.
int seq_len; // max sequence length
} Config;

Expand Down Expand Up @@ -68,6 +68,8 @@ typedef struct {
SharedData* shared_data;

float* logits_array; // array of output logits used to compute perplexity (seq_len, vocab_size)

float rope_theta; // theta for the rope rotational embedding. TODO: This really should be part of Config!
} RunState;

int divUp(int a, int b) {
Expand Down
4 changes: 2 additions & 2 deletions gpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,15 @@ __global__ void vec_mat_kernel(half* op, const half* __restrict__ ip, const half
}

// Each block processes a single head
__global__ void RoPERotation_kernel(half* sq, half* sk_base, int num_heads, int head_size, int* pPos, int loff) {
__global__ void RoPERotation_kernel(half* sq, half* sk_base, int num_heads, int head_size, int* pPos, int loff, float rope_theta) {
int pos = *pPos;
half* sk = sk_base + loff + pos * num_heads * head_size;
int h = blockIdx.x;
half* q = sq + h * head_size;
half* k = sk + h * head_size;
int i = threadIdx.x;
int head_dim = (i * 2) % head_size;
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
float freq = 1.0f / powf(rope_theta, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
Expand Down
26 changes: 15 additions & 11 deletions llama2_q4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void malloc_run_state(RunState* s, Config* p, bool allocLogitsArray) {
cudaMalloc((void**)&s->q, p->dim * sizeof(half));
cudaMalloc((void**)&s->att, p->n_heads * p->dim * sizeof(half));
cudaMalloc((void**)&s->logits, p->vocab_size * sizeof(half));
cudaMalloc((void**)&s->key_cache, p->n_layers * p->seq_len * p->dim * sizeof(half)); // potentially huge allocs
cudaMalloc((void**)&s->value_cache, p->n_layers * p->seq_len * p->dim * sizeof(half));
cudaMalloc((void**)&s->key_cache, sizeof(half) * p->n_layers * p->seq_len * p->dim); // potentially huge allocs
cudaMalloc((void**)&s->value_cache, sizeof(half) * p->n_layers * p->seq_len * p->dim);

cudaMalloc((void**)&s->pos, sizeof(int));
cudaMallocHost((void**)&s->shared_data, sizeof(SharedData));
Expand All @@ -58,12 +58,16 @@ void malloc_run_state(RunState* s, Config* p, bool allocLogitsArray) {
}

if (allocLogitsArray) {
cudaMalloc((void**)&s->logits_array, p->seq_len * p->vocab_size * sizeof(float));
cudaMalloc((void**)&s->logits_array, sizeof(float) * p->seq_len * p->vocab_size);
if (!s->logits_array) {
printf("malloc failed for allocaing logits_array!\n");
exit(EXIT_FAILURE);
}
}

// HACK: set rope theta based on which model we are trying to run (based on vocab_size)
// this is a hack to get around the fact that we don't have it in the config header and changing the header would require changing the weight files
s->rope_theta = p->vocab_size == 32016 ? 1000000.0f : 10000.0f; // codellama models use 1000000.0f, while llama2 models use 10000.0f
}

void free_run_state(RunState* s) {
Expand Down Expand Up @@ -232,8 +236,8 @@ void matmul(half* xout, half* x, QWeight &w, int inpSize, int opSize, bool accum
mat_vec_kernel_int4 <<<grid_dim, block_dim, 0, stream >>> (xout, x, w.weight, w.zeros, w.scales, inpSize, opSize, packed_zeros_height, scales_height, packed_wt_height, accum, loff, pPos);
}

void RoPERotation(half *q, half *k, int num_heads, int head_size, int* pPos, int loff) {
RoPERotation_kernel <<<num_heads, head_size / 2, 0, stream >>> (q, k, num_heads, head_size, pPos, loff);
void RoPERotation(half *q, half *k, int num_heads, int head_size, int* pPos, int loff, float rope_theta) {
RoPERotation_kernel <<<num_heads, head_size / 2, 0, stream >>> (q, k, num_heads, head_size, pPos, loff, rope_theta);
}

void MultiHeadAttention(half *output, half *q, half *key_cache, half * value_cache, half *att, int num_heads, int head_size, int max_seq_len, int *pPos) {
Expand Down Expand Up @@ -280,7 +284,7 @@ void run_llama_network(int *pPos, Config* p, RunState* s, TransformerWeights* w,

// apply RoPE rotation to the q and k vectors for each head
// also save the output (key, value) at this time step (pos) to our kv cache
RoPERotation(s->q, s->key_cache, p->n_heads, head_size, pPos, loff);
RoPERotation(s->q, s->key_cache, p->n_heads, head_size, pPos, loff, s->rope_theta);

// apply MHA using the query and the key-value cache
MultiHeadAttention(s->xb, s->q, s->key_cache + loff, s->value_cache + loff, s->att, p->n_heads, head_size, seq_len_bin, pPos);
Expand Down Expand Up @@ -379,8 +383,8 @@ void error_usage(char *argv[]) {
fprintf(stderr, "Options:\n");
fprintf(stderr, " -n <int> max number of steps to run for, default = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.6\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 0.5\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -q <string> compute perplexity on the given dataset file\n");
Expand All @@ -397,8 +401,8 @@ int main(int argc, char *argv[]) {
int steps = 0; // number of steps to run for
char* prompt = nullptr; // prompt string
bool perplexity = false;
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.6f; // top-p in nucleus sampling. 1.0 = off. 0.6 works well, but slower
float temperature = 0.5f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.6f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
unsigned long long rng_seed = 0; // seed rng with time by default

// poor man's C argparse
Expand Down Expand Up @@ -429,7 +433,7 @@ int main(int argc, char *argv[]) {
// parameter validation/overrides
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.6;
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
if (steps < 0) steps = 0;

// read in the model.bin file
Expand Down

0 comments on commit ec8e1a3

Please sign in to comment.