Skip to content

Commit

Permalink
llama : minor sampling refactor (2) (#9386)
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren authored Sep 9, 2024
1 parent 38ca6f6 commit 5fb5e24
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 102 deletions.
2 changes: 0 additions & 2 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ while n_cur <= n_len {

let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])

llama_sampler_accept(smpl, new_token_id)

// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
i_batch[i] = -1
Expand Down
2 changes: 0 additions & 2 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ int main(int argc, char ** argv) {

const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);

llama_sampler_accept(smpl, new_token_id);

// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1;
Expand Down
1 change: 0 additions & 1 deletion examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_decode(ctx, bat);

llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_sampler_accept(smpl, token);

if (token == eos_token) {
break;
Expand Down
2 changes: 0 additions & 2 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
// sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1);

llama_sampler_accept(sampler, new_token_id);

const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
return nullptr;
Expand Down
2 changes: 0 additions & 2 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ actor LlamaContext {

new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)

llama_sampler_accept(sampling, new_token_id)

if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
is_done = true
Expand Down
2 changes: 0 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);

llama_sampler_accept(smpl, new_token_id);

// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("\n");
Expand Down
6 changes: 0 additions & 6 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = llama_token_to_piece(ctx, next_token);

llama_sampler_accept(smpl, next_token);

printf("%s", next_token_str.c_str());
result0 += next_token_str;

Expand Down Expand Up @@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
auto next_token_str = llama_token_to_piece(ctx2, next_token);

llama_sampler_accept(smpl2, next_token);

printf("%s", next_token_str.c_str());
result1 += next_token_str;

Expand Down Expand Up @@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
auto next_token_str = llama_token_to_piece(ctx3, next_token);

llama_sampler_accept(smpl3, next_token);

printf("%s", next_token_str.c_str());
result2 += next_token_str;

Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ struct server_context {

gpt_params params;

llama_batch batch;
llama_batch batch = {};

bool clean_kv_cache = true;
bool add_bos_token = true;
Expand Down
2 changes: 0 additions & 2 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);

llama_sampler_accept(smpl, new_token_id);

// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
LOG_TEE("\n");
Expand Down
11 changes: 6 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1127,15 +1127,16 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// Shorthand for:
/// @details Sample and accept a token from the idx-th output of the last evaluation
//
// Shorthand for:
// const auto * logits = llama_get_logits_ith(ctx, idx);
// llama_token_data_array cur_p = { ... init from logits ... };
// llama_sampler_apply(smpl, &cur_p);
// return cur_p.data[cur_p.selected].id;
//
// At this point, this is mostly a convenience function.
//
// auto token = cur_p.data[cur_p.selected].id;
// llama_sampler_accept(smpl, token);
// return token;
// Returns the sampled token
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);

// TODO: extend in the future
Expand Down
Loading

0 comments on commit 5fb5e24

Please sign in to comment.