Dual Epyc Genoa/Turin token generation performance bottleneck #11733
Replies: 5 comments 9 replies
-
Here's a continuation of my investigation regarding the performance bottleneck on dual CPU systems. Tested PlatformsPlatform P0
This is my Epyc Genoa workstation. I used it with NUMA NPS2 BIOS settings to emulate a dual CPU system with a very fast interconnect between CPUs. Platform P1
Thanks for u/TastesLikeOwlbear for providing access to the system. Platform P2
Thanks for u/SuperSecureHuman for providing access to the system. SoftwareI used the following software:
Test Methodology
Assuming a perfect scaling with the number of CPUs the dual CPU llama-bench run should report twice the performance compared to a single CPU run. So 200% is the theoretical maximum here. Test ResultsMemory BandwidthFirst let's see how much of the theoretical max memory bandwidth can we use. For this purpose I measured the read bandwidth with likwid-bench load kernel.
Turin clearly wins over Genoa in terms of the ability to use available memory bandwidth. LLM Inference - dense modelsFor LLM inference tests of dense models I used a classic Llama-3.1 70b with f16 weights.
We can observe pretty good performance scaling figures. It's definitely worth to use a dual CPU platform for large dense models. LLM inference - MoE modelsFor LLM inference tests of dense models I used Mixtral 8x22B v0.1 with Q8_0 weights.
We can observe that prompt processing performance scales very well on dual-CPU systems. But for some reason the token generation performance exhibits only moderate scaling. LLM inference - DeepSeek V3For LLM inference tests of models based on DeepSeek V3 architecture I used DeepSeek R1 with Q4_K_S and Q8_0 (if possible) weights.
The prompt processing shows moderate scaling on dual Genoa system and good scaling on dual Turin system, but the token generation performance scales very bad on both - there are barely any gains compared to using only a single socket. Possible CausesI thought about what may be the possible causes for this and my current working hypothesis is that the observed differences in scaling are caused by sizes of multiplied matrices. If we look at the FFN matrices of tested models they have sizes:
Also note that I tested Llama-3.1 70B in f16, Mixtral in Q8_0 and DeepSeek R1 in Q4_K_S quantization. So there are 448MB of matrix data in Llama-3.1, 96MB in Mixtral and only around 7MB in DeepSeek. Moreover, DeepSeek R1 in Q8 scaled better than Q4_K_S. The smaller is multiplied matrix, the higher (relatively) is synchronization and communication overhead resulting from dual CPU usage. Note that the problem does not manifest on a single-socket system with NPS2 NUMA settings. Next StepsTo verify my hypothesis it would be enough to check how inference of a very small dense model like Llama-3.2 1B scales on a dual CPU system. It has FFN matrix size very similar to the size of the matrix for a single expert in DeepSeek R1. If I'm right the scaling will be horrible. I posted the steps here in case anyone wants to try. |
Beta Was this translation helpful? Give feedback.
-
what happens if you allow the use of sgemm for token generations ? llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp Lines 2373 to 2374 in d04e716 => if (n<1) |
Beta Was this translation helpful? Give feedback.
-
Is it ultimately calling this function: llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp Line 3970 in 0d55958 If so, then it's almost certainly this loop: // compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
auto src0_cur = (const char *) src0->data + cur_a*nb02;
//const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
src0_cur_start =
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
if (src0_cur_start >= src0_cur_end) return;
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
ne01, src0_cur + src0_cur_start * nb01,
src1_col, 1, src0_cur_end - src0_cur_start);
}
} Each 7168×2048 matrix is only 28MB in 16bit float format (compared to 192MB for your Something like this: #pragma omp parallel for schedule(dynamic, 1)
for (int cur_a = 0; cur_a < n_as; ++cur_a) { and writing into a temporary instead of It needs to use "dynamic" as the loop has the early exit I've no idea how EDIT: Actually, I just saw you replied to my post above and it may be a different function - I'll leave this here anyway and go and have a look at that :) |
Beta Was this translation helpful? Give feedback.
-
ggml_barrier(params->threadpool);
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01;
const int64_t nr1 = cne1;
int chunk_size = 16;
if (nr0 == 1 || nr1 == 1) {
chunk_size = 64;
}
#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
nchunk0 = nr0 > nr1 ? nth : 1;
nchunk1 = nr0 > nr1 ? 1 : nth;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
int current_chunk = ith;
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_id_one_chunk(
dst, src0, src1, ids, cur_a,
ir0_start, ir0_end, ir1_start, ir1_end,
src0_cur, matrix_rows, row_size, src1_cont, wdata
);
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
}
}
} I can see the static void ggml_compute_forward_mul_mat_id_one_chunk(
struct ggml_tensor * dst,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * ids,
const int64_t cur_a,
const int64_t ir0_start,
const int64_t ir0_end,
const int64_t ir1_start,
const int64_t ir1_end,
const char * src0_cur,
const struct mmid_row_mapping * matrix_rows,
const size_t row_size,
const bool src1_cont,
const void * wdata) {
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
float tmp[16];
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11)*row_size
: (i11*nb11 + i12*nb12));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
}
}
}
} Surely it's not just in |
Beta Was this translation helpful? Give feedback.
-
I created a NUMA-aware matrix vector multiplication benchmark (modified the existing old one), it's in the numa-matmul-bench branch of my llama.cpp fork: https://github.com/fairydreaming/llama.cpp/tree/numa-matmul-bench The benchmark needs libnuma to compile. Now I need a brave one to try it on a dual-CPU machine. For example I ran it by using FFN tensor dimensions from llama-3.1 70B on my workstation (emulated a dual CPU system with BIOS NUMA NPS2 setting) and got the following: Using a single CPU:
Using two CPUs:
As you can see there is almost a perfect performance scaling, since 1062.94 / 565.63 = 187.9%. So with two CPUs it works almost twice as fast as with one CPU. Commands to try for smaller matrices (size like in DeepSeek R1 experts): To run on a single CPU:
To run on two CPUs:
Parameter -t is the number of threads, -i is the number of benchmark iterations, parameter -l is the number of benchmark computation graph "layers". Each layer is a single matrix vector multiplication and tensor addition. The number of layers is tuned to make sure that weights of multiplied matrices won't be cached in L3 cache. You can also try swapping x and y values in both commands to see if it affects the performance. On my machine I have average results 870.30 and 535.76, so the scaling ratio is 162.4% - a bit worse than for large matrix. But note that this is an "emulated" dual CPU machine, a real one shall perform worse due to limited interconnect bandwidth and increased latency (I wonder how much worse). |
Beta Was this translation helpful? Give feedback.
-
Part 1 - The Problem
I have a temporary access to a dual CPU Epyc Turin system. I did some initial performance tests with llama.cpp running on a single CPU:
For CPU 0:
For CPU 1:
Unfortunately when I run llama.cpp on both CPUs at once with
--numa distribute
the prompt processing performance doubles, while the token generation performance stays at the same level (actually it's even a bit worse) as with a single CPU:Part 2 - The Workaround
I did some more tests and found something weird. If I run llama-bench with prompt processing and token generation tests with
--numa distribute
on a dual-CPU system, the result is:but when I dropped caches and ran ONLY the generation test it magically became faster:
So my current hypothesis is that the placement of tensors in memory resulting from the prompt processing is for some reason sub-optimal for the token generation. This is definitely something to investigate further.
But loading the model during generation instead of prompt processing can be a viable workaround to the problem. I mean if running a generation benchmark results in optimal placement of tensors in memory, then just run it first and you are done. The generation performance stays high after this even when running combined benchmark:
I described this workaround in #11744 so that people can try it.
Part 3 - The Cause
I did some more investigation on what causes this and it seems to be related to
GGML_USE_LLAMAFILE
andllamafile_sgemm()
calls. If the model weights are loaded with these calls, the token generation performance is reduced. Example:When I disable GGML_USE_LLAMAFILE the token generation rate is not reduced (but prompt processing is much slower,
llamafile_sgemm()
gives it a huge performance boost):When I use trick described in #11744 (with GGML_USE_LLAMAFILE enabled) it's possible to keep both prompt processing and token generation rate fast:
Regarding the exact cause, it's a large number of remote NUMA node memory accesses during token generation if the model weights were loaded with
llamafile_sgemm()
calls. Measured withnumatop
during "slow" generation:while during "fast" generation we have:
It seems that
llamafile_sgemm()
places the model weights in disk cache memory in such a way that a large number of remote NUMA node memory accesses is needed when using the weights during token generation.Part 4 - The Solution
Simplest solution for this problem would be to warm-up the model with token generation instead of prompt processing, so that
llamafile_sgemm()
calls are not used to load model weights. I tested it by commenting the EOS token in creation of the warm-up batch (so that there's only a single token in this batch) and it seems to work. I tested it by running llama-cli and then llama-bench to measure token generation rate. With a single token in the warm-up batch I have:When there are two tokens (BOS and EOS) I have:
Disadvantage of this simple workaround is that disabling the warm-up with a command line option would still cause reduction of the token generation performance.
A proper fix for this problem would be a NUMA-aware matrix multiplication implementation which:
Another possible solution is implementation of Megatron-LM-style tensor parallelism. In this case each NUMA node would use only its associated part of model weights and would keep them in local memory.
Beta Was this translation helpful? Give feedback.
All reactions