Skip to content
12 changes: 12 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_SAMPLERS;
}
).set_sparam());
add_opt(common_arg(
Expand Down Expand Up @@ -1261,27 +1262,31 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_TEMP;
}
).set_sparam());
add_opt(common_arg(
{"--top-k"}, "N",
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sampling.top_k = value;
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_TOP_K;
}
).set_sparam());
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_TOP_P;
}
).set_sparam());
add_opt(common_arg(
{"--min-p"}, "N",
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_MIN_P;
}
).set_sparam());
add_opt(common_arg(
Expand All @@ -1296,13 +1301,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_XTC_PROBABILITY;
}
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_XTC_THRESHOLD;
}
).set_sparam());
add_opt(common_arg(
Expand All @@ -1321,13 +1328,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_PENALTY_LAST_N;
}
).set_sparam());
add_opt(common_arg(
{"--repeat-penalty"}, "N",
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_PENALTY_REPEAT;
}
).set_sparam());
add_opt(common_arg(
Expand Down Expand Up @@ -1425,20 +1434,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sampling.mirostat = value;
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT;
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-lr"}, "N",
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT_ETA;
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-ent"}, "N",
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value);
params.sampling.sampling_mask |= common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT_TAU;
}
).set_sparam());
add_opt(common_arg(
Expand Down
55 changes: 55 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"

#include <algorithm>
#include <cinttypes>
Expand Down Expand Up @@ -946,6 +947,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//

static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {

const uint16_t mask = sparams.sampling_mask;

auto get_int32 = [&](const char * key, int32_t & dst, uint16_t user_override) {
if (mask & user_override) return;

char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};

auto get_float = [&](const char * key, float & dst, uint16_t user_override) {
if (mask & user_override) return;

char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};

// Sampler sequence
if (!(mask & common_params_sampling::SAMPLING_MASK_BITS_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, "general.sampler.sequence", buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}

get_int32("general.sampler.top_k", sparams.top_k, common_params_sampling::SAMPLING_MASK_BITS_TOP_K);
get_float("general.sampler.top_p", sparams.top_p, common_params_sampling::SAMPLING_MASK_BITS_TOP_P);
get_float("general.sampler.min_p", sparams.min_p, common_params_sampling::SAMPLING_MASK_BITS_MIN_P);
get_float("general.sampler.xtc_probability", sparams.xtc_probability, common_params_sampling::SAMPLING_MASK_BITS_XTC_PROBABILITY);
get_float("general.sampler.xtc_threshold", sparams.xtc_threshold, common_params_sampling::SAMPLING_MASK_BITS_XTC_THRESHOLD);
get_float("general.sampler.temp", sparams.temp, common_params_sampling::SAMPLING_MASK_BITS_TEMP);
get_int32("general.sampler.penalty_last_n", sparams.penalty_last_n, common_params_sampling::SAMPLING_MASK_BITS_PENALTY_LAST_N);
get_float("general.sampler.penalty_repeat", sparams.penalty_repeat, common_params_sampling::SAMPLING_MASK_BITS_PENALTY_REPEAT);
get_int32("general.sampler.mirostat", sparams.mirostat, common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT);
get_float("general.sampler.mirostat_tau", sparams.mirostat_tau, common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT_TAU);
get_float("general.sampler.mirostat_eta", sparams.mirostat_eta, common_params_sampling::SAMPLING_MASK_BITS_MIROSTAT_ETA);
}

struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
Expand All @@ -957,6 +1010,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}

common_init_sampler_from_model(model, params.sampling);

const llama_vocab * vocab = llama_model_get_vocab(model);

auto cparams = common_context_params_to_llama(params);
Expand Down
16 changes: 16 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;

uint16_t sampling_mask = 0; // bitfield to track user-specified samplers
enum sampling_mask_bits : uint16_t {
SAMPLING_MASK_BITS_SAMPLERS = 1 << 0,
SAMPLING_MASK_BITS_TOP_K = 1 << 1,
SAMPLING_MASK_BITS_TOP_P = 1 << 2,
SAMPLING_MASK_BITS_MIN_P = 1 << 3,
SAMPLING_MASK_BITS_XTC_PROBABILITY = 1 << 4,
SAMPLING_MASK_BITS_XTC_THRESHOLD = 1 << 5,
SAMPLING_MASK_BITS_TEMP = 1 << 6,
SAMPLING_MASK_BITS_PENALTY_LAST_N = 1 << 7,
SAMPLING_MASK_BITS_PENALTY_REPEAT = 1 << 8,
SAMPLING_MASK_BITS_MIROSTAT = 1 << 9,
SAMPLING_MASK_BITS_MIROSTAT_TAU = 1 << 10,
SAMPLING_MASK_BITS_MIROSTAT_ETA = 1 << 11,
};

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY


Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ class General:
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"

# Recommended Sampler Parameters
SAMPLER_SEQUENCE = "general.sampler.sequence"
SAMPLER_TOP_K = "general.sampler.top_k"
SAMPLER_TOP_P = "general.sampler.top_p"
SAMPLER_MIN_P = "general.sampler.min_p"
SAMPLER_XTC_PROBABILITY = "general.sampler.xtc_probability"
SAMPLER_XTC_THRESHOLD = "general.sampler.xtc_threshold"
SAMPLER_TEMP = "general.sampler.temp"
SAMPLER_PENALTY_LAST_N = "general.sampler.penalty_last_n"
SAMPLER_PENALTY_REPEAT = "general.sampler.penalty_repeat"
SAMPLER_MIROSTAT = "general.sampler.mirostat"
SAMPLER_MIROSTAT_TAU = "general.sampler.mirostat_tau"
SAMPLER_MIROSTAT_ETA = "general.sampler.mirostat_eta"

# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"
Expand Down
36 changes: 36 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,42 @@ def add_custom_alignment(self, alignment: int) -> None:
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)

def add_sampler_sequence(self, sequence: str) -> None:
self.add_string(Keys.General.SAMPLER_SEQUENCE, sequence)

def add_sampler_top_k(self, top_k: int) -> None:
self.add_int32(Keys.General.SAMPLER_TOP_K, top_k)

def add_sampler_top_p(self, top_p: float) -> None:
self.add_float32(Keys.General.SAMPLER_TOP_P, top_p)

def add_sampler_min_p(self, min_p: float) -> None:
self.add_float32(Keys.General.SAMPLER_MIN_P, min_p)

def add_sampler_xtc_probability(self, xtc_probability: float) -> None:
self.add_float32(Keys.General.SAMPLER_XTC_PROBABILITY, xtc_probability)

def add_sampler_xtc_threshold(self, xtc_threshold: float) -> None:
self.add_float32(Keys.General.SAMPLER_XTC_THRESHOLD, xtc_threshold)

def add_sampler_temp(self, temp: float) -> None:
self.add_float32(Keys.General.SAMPLER_TEMP, temp)

def add_sampler_penalty_last_n(self, penalty_last_n: int) -> None:
self.add_int32(Keys.General.SAMPLER_PENALTY_LAST_N, penalty_last_n)

def add_sampler_penalty_repeat(self, penalty_repeat: float) -> None:
self.add_float32(Keys.General.SAMPLER_PENALTY_REPEAT, penalty_repeat)

def add_sampler_mirostat(self, mirostat: int) -> None:
self.add_int32(Keys.General.SAMPLER_MIROSTAT, mirostat)

def add_sampler_mirostat_tau(self, mirostat_tau: float) -> None:
self.add_float32(Keys.General.SAMPLER_MIROSTAT_TAU, mirostat_tau)

def add_sampler_mirostat_eta(self, mirostat_eta: float) -> None:
self.add_float32(Keys.General.SAMPLER_MIROSTAT_ETA, mirostat_eta)

def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)

Expand Down
Loading
Loading