-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement classifier-free guidance #2135
Changes from 9 commits
d09d5ed
4786300
8ba5b13
8f91b52
114d4c5
422a7ff
66eb048
8e66e59
325fc88
abf164d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -109,10 +109,16 @@ int main(int argc, char ** argv) { | |||||
|
||||||
llama_model * model; | ||||||
llama_context * ctx; | ||||||
llama_context * guidance_ctx = NULL; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to |
||||||
g_ctx = &ctx; | ||||||
|
||||||
// load the model and apply lora adapter, if any | ||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||||||
if (params.cfg_scale > 1.f) { | ||||||
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params); | ||||||
guidance_ctx = llama_new_context_with_model(model, lparams); | ||||||
} | ||||||
|
||||||
if (model == NULL) { | ||||||
fprintf(stderr, "%s: error: unable to load model\n", __func__); | ||||||
return 1; | ||||||
|
@@ -183,15 +189,28 @@ int main(int argc, char ** argv) { | |||||
// tokenize the prompt | ||||||
std::vector<llama_token> embd_inp; | ||||||
|
||||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||||
// Add a space in front of the first character to match OG llama tokenizer behavior | ||||||
params.prompt.insert(0, 1, ' '); | ||||||
// Add a space in front of the first character to match OG llama tokenizer behavior | ||||||
params.prompt.insert(0, 1, ' '); | ||||||
|
||||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
} else { | ||||||
embd_inp = session_tokens; | ||||||
} | ||||||
|
||||||
// Tokenize negative prompt | ||||||
std::vector<llama_token> guidance_inp; | ||||||
int guidance_offset = 0; | ||||||
int original_prompt_len = 0; | ||||||
if (guidance_ctx) { | ||||||
params.cfg_negative_prompt.insert(0, 1, ' '); | ||||||
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); | ||||||
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
original_prompt_len = original_inp.size(); | ||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len; | ||||||
} | ||||||
|
||||||
const int n_ctx = llama_n_ctx(ctx); | ||||||
|
||||||
if ((int) embd_inp.size() > n_ctx - 4) { | ||||||
|
@@ -258,6 +277,16 @@ int main(int argc, char ** argv) { | |||||
for (int i = 0; i < (int) embd_inp.size(); i++) { | ||||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); | ||||||
} | ||||||
|
||||||
if (guidance_ctx) { | ||||||
fprintf(stderr, "\n"); | ||||||
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); | ||||||
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); | ||||||
for (int i = 0; i < (int) guidance_inp.size(); i++) { | ||||||
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); | ||||||
} | ||||||
} | ||||||
|
||||||
if (params.n_keep > 0) { | ||||||
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); | ||||||
for (int i = 0; i < params.n_keep; i++) { | ||||||
|
@@ -334,11 +363,13 @@ int main(int argc, char ** argv) { | |||||
int n_remain = params.n_predict; | ||||||
int n_consumed = 0; | ||||||
int n_session_consumed = 0; | ||||||
int guidance_n_past = 0; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
// the first thing we will do is to output the prompt, so set color accordingly | ||||||
console_set_color(con_st, CONSOLE_COLOR_PROMPT); | ||||||
|
||||||
std::vector<llama_token> embd; | ||||||
std::vector<llama_token> guidance_embd; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
// do one empty run to warm up the model | ||||||
{ | ||||||
|
@@ -367,11 +398,12 @@ int main(int argc, char ** argv) { | |||||
// if we run out of context: | ||||||
// - take the n_keep first tokens from the original prompt (via n_past) | ||||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches | ||||||
if (n_past + (int) embd.size() > n_ctx) { | ||||||
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { | ||||||
const int n_left = n_past - params.n_keep; | ||||||
|
||||||
// always keep the first token - BOS | ||||||
n_past = std::max(1, params.n_keep); | ||||||
guidance_n_past = std::max(1, params.n_keep + guidance_offset); | ||||||
|
||||||
// insert n_left/2 tokens at the start of embd from last_n_tokens | ||||||
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); | ||||||
|
@@ -412,6 +444,48 @@ int main(int argc, char ** argv) { | |||||
|
||||||
// evaluate tokens in batches | ||||||
// embd is typically prepared beforehand to fit within a batch, but not always | ||||||
|
||||||
if (guidance_ctx) { | ||||||
int input_size = 0; | ||||||
llama_token* input_buf = NULL; | ||||||
|
||||||
if (guidance_n_past < (int) guidance_inp.size()) { | ||||||
// Guidance context should have the same data with these modifications: | ||||||
// | ||||||
// * Replace the initial prompt | ||||||
// * Shift everything by guidance_offset | ||||||
guidance_embd = guidance_inp; | ||||||
if (embd.begin() + original_prompt_len < embd.end()) { | ||||||
guidance_embd.insert( | ||||||
guidance_embd.end(), | ||||||
embd.begin() + original_prompt_len, | ||||||
embd.end() | ||||||
); | ||||||
} | ||||||
|
||||||
input_buf = guidance_embd.data(); | ||||||
input_size = guidance_embd.size(); | ||||||
//fprintf(stderr, "\n---------------------\n"); | ||||||
//for (int i = 0; i < (int) guidance_embd.size(); i++) { | ||||||
//fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); | ||||||
//} | ||||||
//fprintf(stderr, "\n---------------------\n"); | ||||||
} else { | ||||||
input_buf = embd.data(); | ||||||
input_size = embd.size(); | ||||||
} | ||||||
|
||||||
for (int i = 0; i < input_size; i += params.n_batch) { | ||||||
int n_eval = std::min(input_size - i, params.n_batch); | ||||||
if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) { | ||||||
fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
return 1; | ||||||
} | ||||||
|
||||||
guidance_n_past += n_eval; | ||||||
} | ||||||
} | ||||||
|
||||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) { | ||||||
int n_eval = (int) embd.size() - i; | ||||||
if (n_eval > params.n_batch) { | ||||||
|
@@ -431,6 +505,7 @@ int main(int argc, char ** argv) { | |||||
} | ||||||
|
||||||
embd.clear(); | ||||||
guidance_embd.clear(); | ||||||
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||||
// out of user input, sample next token | ||||||
|
@@ -473,6 +548,10 @@ int main(int argc, char ** argv) { | |||||
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||||
|
||||||
if (guidance_ctx) { | ||||||
llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); | ||||||
} | ||||||
|
||||||
// Apply penalties | ||||||
float nl_logit = logits[llama_token_nl()]; | ||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||||
|
@@ -668,6 +747,7 @@ int main(int argc, char ** argv) { | |||||
} | ||||||
|
||||||
llama_print_timings(ctx); | ||||||
if (guidance_ctx) { llama_free(guidance_ctx); } | ||||||
llama_free(ctx); | ||||||
llama_free_model(model); | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2141,6 +2141,75 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template<typename T, typename LogitAccessor> | ||||||||||
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { | ||||||||||
T* element = std::max_element( | ||||||||||
array, array + size, | ||||||||||
[&logit_accessor](T& lhs, T& rhs) { | ||||||||||
return logit_accessor(lhs) < logit_accessor(rhs); | ||||||||||
} | ||||||||||
); | ||||||||||
|
||||||||||
float max_l = logit_accessor(*element); | ||||||||||
float sum = 0.f; | ||||||||||
for (int i = 0; i < size; ++i) { | ||||||||||
float& logit = logit_accessor(array[i]); | ||||||||||
float p = expf(logit - max_l); | ||||||||||
sum += p; | ||||||||||
logit = p; | ||||||||||
} | ||||||||||
|
||||||||||
for (int i = 0; i < size; ++i) { | ||||||||||
float& logit = logit_accessor(array[i]); | ||||||||||
logit = logf(logit / sum); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid the |
||||||||||
void llama_sample_classifier_free_guidance( | ||||||||||
struct llama_context * ctx, | ||||||||||
llama_token_data_array * candidates, | ||||||||||
struct llama_context * guidance_ctx, | ||||||||||
float scale, | ||||||||||
float smooth_factor) { | ||||||||||
int64_t t_start_sample_us = t_start_sample_us = ggml_time_us(); | ||||||||||
|
||||||||||
assert(ctx); | ||||||||||
auto n_vocab = llama_n_vocab(ctx); | ||||||||||
assert(n_vocab == (int)candidates->size); | ||||||||||
assert(!candidates->sorted); | ||||||||||
|
||||||||||
auto logit_from_token_data = [](llama_token_data& data) -> float& { | ||||||||||
return data.logit; | ||||||||||
}; | ||||||||||
|
||||||||||
auto logit_from_float = [](float& item) -> float& { | ||||||||||
return item; | ||||||||||
}; | ||||||||||
|
||||||||||
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data); | ||||||||||
|
||||||||||
auto* guidance_logits = llama_get_logits(guidance_ctx); | ||||||||||
llama_log_softmax(guidance_logits, n_vocab, logit_from_float); | ||||||||||
|
||||||||||
for (int i = 0; i < n_vocab; ++i) { | ||||||||||
float guidance_logit = guidance_logits[i]; | ||||||||||
float base_logit = candidates->data[i].logit; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit; | ||||||||||
} | ||||||||||
|
||||||||||
llama_log_softmax(guidance_logits, n_vocab, logit_from_float); | ||||||||||
|
||||||||||
for (int i = 0; i < n_vocab; ++i) { | ||||||||||
float base_logit = candidates->data[i].logit; | ||||||||||
float guidance_logit = guidance_logits[i]; | ||||||||||
|
||||||||||
candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; | ||||||||||
} | ||||||||||
|
||||||||||
if (ctx) { | ||||||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { | ||||||||||
assert(ctx); | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess
llama_context_params_from_gpt_params()
should fit better.We tend to use
get
andset
to access properties, while here we constructcontext_params