Skip to content

Commit ec250ae

Browse files
committed
Porting saood06's ik_llama.cpp utility
To make it easier to compare performance across forks
1 parent df0c0c7 commit ec250ae

File tree

5 files changed

+380
-0
lines changed

5 files changed

+380
-0
lines changed

examples/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ else()
3131
add_subdirectory(simple-chat)
3232
add_subdirectory(speculative)
3333
add_subdirectory(speculative-simple)
34+
add_subdirectory(sweep-bench)
35+
add_subdirectory(tokenize)
36+
add_subdirectory(tts)
3437
add_subdirectory(gen-docs)
3538
add_subdirectory(training)
3639
if (NOT GGML_BACKEND_DL)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-sweep-bench)
2+
add_executable(${TARGET} sweep-bench.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/sweep-bench/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# ik_llama.cpp/example/sweep-bench
2+
3+
Benchmark the prompt processing and token generation performance of `ik_llama.cpp`
4+
by doing a sweep over a whole context size and gathering performance metrics
5+
in each ubatch-sized window. Only a single token sequence is used.
6+
7+
The benchmark steps are:
8+
9+
for each ubatch-sized window in context:
10+
11+
1. generate ubatch/4 tokens (not the whole window to save some time)
12+
2. measure generation performance
13+
3. remove generated tokens from KV cache
14+
4. prepare a ubatch-sized batch of random tokens
15+
4. process prepated batch
16+
5. measure prompt processing performance
17+
18+
The purpose of the benchmark is to visualize how the performance changes with
19+
the context size without averaging the metrics values over the whole context.
20+
21+
## Usage
22+
23+
./llama-sweep-bench -c 8704 -ub 512 -m models/Meta-Llama-3.2-3B-Instruct-Q8_0.gguf
24+
25+
## Sample results
26+
27+
- `PP` - prompt tokens per ubatch
28+
- `TG` - generated tokens per ubatch
29+
- `N_KV` - current KV cache size
30+
- `T_PP` - prompt processing time (i.e. time to first token)
31+
- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`)
32+
- `T_TG` - time to generate all batches
33+
- `S_TG` - text generation speed (`(B*TG)/T_TG`)
34+
35+
| PP | TG | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s |
36+
|-------|--------|--------|----------|----------|----------|----------|
37+
| 512 | 128 | 0 | 1.100 | 465.51 | 2.311 | 55.38 |
38+
| 512 | 128 | 512 | 1.183 | 432.97 | 1.895 | 67.55 |
39+
| 512 | 128 | 1024 | 1.305 | 392.38 | 2.071 | 61.81 |
40+
| 512 | 128 | 1536 | 1.279 | 400.42 | 2.164 | 59.14 |
41+
| 512 | 128 | 2048 | 1.571 | 325.96 | 2.280 | 56.14 |
42+
| 512 | 128 | 2560 | 1.431 | 357.87 | 2.418 | 52.94 |
43+
| 512 | 128 | 3072 | 1.515 | 337.93 | 2.566 | 49.88 |
44+
| 512 | 128 | 3584 | 1.588 | 322.34 | 2.722 | 47.03 |
45+
| 512 | 128 | 4096 | 1.675 | 305.70 | 2.864 | 44.69 |
46+
| 512 | 128 | 4608 | 1.769 | 289.50 | 2.999 | 42.68 |
47+
| 512 | 128 | 5120 | 1.845 | 277.48 | 3.102 | 41.26 |
48+
| 512 | 128 | 5632 | 1.893 | 270.46 | 3.219 | 39.76 |
49+
| 512 | 128 | 6144 | 1.953 | 262.20 | 3.348 | 38.23 |
50+
| 512 | 128 | 6656 | 2.018 | 253.71 | 3.474 | 36.84 |
51+
| 512 | 128 | 7168 | 2.078 | 246.34 | 3.589 | 35.66 |
52+
| 512 | 128 | 7680 | 2.140 | 239.22 | 3.717 | 34.43 |
53+
| 512 | 128 | 8192 | 2.196 | 233.15 | 3.854 | 33.21 |
54+
55+
### JSONL output
56+
57+
Pass `--output-format jsonl` to output JSONL instead of Markdown, á la
58+
59+
```json lines
60+
{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 0, "t_pp": 1.093814, "speed_pp": 468.086884, "t_tg": 1.780312, "speed_tg": 71.897514 }
61+
{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 512, "t_pp": 1.169302, "speed_pp": 437.868073, "t_tg": 1.897474, "speed_tg": 67.458099 }
62+
{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1024, "t_pp": 1.183700, "speed_pp": 432.542053, "t_tg": 2.059179, "speed_tg": 62.160694 }
63+
{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1536, "t_pp": 1.428625, "speed_pp": 358.386566, "t_tg": 2.160639, "speed_tg": 59.241734 }
64+
{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 2048, "t_pp": 1.360647, "speed_pp": 376.291595, "t_tg": 2.274003, "speed_tg": 56.288403 }
65+
```
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pandas as pd
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import argparse
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument('file', nargs='+')
8+
args = parser.parse_args()
9+
10+
df = None
11+
12+
#for jsonl_file in args.file:
13+
# # Read JSONL file into DataFrame
14+
# df_part = pd.read_json(jsonl_file, lines=True)
15+
# df_part['label'] = jsonl_file
16+
# if df is None:
17+
# df = df_part
18+
# else:
19+
# df = pd.concat([df, df_part])
20+
#
21+
22+
23+
24+
for md_file in args.file:
25+
# Read markdown table file into DataFrame
26+
df_part = pd.read_csv(md_file, sep=r'\s*\|\s*', engine='python',
27+
header=0, skiprows=[1])
28+
29+
# Clean up columns (remove empty columns from markdown formatting)
30+
df_part = df_part.iloc[:, 1:-1]
31+
df_part.columns = [col.strip() for col in df_part.columns]
32+
33+
# Rename columns to match expected names
34+
df_part = df_part.rename(columns={
35+
'N_KV': 'n_kv',
36+
'S_PP t/s': 'speed_pp',
37+
'S_TG t/s': 'speed_tg'
38+
})
39+
40+
# Convert to numeric types
41+
df_part['n_kv'] = pd.to_numeric(df_part['n_kv'])
42+
df_part['speed_pp'] = pd.to_numeric(df_part['speed_pp'])
43+
df_part['speed_tg'] = pd.to_numeric(df_part['speed_tg'])
44+
45+
# Add label and append to main DataFrame
46+
df_part['label'] = md_file
47+
df = pd.concat([df, df_part]) if df is not None else df_part
48+
49+
# Group by label and n_kv, calculate mean and std for both speed metrics
50+
df_grouped = df.groupby(['label', 'n_kv']).agg({
51+
'speed_pp': ['mean', 'std'],
52+
'speed_tg': ['mean', 'std']
53+
}).reset_index()
54+
55+
# Flatten multi-index columns
56+
df_grouped.columns = ['label', 'n_kv', 'speed_pp_mean', 'speed_pp_std',
57+
'speed_tg_mean', 'speed_tg_std']
58+
59+
# Replace NaN with 0 (std for a single sample is NaN)
60+
df_grouped['speed_pp_std'] = df_grouped['speed_pp_std'].fillna(0)
61+
df_grouped['speed_tg_std'] = df_grouped['speed_tg_std'].fillna(0)
62+
63+
# Prepare ticks values for X axis (prune for readability)
64+
x_ticks = df['n_kv'].unique()
65+
while len(x_ticks) > 16:
66+
x_ticks = x_ticks[::2]
67+
68+
# Get unique labels and color map
69+
labels = df_grouped['label'].unique()
70+
colors = plt.cm.rainbow(np.linspace(0, 1, len(labels)))
71+
72+
# Create prompt processing plot
73+
plt.figure(figsize=(10, 6))
74+
ax1 = plt.gca()
75+
plt.grid()
76+
ax1.set_xticks(x_ticks)
77+
78+
# Plot each label's data
79+
for label, color in zip(labels, colors):
80+
label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv')
81+
pp = ax1.errorbar(label_data['n_kv'], label_data['speed_pp_mean'],
82+
yerr=label_data['speed_pp_std'], color=color,
83+
marker='o', linestyle='-', label=label)
84+
85+
# Add labels and title
86+
ax1.set_xlabel('Context Length (tokens)')
87+
ax1.set_ylabel('Prompt Processing Rate (t/s)')
88+
plt.title('Prompt Processing Performance Comparison')
89+
ax1.legend(loc='upper right')
90+
91+
# Adjust layout and save
92+
plt.tight_layout()
93+
plt.savefig('performance_comparison_pp.png', bbox_inches='tight')
94+
plt.close()
95+
96+
# Create token generation plot
97+
plt.figure(figsize=(10, 6))
98+
ax1 = plt.gca()
99+
plt.grid()
100+
ax1.set_xticks(x_ticks)
101+
102+
# Plot each model's data
103+
for label, color in zip(labels, colors):
104+
label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv')
105+
tg = ax1.errorbar(label_data['n_kv'], label_data['speed_tg_mean'],
106+
yerr=label_data['speed_tg_std'], color=color,
107+
marker='s', linestyle='-', label=label)
108+
109+
# Add labels and title
110+
ax1.set_xlabel('Context Length (n_kv)')
111+
ax1.set_ylabel('Token Generation Rate (t/s)')
112+
plt.title('Token Generation Performance Comparison')
113+
ax1.legend(loc='upper right')
114+
115+
# Adjust layout and save
116+
plt.tight_layout()
117+
plt.savefig('performance_comparison_tg.png', bbox_inches='tight')
118+
plt.close()
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#include "ggml.h"
2+
#include "llama.h"
3+
#include "common.h"
4+
#include "llama-vocab.h"
5+
6+
#ifdef _WIN32
7+
#define WIN32_LEAN_AND_MEAN
8+
#ifndef NOMINMAX
9+
# define NOMINMAX
10+
#endif
11+
#include <windows.h>
12+
#endif
13+
14+
#include <algorithm>
15+
#include <cstdlib>
16+
#include <cstdio>
17+
#include <string>
18+
#include <vector>
19+
20+
static void print_usage(int, char ** argv) {
21+
LOG_TEE("\nexample usage:\n");
22+
LOG_TEE("\n %s -m model.gguf -c 8192 -b 2048 -ub 512\n", argv[0]);
23+
LOG_TEE("\n");
24+
}
25+
26+
int main(int argc, char ** argv) {
27+
28+
gpt_params params;
29+
30+
if (!gpt_params_parse(argc, argv, params)) {
31+
print_usage(argc, argv);
32+
return 1;
33+
}
34+
35+
// init LLM
36+
37+
llama_backend_init();
38+
llama_numa_init(params.numa);
39+
40+
// initialize the model
41+
42+
llama_model_params model_params = llama_model_params_from_gpt_params(params);
43+
44+
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
45+
46+
if (model == NULL) {
47+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
48+
return 1;
49+
}
50+
51+
llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
52+
53+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
54+
55+
if (ctx == NULL) {
56+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
57+
return 1;
58+
}
59+
60+
const unsigned int n_kv_max = llama_n_ctx(ctx);
61+
62+
63+
const llama_vocab * vocab = llama_get_vocab(ctx);
64+
llama_token bos = llama_token_bos_impl(*vocab);
65+
//llama_token eos = llama_token_eos_impl(*vocab);
66+
67+
const unsigned int n_vocab = llama_n_vocab(model);
68+
69+
// decode in batches of ctx_params.n_batch tokens
70+
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
71+
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
72+
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
73+
74+
llama_batch batch_view = {
75+
n_tokens,
76+
batch.token + i,
77+
nullptr,
78+
batch.pos + i,
79+
batch.n_seq_id + i,
80+
batch.seq_id + i,
81+
batch.logits + i,
82+
};
83+
84+
const int ret = llama_decode(ctx, batch_view);
85+
if (ret != 0) {
86+
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
87+
return false;
88+
}
89+
90+
llama_synchronize(ctx);
91+
}
92+
93+
return true;
94+
};
95+
96+
const unsigned int pp = params.n_ubatch;
97+
const unsigned int tg = params.n_ubatch / 4;
98+
99+
if (!params.sweep_bench_output_jsonl) {
100+
LOG_TEE("\n");
101+
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
102+
LOG_TEE("\n");
103+
LOG_TEE("|%6s | %6s | %6s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s");
104+
LOG_TEE("|%6s-|-%6s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "------", "--------", "--------", "--------", "--------");
105+
}
106+
107+
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
108+
109+
// warm up
110+
{
111+
llama_batch_add(batch, bos, 0, { 0 }, false);
112+
113+
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
114+
LOG_TEE("%s: llama_decode() failed\n", __func__);
115+
return 1;
116+
}
117+
}
118+
119+
llama_batch_clear(batch);
120+
llama_kv_cache_clear(ctx);
121+
122+
for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) {
123+
// clean up KV cache before generation
124+
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
125+
126+
// first measure token generation performance at this context size
127+
const auto t_tg_start = ggml_time_us();
128+
129+
for (unsigned int i = 0; i < tg; ++i) {
130+
llama_batch_clear(batch);
131+
llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true);
132+
133+
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
134+
LOG_TEE("%s: llama_decode() failed\n", __func__);
135+
return 1;
136+
}
137+
}
138+
139+
const auto t_tg_end = ggml_time_us();
140+
141+
// clean up KV cache after generation
142+
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
143+
144+
// prepare batch of pp size for prompt processing performance measurement
145+
llama_batch_clear(batch);
146+
147+
for (unsigned int i = 0; i < pp; ++i) {
148+
llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false);
149+
}
150+
batch.logits[batch.n_tokens - 1] = true;
151+
152+
// measure prompt processing performance
153+
const auto t_pp_start = ggml_time_us();
154+
155+
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
156+
LOG_TEE("%s: llama_decode() failed\n", __func__);
157+
return 1;
158+
}
159+
160+
const auto t_pp_end = ggml_time_us();
161+
162+
// calculate and print metrics
163+
const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
164+
const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
165+
166+
const float speed_pp = pp / t_pp;
167+
const float speed_tg = tg / t_tg;
168+
169+
if(params.sweep_bench_output_jsonl) {
170+
LOG_TEE(
171+
"{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
172+
"\"pp\": %d, \"tg\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f }\n",
173+
n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
174+
pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg
175+
);
176+
} else {
177+
LOG_TEE("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg);
178+
}
179+
}
180+
181+
llama_batch_free(batch);
182+
183+
llama_free(ctx);
184+
llama_free_model(model);
185+
186+
llama_backend_free();
187+
188+
return 0;
189+
}

0 commit comments

Comments
 (0)