Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Refactor babyllama example to use llama2.c as a submodule #2911

Merged
merged 1 commit into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "cpp/third-party/llama.cpp"]
path = cpp/third-party/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "cpp/third-party/llama2.c"]
path = cpp/third-party/llama2.c
url = https://github.com/karpathy/llama2.c
1 change: 1 addition & 0 deletions cpp/third-party/llama2.c
Submodule llama2.c added at d98620
7 changes: 4 additions & 3 deletions examples/cpp/babyllama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

add_library(babyllama_handler SHARED src/baby_llama_handler.cc)
add_library(llama2_c STATIC ../../../cpp/third-party/llama2.c/run.c)
target_compile_options(llama2_c PRIVATE -Wall -Wextra -Ofast -fPIC)

target_link_libraries(babyllama_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES})
target_compile_options(babyllama_handler PRIVATE -Wall -Wextra -Ofast)
add_library(babyllama_handler SHARED src/baby_llama_handler.cc)
target_link_libraries(babyllama_handler PRIVATE llama2_c ts_backends_core ts_utils ${TORCH_LIBRARIES})
4 changes: 3 additions & 1 deletion examples/cpp/babyllama/src/baby_llama_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

#include <typeinfo>

#include "llama2.c/run.c"
extern "C" {
#include "llama2.c/llama2.h"
}

namespace llm {

Expand Down
113 changes: 113 additions & 0 deletions examples/cpp/babyllama/src/llama2.c/llama2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
// ----------------------------------------------------------------------------
// Transformer model

typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;

typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
float* wq; // (layer, dim, n_heads * head_size)
float* wk; // (layer, dim, n_kv_heads * head_size)
float* wv; // (layer, dim, n_kv_heads * head_size)
float* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;

typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *xb2; // an additional buffer just for convenience (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;

typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
char *str;
int id;
} TokenIndex;

typedef struct {
char** vocab;
float* vocab_scores;
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;

// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling

typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling

typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
void build_transformer(Transformer *t, char* checkpoint_path);
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size);
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed);
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens);
float* forward(Transformer* transformer, int token, int pos);
int sample(Sampler* sampler, float* logits);
long time_in_ms();
char* decode(Tokenizer* t, int prev_token, int token);
void free_sampler(Sampler* sampler);
void free_tokenizer(Tokenizer* t);
void free_transformer(Transformer* t);
Loading