Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speculative : add tree-based sampling example #3624

Merged
merged 18 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
examples : fix build after sampling refactoring
ggml-ci
  • Loading branch information
ggerganov committed Oct 15, 2023
commit 7e48e21b1f61fb23000d0220c03c1afe192005ac
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
$(CXX) $(CXXFLAGS) -c $< -o $@

COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o

common.o: common/common.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
Expand Down
138 changes: 69 additions & 69 deletions common/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,75 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
return buf.str();
}

#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
[&tokens, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (const auto &token : tokens) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, token); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "'" << detokenized << "'" \
<< ":" << std::to_string(token); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()

#define LOG_BATCH_TOSTR_PRETTY(ctx, batch) \
[&batch, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (int i = 0; i < batch.n_tokens; ++i) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, batch.token[i]); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "\n" << std::to_string(i) \
<< ":token '" << detokenized << "'" \
<< ":pos " << std::to_string(batch.pos[i]) \
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i]) \
<< ":seq_id " << std::to_string(batch.seq_id[i][0]) \
<< ":logits " << std::to_string(batch.logits[i]); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()
template <typename C, typename T>
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
{
std::stringstream buf;
buf << "[ ";

bool first = true;
for (const auto &token : tokens)
{
if (!first) {
buf << ", ";
} else {
first = false;
}

auto detokenized = llama_token_to_piece(ctx, token);

detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf
<< "'" << detokenized << "'"
<< ":" << std::to_string(token);
}
buf << " ]";

return buf.str();
}

template <typename C, typename B>
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
{
std::stringstream buf;
buf << "[ ";

bool first = true;
for (int i = 0; i < batch.n_tokens; ++i)
{
if (!first) {
buf << ", ";
} else {
first = false;
}

auto detokenized = llama_token_to_piece(ctx, batch.token[i]);

detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf
<< "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
}
buf << " ]";

return buf.str();
}

#ifdef LOG_DISABLE_LOGS

Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
Expand Down
44 changes: 18 additions & 26 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ int main(int argc, char ** argv) {

LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());

// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}

// Tokenize negative prompt
Expand All @@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
Expand All @@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}

LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());


// enable interactive mode if interactive start is specified
Expand Down Expand Up @@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}

// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
Expand Down Expand Up @@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;

const int n_vocab = llama_n_vocab(model);

llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);

while (n_remain != 0 || params.interactive) {
// predict
Expand Down Expand Up @@ -484,7 +477,7 @@ int main(int argc, char ** argv) {

LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);

LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

}

Expand Down Expand Up @@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();

LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
Expand All @@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}

LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
Expand All @@ -554,12 +547,11 @@ int main(int argc, char ** argv) {

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);

last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx, id);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());

embd.push_back(id);

Expand All @@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
Expand Down Expand Up @@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {

// deal with eot token in infill mode
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
Expand Down Expand Up @@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (last_tokens.back() == llama_token_eos(ctx)) {
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");

if (params.interactive) {
Expand Down Expand Up @@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();

const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

Expand Down
Loading