Skip to content

Commit

Permalink
split into multiple header files
Browse files Browse the repository at this point in the history
easier to read and manage
  • Loading branch information
ankan-ban committed Sep 14, 2023
1 parent fffcec6 commit dfb2428
Show file tree
Hide file tree
Showing 6 changed files with 814 additions and 790 deletions.
71 changes: 71 additions & 0 deletions common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <stdint.h>
#include <cuda_fp16.h>

constexpr int MAX_SEQ_LEN = 8192;

typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;

struct QWeight {
uint32_t* weight;
uint32_t* zeros;
half* scales;
};

struct PerLayerWeight {
half* rms_att_weight; // (layer, dim) rmsnorm weights
half* rms_ffn_weight; // (layer, dim)
QWeight wq_q;
QWeight wq_k;
QWeight wq_v;
QWeight wq_o;
QWeight wq_gate;
QWeight wq_up;
QWeight wq_down;
};

typedef struct {
// token embedding table
half* token_embedding_table; // (vocab_size, dim)
// classifier weights for the logits, on the last layer
half* wcls;
// final rmsnorm
half* rms_final_weight; // (dim,)
// Per layer weights
PerLayerWeight* layers;
int num_layers;
} TransformerWeights;

// data shared between CPU and GPU (allocated in host memory)
struct SharedData {
volatile int pos; // current token index
int tokens[MAX_SEQ_LEN]; // seq_len (tokens processed/generated so far) allocated in host memory so that CPU can read this
};

typedef struct {
// current wave of activations
half* x; // activation at current time stamp (dim,)
half* xb; // same, but inside a residual branch (dim,)
half* hb; // buffer for hidden dimension in the ffn (hidden_dim,)
half* hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
half* q; // query (dim,)
half* att; // buffer for scores/attention values (n_heads, seq_len)
half* logits; // output logits
// kv cache
half* key_cache; // (layer, seq_len, dim)
half* value_cache; // (layer, seq_len, dim)

int* pos; // GPU copy of the current position (just 1 element)
SharedData* shared_data;

float* logits_array; // array of output logits used to compute perplexity (seq_len, vocab_size)
} RunState;
Loading

0 comments on commit dfb2428

Please sign in to comment.