@@ -164,6 +164,7 @@ struct cmd_params {
164164 std::vector<int > n_prompt;
165165 std::vector<int > n_gen;
166166 std::vector<int > n_batch;
167+ std::vector<int > n_ubatch;
167168 std::vector<ggml_type> type_k;
168169 std::vector<ggml_type> type_v;
169170 std::vector<int > n_threads;
@@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
183184 /* model */ {" models/7B/ggml-model-q4_0.gguf" },
184185 /* n_prompt */ {512 },
185186 /* n_gen */ {128 },
186- /* n_batch */ {512 },
187+ /* n_batch */ {2048 },
188+ /* n_ubatch */ {512 },
187189 /* type_k */ {GGML_TYPE_F16},
188190 /* type_v */ {GGML_TYPE_F16},
189191 /* n_threads */ {get_num_physical_cores ()},
@@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
208210 printf (" -p, --n-prompt <n> (default: %s)\n " , join (cmd_params_defaults.n_prompt , " ," ).c_str ());
209211 printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
210212 printf (" -b, --batch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_batch , " ," ).c_str ());
213+ printf (" -ub N, --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
211214 printf (" -ctk <t>, --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
212215 printf (" -ctv <t>, --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
213216 printf (" -t, --threads <n> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
@@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
217220 printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
218221 printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
219222 printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
220- printf (" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n " );
223+ printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
221224 printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
222225 printf (" -o, --output <csv|json|md|sql> (default: %s)\n " , output_format_str (cmd_params_defaults.output_format ));
223226 printf (" -v, --verbose (default: %s)\n " , cmd_params_defaults.verbose ? " 1" : " 0" );
@@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
297300 }
298301 auto p = split<int >(argv[i], split_delim);
299302 params.n_batch .insert (params.n_batch .end (), p.begin (), p.end ());
303+ } else if (arg == " -ub" || arg == " --ubatch-size" ) {
304+ if (++i >= argc) {
305+ invalid_param = true ;
306+ break ;
307+ }
308+ auto p = split<int >(argv[i], split_delim);
309+ params.n_ubatch .insert (params.n_ubatch .end (), p.begin (), p.end ());
300310 } else if (arg == " -ctk" || arg == " --cache-type-k" ) {
301311 if (++i >= argc) {
302312 invalid_param = true ;
@@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
455465 if (params.n_prompt .empty ()) { params.n_prompt = cmd_params_defaults.n_prompt ; }
456466 if (params.n_gen .empty ()) { params.n_gen = cmd_params_defaults.n_gen ; }
457467 if (params.n_batch .empty ()) { params.n_batch = cmd_params_defaults.n_batch ; }
468+ if (params.n_ubatch .empty ()) { params.n_ubatch = cmd_params_defaults.n_ubatch ; }
458469 if (params.type_k .empty ()) { params.type_k = cmd_params_defaults.type_k ; }
459470 if (params.type_v .empty ()) { params.type_v = cmd_params_defaults.type_v ; }
460471 if (params.n_gpu_layers .empty ()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers ; }
@@ -474,6 +485,7 @@ struct cmd_params_instance {
474485 int n_prompt;
475486 int n_gen;
476487 int n_batch;
488+ int n_ubatch;
477489 ggml_type type_k;
478490 ggml_type type_v;
479491 int n_threads;
@@ -511,6 +523,7 @@ struct cmd_params_instance {
511523
512524 cparams.n_ctx = n_prompt + n_gen;
513525 cparams.n_batch = n_batch;
526+ cparams.n_ubatch = n_ubatch;
514527 cparams.type_k = type_k;
515528 cparams.type_v = type_v;
516529 cparams.offload_kqv = !no_kv_offload;
@@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
532545 for (const auto & mmp : params.use_mmap )
533546 for (const auto & embd : params.embeddings )
534547 for (const auto & nb : params.n_batch )
548+ for (const auto & nub : params.n_ubatch )
535549 for (const auto & tk : params.type_k )
536550 for (const auto & tv : params.type_v )
537551 for (const auto & nkvo : params.no_kv_offload )
@@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
545559 /* .n_prompt = */ n_prompt,
546560 /* .n_gen = */ 0 ,
547561 /* .n_batch = */ nb,
562+ /* .n_ubatch = */ nub,
548563 /* .type_k = */ tk,
549564 /* .type_v = */ tv,
550565 /* .n_threads = */ nt,
@@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
568583 /* .n_prompt = */ 0 ,
569584 /* .n_gen = */ n_gen,
570585 /* .n_batch = */ nb,
586+ /* .n_ubatch = */ nub,
571587 /* .type_k = */ tk,
572588 /* .type_v = */ tv,
573589 /* .n_threads = */ nt,
@@ -604,6 +620,7 @@ struct test {
604620 uint64_t model_size;
605621 uint64_t model_n_params;
606622 int n_batch;
623+ int n_ubatch;
607624 int n_threads;
608625 ggml_type type_k;
609626 ggml_type type_v;
@@ -627,6 +644,7 @@ struct test {
627644 model_size = llama_model_size (lmodel);
628645 model_n_params = llama_model_n_params (lmodel);
629646 n_batch = inst.n_batch ;
647+ n_ubatch = inst.n_ubatch ;
630648 n_threads = inst.n_threads ;
631649 type_k = inst.type_k ;
632650 type_v = inst.type_v ;
@@ -705,7 +723,8 @@ struct test {
705723 " cuda" , " opencl" , " vulkan" , " kompute" , " metal" , " sycl" , " gpu_blas" , " blas" ,
706724 " cpu_info" , " gpu_info" ,
707725 " model_filename" , " model_type" , " model_size" , " model_n_params" ,
708- " n_batch" , " n_threads" , " type_k" , " type_v" ,
726+ " n_batch" , " n_ubatch" ,
727+ " n_threads" , " type_k" , " type_v" ,
709728 " n_gpu_layers" , " split_mode" ,
710729 " main_gpu" , " no_kv_offload" ,
711730 " tensor_split" , " use_mmap" , " embeddings" ,
@@ -719,7 +738,8 @@ struct test {
719738 enum field_type {STRING, BOOL, INT, FLOAT};
720739
721740 static field_type get_field_type (const std::string & field) {
722- if (field == " build_number" || field == " n_batch" || field == " n_threads" ||
741+ if (field == " build_number" || field == " n_batch" || field == " n_ubatch" ||
742+ field == " n_threads" ||
723743 field == " model_size" || field == " model_n_params" ||
724744 field == " n_gpu_layers" || field == " main_gpu" ||
725745 field == " n_prompt" || field == " n_gen" ||
@@ -759,7 +779,8 @@ struct test {
759779 std::to_string (metal), std::to_string (sycl), std::to_string (gpu_blas), std::to_string (blas),
760780 cpu_info, gpu_info,
761781 model_filename, model_type, std::to_string (model_size), std::to_string (model_n_params),
762- std::to_string (n_batch), std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
782+ std::to_string (n_batch), std::to_string (n_ubatch),
783+ std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
763784 std::to_string (n_gpu_layers), split_mode_str (split_mode),
764785 std::to_string (main_gpu), std::to_string (no_kv_offload),
765786 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
@@ -957,6 +978,9 @@ struct markdown_printer : public printer {
957978 if (params.n_batch .size () > 1 || params.n_batch != cmd_params_defaults.n_batch ) {
958979 fields.emplace_back (" n_batch" );
959980 }
981+ if (params.n_ubatch .size () > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch ) {
982+ fields.emplace_back (" n_ubatch" );
983+ }
960984 if (params.type_k .size () > 1 || params.type_k != cmd_params_defaults.type_k ) {
961985 fields.emplace_back (" type_k" );
962986 }
@@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
10961120};
10971121
10981122static void test_prompt (llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1123+ llama_set_n_threads (ctx, n_threads, n_threads);
1124+
1125+ // std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
1126+ // llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
1127+ // GGML_UNUSED(n_batch);
1128+
10991129 std::vector<llama_token> tokens (n_batch, llama_token_bos (llama_get_model (ctx)));
11001130 int n_processed = 0 ;
11011131
1102- llama_set_n_threads (ctx, n_threads, n_threads);
1103-
11041132 while (n_processed < n_prompt) {
11051133 int n_tokens = std::min (n_prompt - n_processed, n_batch);
11061134 llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens, n_past + n_processed, 0 ));
11071135 n_processed += n_tokens;
11081136 }
1137+
1138+ llama_synchronize (ctx);
11091139}
11101140
11111141static void test_gen (llama_context * ctx, int n_gen, int n_past, int n_threads) {
1112- llama_token token = llama_token_bos (llama_get_model (ctx));
1113-
11141142 llama_set_n_threads (ctx, n_threads, n_threads);
11151143
1144+ llama_token token = llama_token_bos (llama_get_model (ctx));
1145+
11161146 for (int i = 0 ; i < n_gen; i++) {
11171147 llama_decode (ctx, llama_batch_get_one (&token, 1 , n_past + i, 0 ));
1148+ llama_synchronize (ctx);
11181149 }
11191150}
11201151
@@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {
12031234
12041235 // warmup run
12051236 if (t.n_prompt > 0 ) {
1206- test_prompt (ctx, std::min (2 , t.n_batch ), 0 , t.n_batch , t.n_threads );
1237+ // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1238+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
12071239 }
12081240 if (t.n_gen > 0 ) {
12091241 test_gen (ctx, 1 , 0 , t.n_threads );
@@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
12191251 if (t.n_gen > 0 ) {
12201252 test_gen (ctx, t.n_gen , t.n_prompt , t.n_threads );
12211253 }
1254+
12221255 uint64_t t_ns = get_time_ns () - t_start;
12231256 t.samples_ns .push_back (t_ns);
12241257 }
0 commit comments