@@ -386,7 +386,7 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
386386    candidates->size  = last_idx;
387387}
388388
389- void  sample_rep_pen (int  n_ctx, int  rep_pen_range, float  rep_pen, llama_token_data_array * candidates_p)
389+ void  sample_rep_pen (int  n_ctx, int  rep_pen_range, float  rep_pen, float  presence_penalty,  llama_token_data_array * candidates_p)
390390{
391391    auto  last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), rep_pen_range), n_ctx);
392392
@@ -414,6 +414,8 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_dat
414414        } else  {
415415            candidates->data [i].logit  /= penalty;
416416        }
417+ 
418+         candidates->data [i].logit  -= presence_penalty;
417419    }
418420
419421    candidates->sorted  = false ;
@@ -474,7 +476,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
474476
475477}
476478
477- int  SampleLogits (const  float  * logits, int  n_ctx, int  n_vocab, int  rep_pen_range, float  rep_pen, float  top_k, float  top_a, float  top_p, float  min_p, float  typical_p, float  tfs, float  temp, std::mt19937 & rng,
479+ int  SampleLogits (const  float  * logits, int  n_ctx, int  n_vocab, int  rep_pen_range, float  rep_pen, float  presence_penalty,  float   top_k, float  top_a, float  top_p, float  min_p, float  typical_p, float  tfs, float  temp, std::mt19937 & rng,
478480int  mirostat, float  mirostat_tau, float  mirostat_eta, const  std::vector<samplers> & sampler_order, llama_grammar * grammar)
479481{
480482    int  id = 0 ;
@@ -494,7 +496,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
494496    {
495497        static  float  mirostat_mu = 2 .0f  * mirostat_tau;
496498        const  int  mirostat_m = 100 ;
497-         sample_rep_pen (n_ctx, rep_pen_range, rep_pen, &candidates_p);
499+         sample_rep_pen (n_ctx, rep_pen_range, rep_pen, presence_penalty,  &candidates_p);
498500        sample_temperature (&candidates_p, temp);
499501        if  (mirostat == 1 )
500502        {
@@ -531,7 +533,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
531533                    sample_temperature (&candidates_p, temp);
532534                    break ;
533535                case  KCPP_SAMPLER_REP_PEN:
534-                     sample_rep_pen (n_ctx, rep_pen_range, rep_pen, &candidates_p);
536+                     sample_rep_pen (n_ctx, rep_pen_range, rep_pen, presence_penalty,  &candidates_p);
535537                    break ;
536538                default :
537539                    printf (" \n SampleLogits: Unknown Sampler : %d" 
@@ -1442,6 +1444,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14421444    params.temp  = inputs.temperature ;
14431445    params.repeat_last_n  = inputs.rep_pen_range ;
14441446    params.repeat_penalty  = inputs.rep_pen ;
1447+     params.presence_penalty  = inputs.presence_penalty ;
14451448    params.mirostat  = inputs.mirostat ;
14461449    params.mirostat_eta  = inputs.mirostat_eta ;
14471450    params.mirostat_tau  = inputs.mirostat_tau ;
@@ -1836,6 +1839,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
18361839            const  float  temp = params.temp ;
18371840            const  float  top_a = inputs.top_a ;
18381841            const  float  repeat_penalty = params.repeat_penalty ;
1842+             const  float  presence_penalty = params.presence_penalty ;
18391843            const  float  typical_p = params.typical_p ;
18401844            const  float  tfs_z = params.tfs_z ;
18411845
@@ -1891,7 +1895,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
18911895                }
18921896            }
18931897
1894-             id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
1898+             id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty, 
18951899            top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
18961900            params.mirostat , params.mirostat_tau , params.mirostat_eta , sampler_order, grammar);
18971901
0 commit comments