Skip to content

Commit f866cb9

Browse files
committed
llama : move sampling rngs from common to llama
ggml-ci
1 parent 938943c commit f866cb9

22 files changed

+342
-344
lines changed

common/sampling.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
#define LLAMA_API_INTERNAL
21
#include "sampling.h"
2+
33
#include <random>
44

5-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
5+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) {
66
struct llama_sampling_context * result = new llama_sampling_context();
77

88
result->params = params;
9+
result->seq_id = seq_id;
10+
result->ctx = ctx;
911
result->grammar = nullptr;
1012

1113
// if there is a grammar, parse it
@@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
8183
if (seed == LLAMA_DEFAULT_SEED) {
8284
seed = std::random_device{}();
8385
}
84-
ctx->rng.seed(seed);
86+
llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
8587
}
8688

8789
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
@@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl(
271273
bool is_resampling) {
272274
const llama_sampling_params & params = ctx_sampling->params;
273275

274-
const float temp = params.temp;
275-
const int mirostat = params.mirostat;
276-
const float mirostat_tau = params.mirostat_tau;
277-
const float mirostat_eta = params.mirostat_eta;
276+
const float temp = params.temp;
277+
const int mirostat = params.mirostat;
278+
const float mirostat_tau = params.mirostat_tau;
279+
const float mirostat_eta = params.mirostat_eta;
278280

279281
std::vector<float> original_logits;
280282
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
@@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl(
304306

305307
sampler_queue(ctx_main, params, cur_p, min_keep);
306308

307-
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
309+
id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id);
308310

309311
//{
310312
// const int n_top = 10;

common/sampling.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ struct llama_sampling_context {
7070
// parameters that will be used for sampling
7171
llama_sampling_params params;
7272

73+
llama_seq_id seq_id;
74+
7375
// mirostat sampler state
7476
float mirostat_mu;
7577

78+
llama_context * ctx; // TMP
7679
llama_grammar * grammar;
7780

7881
// internal
@@ -81,15 +84,14 @@ struct llama_sampling_context {
8184
// TODO: replace with ring-buffer
8285
std::vector<llama_token> prev;
8386
std::vector<llama_token_data> cur;
84-
size_t n_valid; // Number of correct top tokens with correct probabilities.
8587

86-
std::mt19937 rng;
88+
size_t n_valid; // Number of correct top tokens with correct probabilities.
8789
};
8890

8991
#include "common.h"
9092

9193
// Create a new sampling context instance.
92-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
94+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id);
9395

9496
void llama_sampling_free(struct llama_sampling_context * ctx);
9597

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
#define LLAMA_API_INTERNAL
2-
31
#include "grammar-parser.h"
42
#include "ggml.h"
53
#include "llama.h"
4+
#include "llama-impl.h"
65
#include "unicode.h"
76

87
#include <cstdio>

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ int main(int argc, char ** argv) {
346346

347347
std::vector<llama_token> embd;
348348

349-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
349+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
350350

351351
while (n_remain != 0 || params.interactive) {
352352
// predict

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG_TEE("\n");
193193

194-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
194+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0);
195195
if (!ctx_sampling) {
196196
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

examples/lookahead/lookahead.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
118118
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
119119

120120
// target model sampling context
121-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
121+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
122122

123123
// verification n-grams
124124
std::vector<ngram_data> ngrams_cur(G);

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ int main(int argc, char ** argv){
106106

107107
bool has_eos = false;
108108

109-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
109+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
110110

111111
std::vector<llama_token> draft;
112112

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
527527
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
528528
}
529529

530-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
530+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
531531
if (!ctx_sampling) {
532532
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
533533
exit(1);

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
161161
for (size_t i = 0; i < clients.size(); ++i) {
162162
auto & client = clients[i];
163163
client.id = i;
164-
client.ctx_sampling = llama_sampling_init(params.sparams);
164+
client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i);
165165
}
166166

167167
std::vector<llama_token> tokens_system;

examples/quantize-stats/quantize-stats.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#define LLAMA_API_INTERNAL
21
#include "common.h"
32
#include "ggml.h"
43
#include "llama.h"
4+
#include "llama-impl.h"
55

66
#include <algorithm>
77
#include <cassert>

0 commit comments

Comments
 (0)