Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm : add Falcon support #2717

Merged
merged 38 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4ed3469
llama : refactor GGUF constants into static maps
ggerganov Aug 22, 2023
8bd7f06
llama : check if model architecture is known
ggerganov Aug 22, 2023
3057d6a
llama : refactor llama_model_load_internal()
ggerganov Aug 22, 2023
3c025a6
gguf : add KV constant maps
ggerganov Aug 22, 2023
b19c6e4
Merge branch 'master' into falcon
ggerganov Aug 22, 2023
9f28f73
llm : read arch-specific KVs
ggerganov Aug 22, 2023
d1b3b95
convert : add dummy scores + types
ggerganov Aug 22, 2023
2f3c80a
falcon : load tensor data (CPU only)
ggerganov Aug 22, 2023
5c5413d
llama : fix loading progress bar
ggerganov Aug 22, 2023
085228e
llama : add arch member to llama_model
ggerganov Aug 22, 2023
3c7c325
falcon : CPU inference working
ggerganov Aug 22, 2023
2d58444
falcon : support non-40B models
ggerganov Aug 22, 2023
0ec27ad
falcon : minor
ggerganov Aug 22, 2023
7bbbf38
llama : minor updates
ggerganov Aug 22, 2023
9853f2c
convert-falcon-hf-to-gguf.py : fix special token mapping
klosax Aug 22, 2023
ffa5099
llama.cpp : llama default UNK token = id 0
klosax Aug 22, 2023
a95ae75
llama.cpp : fix bpe tokenizer
klosax Aug 22, 2023
d561b7f
llama.cpp : fix the fix of bpe tokenizer
klosax Aug 22, 2023
e3c52bd
ggml : pass eps to ggml_norm
ggerganov Aug 23, 2023
99bb260
metal : implement RoPE (mode = 2) + avoid ggml_repeat
ggerganov Aug 23, 2023
af4bbcc
ggml : ggml_repeat always creates new tensor
ggerganov Aug 23, 2023
b34ab74
falcon : copy-paste self-attention from LLaMA
ggerganov Aug 23, 2023
a0dc47a
metal : print extra compute pipeline info
ggerganov Aug 23, 2023
e2d23be
falcon : minor changes (still chasing the Metal problem)
ggerganov Aug 23, 2023
b693000
llama.cpp : fix linefeed token
klosax Aug 23, 2023
0a85ae7
metal : fix GELU kernel numerical stability by using precise::tanh
ggerganov Aug 23, 2023
854ae5d
metal : temporary workaround for the concurrency optimization bug
ggerganov Aug 23, 2023
e729965
falcon : add CUDA offloading (#2739)
slaren Aug 23, 2023
176ea71
llama : better model naming and size reporting
ggerganov Aug 23, 2023
6938c5f
Merge branch 'master' into falcon
ggerganov Aug 23, 2023
c3f8a6e
llama : prep new tokenizer support
ggerganov Aug 23, 2023
3bfb720
llama : advanced BPE tokenizer based on ggllm.cpp imlpementation
ggerganov Aug 23, 2023
2424e1d
llama : remove oboslete comment
ggerganov Aug 23, 2023
596e109
common : remove obsolete BPE API + disable test-tokenizer-1
ggerganov Aug 23, 2023
f8ee54b
llama : revert BPE special-case in llama_byte_to_token()
ggerganov Aug 23, 2023
8c6d393
cuda : add TODOs for RoPE NeoX implementation
ggerganov Aug 23, 2023
630d8b4
llama : default special tokens based on vocab type
ggerganov Aug 23, 2023
fae8faa
perplexity : add log for start of tokenization
ggerganov Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok

return std::string(result.data(), result.size());
}

std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos) {
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}

std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}

return std::string(result.data(), result.size());
}

9 changes: 0 additions & 9 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
const std::string & text,
bool add_bos);

std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos);

std::string llama_token_to_str(
const struct llama_context * ctx,
llama_token token);

std::string llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token);
55 changes: 25 additions & 30 deletions convert-falcon-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,26 @@ def count_model_parts(dir_model: str) -> int:

block_count = hparams["n_layer"]

gguf_writer.add_name(last_dir)
gguf_writer.add_name("Falcon")
gguf_writer.add_context_length(2048) # not in config.json
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"])
if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"])
if "n_head_kv" in hparams:
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
else:
gguf_writer.add_head_count_kv(1)
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])

# TOKENIZATION

print("gguf: get tokenizer metadata")

tokens: List[str] = []
scores: List[float] = []
toktypes: List[int] = []
merges: List[str] = []


Expand Down Expand Up @@ -152,50 +157,40 @@ def count_model_parts(dir_model: str) -> int:
text = bytearray(pad_token)

tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy

gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
print("gguf: get special token ids")

