Skip to content

Commit

Permalink
Merge branch 'master' into delete-baichuan
Browse files Browse the repository at this point in the history
  • Loading branch information
Galunid committed Nov 18, 2023
2 parents 246a5a6 + 2923f17 commit e7f29df
Show file tree
Hide file tree
Showing 31 changed files with 507 additions and 144 deletions.
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,12 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GE
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
add_compile_options(-mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
add_compile_options(-mcpu=powerpc64le)
else()
add_compile_options(-mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
endif()
else()
message(STATUS "Unknown architecture")
endif()
Expand Down
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
endif
endif

ifneq ($(filter ppc64le%,$(UNAME_M)),)
MK_CFLAGS += -mcpu=powerpc64le
MK_CXXFLAGS += -mcpu=powerpc64le
CUDA_POWER_ARCH = 1
endif

else
MK_CFLAGS += -march=rv64gcv -mabi=lp64d
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
Expand Down Expand Up @@ -392,6 +398,8 @@ else
endif #LLAMA_CUDA_NVCC
ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else ifdef CUDA_POWER_ARCH
NVCCFLAGS +=
else
NVCCFLAGS += -arch=native
endif # CUDA_DOCKER_ARCH
Expand Down
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,12 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
return result;
}

bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);

return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}

//
// YAML utils
//
Expand Down Expand Up @@ -1188,6 +1194,7 @@ void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const cha
if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
data_str = "\"" + data_str + "\"";
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
return;
Expand Down
4 changes: 4 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);

// Uses the value from the model metadata if possible, otherwise
// defaults to true when model type is SPM, otherwise false.
bool llama_should_add_bos_token(const llama_model * model);

//
// YAML utils
//
Expand Down
12 changes: 12 additions & 0 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -1355,6 +1356,17 @@ bool consume_common_train_arg(
return true;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;
Expand Down
2 changes: 1 addition & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN
if arch == "FalconForCausalLM":
if arch in ("FalconForCausalLM", "RWForCausalLM"):
return gguf.MODEL_ARCH.FALCON
if arch == "GPTBigCodeForCausalLM":
return gguf.MODEL_ARCH.STARCODER
Expand Down
1 change: 0 additions & 1 deletion convert-llama-ggml-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import argparse
import math
import struct
import sys
from enum import IntEnum
Expand Down
1 change: 1 addition & 0 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
data_base_path=pickle_paths[0][:-4],
zip_file=zf)
model = unpickler.load()
if 'model' in model: model = model['model']
as_dict = dict(model.items())
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)

Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ else()
add_subdirectory(llama-bench)
add_subdirectory(llava)
add_subdirectory(main)
add_subdirectory(tokenize)
add_subdirectory(parallel)
add_subdirectory(perplexity)
add_subdirectory(quantize)
Expand Down
2 changes: 0 additions & 2 deletions examples/finetune/convert-finetune-checkpoint-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import argparse
import gguf
import os
import struct
import sys
import numpy as np
from pathlib import Path

Expand Down
35 changes: 12 additions & 23 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,35 +548,35 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);

randomize_tensor_normal(lora->tok_embeddings_a, rnd);
randomize_tensor_normal(lora->tok_embeddings_b, rnd);
ggml_set_zero(lora->tok_embeddings_b);
randomize_tensor_normal(lora->norm_a, rnd);
randomize_tensor_normal(lora->norm_b, rnd);
ggml_set_zero(lora->norm_b);
randomize_tensor_normal(lora->output_a, rnd);
randomize_tensor_normal(lora->output_b, rnd);
ggml_set_zero(lora->output_b);

for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = lora->layers[i];
randomize_tensor_normal(layer.attention_norm_a, rnd);
randomize_tensor_normal(layer.attention_norm_b, rnd);
ggml_set_zero(layer.attention_norm_b);

