Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement classifier-free guidance #2135

Merged
merged 10 commits into from
Jul 11, 2023
30 changes: 29 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--cfg-smooth-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_smooth_factor = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -468,6 +486,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
Expand Down Expand Up @@ -534,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess llama_context_params_from_gpt_params() should fit better.
We tend to use get and set to access properties, while here we construct context_params

auto lparams = llama_context_default_params();

lparams.n_ctx = params.n_ctx;
Expand All @@ -550,6 +572,12 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

return lparams;
}

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_get_context_params_from_gpt_params(params);

llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
Expand Down
7 changes: 7 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ struct gpt_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate

// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // How strong is guidance
float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits

std::string model = "models/7B/ggml-model.bin"; // model path
std::string model_alias = "unknown"; // model alias
std::string prompt = "";
Expand Down Expand Up @@ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
//

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params);

//
// Console utils
Expand Down
88 changes: 84 additions & 4 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,16 @@ int main(int argc, char ** argv) {

llama_model * model;
llama_context * ctx;
llama_context * guidance_ctx = NULL;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to ctx_guidance

g_ctx = &ctx;

// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params);
guidance_ctx = llama_new_context_with_model(model, lparams);
}

if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
Expand Down Expand Up @@ -183,15 +189,28 @@ int main(int argc, char ** argv) {
// tokenize the prompt
std::vector<llama_token> embd_inp;

if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');

if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else {
embd_inp = session_tokens;
}

// Tokenize negative prompt
std::vector<llama_token> guidance_inp;
int guidance_offset = 0;
int original_prompt_len = 0;
if (guidance_ctx) {
params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true);

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
}

const int n_ctx = llama_n_ctx(ctx);

if ((int) embd_inp.size() > n_ctx - 4) {
Expand Down Expand Up @@ -258,6 +277,16 @@ int main(int argc, char ** argv) {
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
}

if (guidance_ctx) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
}
}

if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
Expand Down Expand Up @@ -334,11 +363,13 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
int guidance_n_past = 0;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int guidance_n_past = 0;
int n_past_guidance = 0;


// the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT);

std::vector<llama_token> embd;
std::vector<llama_token> guidance_embd;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<llama_token> guidance_embd;
std::vector<llama_token> embd_guidance;


// do one empty run to warm up the model
{
Expand Down Expand Up @@ -367,11 +398,12 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() > n_ctx) {
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
const int n_left = n_past - params.n_keep;

// always keep the first token - BOS
n_past = std::max(1, params.n_keep);
guidance_n_past = std::max(1, params.n_keep + guidance_offset);

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
Expand Down Expand Up @@ -412,6 +444,48 @@ int main(int argc, char ** argv) {

// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always

if (guidance_ctx) {
int input_size = 0;
llama_token* input_buf = NULL;

if (guidance_n_past < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
// * Replace the initial prompt
// * Shift everything by guidance_offset
guidance_embd = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) {
guidance_embd.insert(
guidance_embd.end(),
embd.begin() + original_prompt_len,
embd.end()
);
}

input_buf = guidance_embd.data();
input_size = guidance_embd.size();
//fprintf(stderr, "\n---------------------\n");
//for (int i = 0; i < (int) guidance_embd.size(); i++) {
//fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i]));
//}
//fprintf(stderr, "\n---------------------\n");
} else {
input_buf = embd.data();
input_size = embd.size();
}

for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}

guidance_n_past += n_eval;
}
}

for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
Expand All @@ -431,6 +505,7 @@ int main(int argc, char ** argv) {
}

embd.clear();
guidance_embd.clear();

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
Expand Down Expand Up @@ -473,6 +548,10 @@ int main(int argc, char ** argv) {

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

if (guidance_ctx) {
llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor);
}

// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
Expand Down Expand Up @@ -668,6 +747,7 @@ int main(int argc, char ** argv) {
}

llama_print_timings(ctx);
if (guidance_ctx) { llama_free(guidance_ctx); }
llama_free(ctx);
llama_free_model(model);

Expand Down
69 changes: 69 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,75 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}

template<typename T, typename LogitAccessor>
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
T* element = std::max_element(
array, array + size,
[&logit_accessor](T& lhs, T& rhs) {
return logit_accessor(lhs) < logit_accessor(rhs);
}
);

float max_l = logit_accessor(*element);
float sum = 0.f;
for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
float p = expf(logit - max_l);
sum += p;
logit = p;
}

for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
logit = logf(logit / sum);
}
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid the template. You can copy the logits in a std::vector<float> and use float * array implementation in both cases

void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale,
float smooth_factor) {
int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();

assert(ctx);
auto n_vocab = llama_n_vocab(ctx);
assert(n_vocab == (int)candidates->size);
assert(!candidates->sorted);

auto logit_from_token_data = [](llama_token_data& data) -> float& {
return data.logit;
};

auto logit_from_float = [](float& item) -> float& {
return item;
};

llama_log_softmax(candidates->data, candidates->size, logit_from_token_data);

auto* guidance_logits = llama_get_logits(guidance_ctx);
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);

for (int i = 0; i < n_vocab; ++i) {
float guidance_logit = guidance_logits[i];
float base_logit = candidates->data[i].logit;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float guidance_logit = guidance_logits[i];
float base_logit = candidates->data[i].logit;
float logit_guidance = guidance_logits[i];
float logit_base = candidates->data[i].logit;

guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit;
}

llama_log_softmax(guidance_logits, n_vocab, logit_from_float);

for (int i = 0; i < n_vocab; ++i) {
float base_logit = candidates->data[i].logit;
float guidance_logit = guidance_logits[i];

candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit;
}

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}

llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx);
Expand Down
12 changes: 12 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ extern "C" {
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);

/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
LLAMA_API void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale,
float smooth_factor);

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);

Expand Down