with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
print("gguf: get special token ids")
# Look for special tokens in config.json

# find special token ids
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])

if "bos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]:
gguf_writer.add_bos_token_id(key["id"])
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])

if "eos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]:
gguf_writer.add_eos_token_id(key["id"])
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])

if "unk_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]:
gguf_writer.add_unk_token_id(key["id"])
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])

if "sep_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]:
gguf_writer.add_sep_token_id(key["id"])

if "pad_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]:
gguf_writer.add_pad_token_id(key["id"])
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])


# TENSORS

tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# params for qkv transform
n_head = hparams["n_head"]
n_head = hparams["n_head"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1

head_dim = hparams["hidden_size"] // n_head

# tensor info
Expand Down
6 changes: 5 additions & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,11 @@ def __init__(self, fname_out: Path) -> None:
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])

def add_meta_arch(self, params: Params) -> None:
self.gguf.add_name ("LLaMA")
ver = None
if (params.n_ctx == 4096):
ver = "v2"

self.gguf.add_name ("LLaMA" if ver == None else "LLaMA " + ver)
self.gguf.add_context_length (params.n_ctx)
self.gguf.add_embedding_length (params.n_embd)
self.gguf.add_block_count (params.n_layer)
Expand Down
14 changes: 8 additions & 6 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static bool is_interacting = false;
void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
is_interacting = true;
} else {
console::cleanup();
printf("\n");
Expand Down Expand Up @@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
}
}

const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;

// tokenize the prompt
std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
} else {
embd_inp = session_tokens;
}
Expand All @@ -203,9 +205,9 @@ int main(int argc, char ** argv) {
int original_prompt_len = 0;
if (ctx_guidance) {
params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
}
Expand Down Expand Up @@ -252,8 +254,8 @@ int main(int argc, char ** argv) {
}

// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);

// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
Expand Down
31 changes: 21 additions & 10 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
}

void perplexity_v2(llama_context * ctx, const gpt_params & params) {

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
Expand All @@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return;
}
auto tokens = ::llama_tokenize(ctx, params.prompt, true);

const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;

fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);

const int calc_chunk = params.n_ctx;

Expand Down Expand Up @@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
const auto token_org = tokens[batch_start];

// add BOS token for the first batch of each chunk
if (j == 0) {
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx);
}

Expand Down Expand Up @@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
}

void perplexity(llama_context * ctx, const gpt_params & params) {

if (params.ppl_stride > 0) {
perplexity_v2(ctx, params);
return;
Expand All @@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
auto tokens = ::llama_tokenize(ctx, params.prompt, true);

const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;

fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);

const int n_chunk_max = tokens.size() / params.n_ctx;

Expand Down Expand Up @@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
const auto token_org = tokens[batch_start];

// add BOS token for the first batch of each chunk
if (j == 0) {
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx);
}

Expand Down Expand Up @@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);

const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;

// This is needed as usual for LLaMA models
bool prepend_bos = true;
const bool add_bos = is_spm;

// Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) {
Expand Down Expand Up @@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
std::vector<float> tok_logits(n_vocab);

for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {

// Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size();

// Do the 1st ending
// In this case we include the context when evaluating
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_size = query_embd.size();
//printf("First query: %d\n",(int)query_size);

Expand Down
4 changes: 2 additions & 2 deletions ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
alloc->n_free_blocks++;
}

void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
int pos = 0;
for (int i = 0; i < n; i++) {
if (list[i] != -1) {
Expand Down Expand Up @@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_tensor * view_src = get_view_source(parent);
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
AT_PRINTF("view_src %s\n", view_src->name);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_allocator_free_tensor(alloc, view_src);
}
Expand Down
2 changes: 1 addition & 1 deletion ggml-alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);

// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);

GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
Expand Down
29 changes: 28 additions & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

// TODO: this implementation is wrong!
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
// const float p_delta, const int p_delta_rows, const float theta_scale) {
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
//
// if (col >= ncols) {
// return;
// }
//
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
// const int i = row*ncols + col/2;
//
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
// const float sin_theta = sinf(theta);
// const float cos_theta = cosf(theta);
//
// const float x0 = x[i + 0];
// const float x1 = x[i + ncols/2];
//
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
//}

static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
Expand Down Expand Up @@ -5515,14 +5538,18 @@ inline void ggml_cuda_op_rope(

const float theta_scale = powf(freq_base, -2.0f/n_dims);

const bool is_glm = mode & 4;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

// compute
if (is_glm) {
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
const float id_p = min(p, n_ctx - 2.f);
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else if (is_neox) {
GGML_ASSERT(false && "RoPE NeoX not implemented yet");
#pragma message("TODO: implement RoPE NeoX for CUDA")
} else {
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
Expand Down
Loading