@@ -37,8 +37,8 @@ int main(int argc, char ** argv) {
3737 const int n_seq_dft = params.n_parallel ;
3838
3939 // TODO: make this configurable
40- const float p_accept = 0 .4f ;
41- const float p_split = 0 .3f ;
40+ const float p_accept = 0 .80f ;
41+ const float p_split = 0 .10f ;
4242
4343#ifndef LOG_DISABLE_LOGS
4444 log_set_target (log_filename_generator (" speculative" , " log" ));
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
118118 std::vector<seq_draft> drafts (n_seq_dft);
119119
120120 params.grammar .clear (); // the draft samplers will copy the target sampler's grammar
121- params.sampling_params .temp = 1 . 0f ; // the draft samplers use default temperature
121+ params.sampling_params .temp = std::max ( 0 . 01f , params. sampling_params . temp );
122122
123123 for (int s = 0 ; s < n_seq_dft; ++s) {
124124 drafts[s].ctx_sampling = llama_sampling_init (params);
@@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
156156
157157 llama_sampling_accept (ctx_sampling, ctx_tgt, id);
158158
159- // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens ));
159+ // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str( ));
160160
161161 const std::string token_str = llama_token_to_piece (ctx_tgt, id);
162162
@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
202202
203203 // TODO: simplify
204204 {
205- LOG (" keeping sequence %d\n " , s_keep);
205+ LOG (" keeping sequence %d, n_past_tgt = %d, n_past_dft = %d \n " , s_keep, n_past_tgt, n_past_dft );
206206
207207 llama_kv_cache_seq_keep (ctx_dft, s_keep);
208208 llama_kv_cache_seq_cp (ctx_dft, s_keep, 0 , -1 , -1 );
@@ -277,7 +277,7 @@ int main(int argc, char ** argv) {
277277 }
278278
279279 if (cur_p[0 ].p < p_accept) {
280- LOG (" stopping drafting for seq %3d, probability too low: %.3f < 2* %.3f\n " , s, cur_p[0 ].p , cur_p[ 1 ]. p );
280+ LOG (" stopping drafting for seq %3d, probability too low: %.3f < %.3f\n " , s, cur_p[0 ].p , p_accept );
281281 drafts[s].drafting = false ;
282282 continue ;
283283 }
@@ -337,16 +337,14 @@ int main(int argc, char ** argv) {
337337
338338 llama_batch_add (batch_tgt, id, n_past_tgt + i + 1 , { s }, true );
339339
340- // no need to evaluate the last drafted token, since we won't use the result
341- if (batch_tgt.n_tokens > n_draft) {
342- drafts[s].drafting = false ;
343- continue ;
344- }
345-
346340 // add the token to the batch for batched decoding with the draft model
347341 drafts[s].i_batch_dft = batch_dft.n_tokens ;
348342
349343 llama_batch_add (batch_dft, id, n_past_cur, { s }, true );
344+
345+ if (batch_tgt.n_tokens > n_draft) {
346+ drafts[s].drafting = false ;
347+ }
350348 }
351349 }
352350
@@ -365,11 +363,6 @@ int main(int argc, char ** argv) {
365363 }
366364 }
367365
368- // account for the last drafted token that we didn't evaluate
369- if (batch_tgt.n_tokens > n_draft) {
370- ++n_drafted;
371- }
372-
373366 // evaluate the target model on the drafted tokens
374367 {
375368 llama_kv_cache_seq_keep (ctx_tgt, 0 );
0 commit comments