randomize_tensor_normal(layer.wq_a, rnd);
randomize_tensor_normal(layer.wq_b, rnd);
ggml_set_zero(layer.wq_b);
randomize_tensor_normal(layer.wk_a, rnd);
randomize_tensor_normal(layer.wk_b, rnd);
ggml_set_zero(layer.wk_b);
randomize_tensor_normal(layer.wv_a, rnd);
randomize_tensor_normal(layer.wv_b, rnd);
ggml_set_zero(layer.wv_b);
randomize_tensor_normal(layer.wo_a, rnd);
randomize_tensor_normal(layer.wo_b, rnd);
ggml_set_zero(layer.wo_b);

randomize_tensor_normal(layer.ffn_norm_a, rnd);
randomize_tensor_normal(layer.ffn_norm_b, rnd);
ggml_set_zero(layer.ffn_norm_b);

randomize_tensor_normal(layer.w1_a, rnd);
randomize_tensor_normal(layer.w1_b, rnd);
ggml_set_zero(layer.w1_b);
randomize_tensor_normal(layer.w2_a, rnd);
randomize_tensor_normal(layer.w2_b, rnd);
ggml_set_zero(layer.w2_b);
randomize_tensor_normal(layer.w3_a, rnd);
randomize_tensor_normal(layer.w3_b, rnd);
ggml_set_zero(layer.w3_b);
}

free_random_normal_distribution(rnd);
Expand Down Expand Up @@ -1460,17 +1460,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
}
params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->common.n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
LOG_TEE("%s\n", get_system_info(params).c_str());
}
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos);

bool suff_rm_leading_spc = params.escape;
Expand Down
3 changes: 2 additions & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,10 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
int n_past = 0;

const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));

// llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:"
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, true);
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, add_bos);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false);

Expand Down
9 changes: 8 additions & 1 deletion examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,14 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
fclose(file);
return false;
}
fread(buffer, 1, fileSize, file); // Read the file into the buffer
errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
}
fclose(file); // Close the file

*bytesOut = buffer;
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ int main(int argc, char ** argv) {
}
}

const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
Expand Down
8 changes: 3 additions & 5 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval

const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

Expand Down Expand Up @@ -288,8 +287,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval

const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);

auto tim1 = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -481,7 +479,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "================================= is_spm = %d\n", is_spm);

// This is needed as usual for LLaMA models
const bool add_bos = is_spm;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

// Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) {
Expand Down
9 changes: 6 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ struct llama_server_context
bool multimodal = false;
bool clean_kv_cache = true;
bool all_slots_are_idle = false;
bool add_bos_token = true;

int32_t id_gen;
int32_t n_ctx; // total context for all clients / slots
Expand Down Expand Up @@ -573,6 +574,8 @@ struct llama_server_context

n_ctx = llama_n_ctx(ctx);

add_bos_token = llama_should_add_bos_token(model);

return true;
}

Expand Down Expand Up @@ -864,7 +867,7 @@ struct llama_server_context
}

void update_system_prompt() {
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);

llama_batch_clear(batch);

Expand Down Expand Up @@ -1552,7 +1555,7 @@ struct llama_server_context
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}

slot.num_prompt_tokens = prompt_tokens.size();
Expand Down Expand Up @@ -1629,7 +1632,7 @@ struct llama_server_context
const bool has_images = process_images(slot);

// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
Expand Down
5 changes: 5 additions & 0 deletions examples/tokenize/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET tokenize)
add_executable(${TARGET} tokenize.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
44 changes: 44 additions & 0 deletions examples/tokenize/tokenize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

int main(int argc, char ** argv) {
if (argc < 3 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
return 1;
}

const char * model_path = argv[1];
const char * prompt = argv[2];

const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";

llama_backend_init(false);

llama_model_params model_params = llama_model_default_params();
model_params.vocab_only = true;
llama_model * model = llama_load_model_from_file(model_path, model_params);

llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params);

const bool add_bos = true;

std::vector<llama_token> tokens;

tokens = ::llama_tokenize(model, prompt, add_bos, true);

for (int i = 0; i < (int) tokens.size(); i++) {
if (printing_ids) {
printf("%d\n", tokens[i]);
} else {
printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
}
}

return 0;
}
Loading

0 comments on commit e7f29df

Please sign in to comment.