@@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317317 break ;
318318 }
319319 params.n_chunks = std::stoi (argv[i]);
320+ } else if (arg == " -np" || arg == " --parallel" ) {
321+ if (++i >= argc) {
322+ invalid_param = true ;
323+ break ;
324+ }
325+ params.n_parallel = std::stoi (argv[i]);
326+ } else if (arg == " -ns" || arg == " --sequences" ) {
327+ if (++i >= argc) {
328+ invalid_param = true ;
329+ break ;
330+ }
331+ params.n_sequences = std::stoi (argv[i]);
320332 } else if (arg == " -m" || arg == " --model" ) {
321333 if (++i >= argc) {
322334 invalid_param = true ;
@@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
360372 params.multiline_input = true ;
361373 } else if (arg == " --simple-io" ) {
362374 params.simple_io = true ;
375+ } else if (arg == " -cb" || arg == " --cont-batching" ) {
376+ params.cont_batching = true ;
363377 } else if (arg == " --color" ) {
364378 params.use_color = true ;
365379 } else if (arg == " --mlock" ) {
@@ -436,8 +450,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
436450 params.use_mmap = false ;
437451 } else if (arg == " --numa" ) {
438452 params.numa = true ;
439- } else if (arg == " --export" ) {
440- params.export_cgraph = true ;
441453 } else if (arg == " --verbose-prompt" ) {
442454 params.verbose_prompt = true ;
443455 } else if (arg == " -r" || arg == " --reverse-prompt" ) {
@@ -456,8 +468,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
456468 if (params.logdir .back () != DIRECTORY_SEPARATOR) {
457469 params.logdir += DIRECTORY_SEPARATOR;
458470 }
459- } else if (arg == " --perplexity" ) {
460- params.perplexity = true ;
471+ } else if (arg == " --perplexity" || arg == " --all-logits " ) {
472+ params.logits_all = true ;
461473 } else if (arg == " --ppl-stride" ) {
462474 if (++i >= argc) {
463475 invalid_param = true ;
@@ -655,12 +667,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
655667 printf (" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
656668 printf (" not recommended: doubles context memory required and no measurable increase in quality\n " );
657669 printf (" --temp N temperature (default: %.1f)\n " , (double )params.temp );
658- printf (" --perplexity compute perplexity over each ctx window of the prompt \n " );
670+ printf (" --logits-all return logits for all tokens in the batch (default: disabled) \n " );
659671 printf (" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n " );
660672 printf (" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n " , params.hellaswag_tasks );
661673 printf (" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
662674 printf (" --draft N number of tokens to draft for speculative decoding (default: %d)\n " , params.n_draft );
663675 printf (" --chunks N max number of chunks to process (default: %d, -1 = all)\n " , params.n_chunks );
676+ printf (" -np N, --parallel N number of parallel sequences to decode (default: %d)\n " , params.n_parallel );
677+ printf (" -ns N, --sequences N number of sequences to decode (default: %d)\n " , params.n_sequences );
678+ printf (" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n " );
664679 if (llama_mlock_supported ()) {
665680 printf (" --mlock force system to keep model in RAM rather than swapping or compressing\n " );
666681 }
@@ -685,7 +700,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
685700 printf (" Not recommended since this is both slower and uses more VRAM.\n " );
686701#endif // GGML_USE_CUBLAS
687702#endif
688- printf (" --export export the computation graph to 'llama.ggml'\n " );
689703 printf (" --verbose-prompt print prompt before generation\n " );
690704 fprintf (stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n " );
691705 printf (" --lora FNAME apply LoRA adapter (implies --no-mmap)\n " );
@@ -738,7 +752,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
738752 lparams.f16_kv = params.memory_f16 ;
739753 lparams.use_mmap = params.use_mmap ;
740754 lparams.use_mlock = params.use_mlock ;
741- lparams.logits_all = params.perplexity ;
755+ lparams.logits_all = params.logits_all ;
742756 lparams.embedding = params.embedding ;
743757 lparams.rope_freq_base = params.rope_freq_base ;
744758 lparams.rope_freq_scale = params.rope_freq_scale ;
@@ -782,8 +796,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
782796 {
783797 LOG (" warming up the model with an empty run\n " );
784798
785- const std::vector<llama_token> tmp = { llama_token_bos (lctx), llama_token_eos (lctx), };
786- llama_eval (lctx, tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , params.n_threads );
799+ std::vector<llama_token> tmp = { llama_token_bos (lctx), llama_token_eos (lctx), };
800+ llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ), params.n_threads );
801+ llama_kv_cache_tokens_rm (lctx, -1 , -1 );
787802 llama_reset_timings (lctx);
788803 }
789804
@@ -890,7 +905,7 @@ llama_token llama_sample_token(
890905
891906 llama_token id = 0 ;
892907
893- float * logits = llama_get_logits (ctx) + idx * n_vocab ;
908+ float * logits = llama_get_logits_ith (ctx, idx) ;
894909
895910 // Apply params.logit_bias map
896911 for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
@@ -941,19 +956,19 @@ llama_token llama_sample_token(
941956 if (mirostat == 1 ) {
942957 static float mirostat_mu = 2 .0f * mirostat_tau;
943958 const int mirostat_m = 100 ;
944- llama_sample_temperature (ctx, &cur_p, temp);
959+ llama_sample_temp (ctx, &cur_p, temp);
945960 id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
946961 } else if (mirostat == 2 ) {
947962 static float mirostat_mu = 2 .0f * mirostat_tau;
948- llama_sample_temperature (ctx, &cur_p, temp);
963+ llama_sample_temp (ctx, &cur_p, temp);
949964 id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
950965 } else {
951966 // Temperature sampling
952967 llama_sample_top_k (ctx, &cur_p, top_k, 1 );
953968 llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
954969 llama_sample_typical (ctx, &cur_p, typical_p, 1 );
955970 llama_sample_top_p (ctx, &cur_p, top_p, 1 );
956- llama_sample_temperature (ctx, &cur_p, temp);
971+ llama_sample_temp (ctx, &cur_p, temp);
957972
958973 {
959974 const int n_top = 10 ;
@@ -1182,7 +1197,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
11821197 fprintf (stream, " color: %s # default: false\n " , params.use_color ? " true" : " false" );
11831198 fprintf (stream, " ctx_size: %d # default: 512\n " , params.n_ctx );
11841199 fprintf (stream, " escape: %s # default: false\n " , params.escape ? " true" : " false" );
1185- fprintf (stream, " export: %s # default: false\n " , params.export_cgraph ? " true" : " false" );
11861200 fprintf (stream, " file: # never logged, see prompt instead. Can still be specified for input.\n " );
11871201 fprintf (stream, " frequency_penalty: %f # default: 0.0 \n " , params.frequency_penalty );
11881202 dump_string_yaml_multiline (stream, " grammar" , params.grammar .c_str ());
@@ -1256,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
12561270 fprintf (stream, " rope_freq_scale: %f # default: 1.0\n " , params.rope_freq_scale );
12571271 fprintf (stream, " seed: %d # default: -1 (random seed)\n " , params.seed );
12581272 fprintf (stream, " simple_io: %s # default: false\n " , params.simple_io ? " true" : " false" );
1273+ fprintf (stream, " cont_batching: %s # default: false\n " , params.cont_batching ? " true" : " false" );
12591274 fprintf (stream, " temp: %f # default: 0.8\n " , params.temp );
12601275
12611276 const std::vector<float > tensor_split_vector (params.tensor_split , params.tensor_split + LLAMA_MAX_DEVICES);
0 commit comments