Description
+34% higher throughput?
TLDR: Seeing vLLM has been really fascinating! @oleitersdorf and I investigated whether we could further accelerate vLLM by profiling its performance with GPU counters. Currently, we believe we have achieve a speed-up of 1.34x for the benchmark reported on the vLLM website. As the vLLM site claims "24x higher throughput compared to HF and up to 3.5x higher throughput than TGI", and the techniques we show below improve a further 1.34x, then vLLM has the potential to have a 29.5x higher throughput compared to the baseline HF and 4.7x over TGI.
Many thanks to the authors for developing this really exciting work -- we had a great time reading your code! We are sure that you probably already thought of the improvements we show below (and maybe just didn't get to them), and would love to hear your thoughts.
Below we write out the optimizations we found, and list several open directions which could hopefully speed up even further. The goal of this issue is to encourage discussion and brainstorm potential improvements -- some parts are still a POC and require more work to make reach production-ready levels. For the part which is already production-ready, we opened this PR.
This issue has 3 sections:
- An optimization in the main attention kernel (
single_query_cached_kv_attention
) - An optimization in the python code serving the models
- Further open directions + ideas which did not work out (yet)
Benchmark
We test on the benchmark of using LLaMA13B to complete 1000 randomly sampled prompts from ShareGPT. For each sequence, we create just one completion (matching the benchmark on the project website). To run the benchmark, begin by cloning vLLM, downloading the dataset from the project website, and running the following command.
python benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model openlm-research/open_llama_13b --tokenizer hf-internal-testing/llama-tokenizer --num-prompts=1000
We begin by running the above on a clean clone of vLLM on an A100 (80GB), to receive the following output.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [04:08<00:00, 4.02it/s]
Throughput: 4.02 requests/s, 1921.31 tokens/s
This rate of 4.02 sequences completed per second translates to 241.2 seq/min. On the project website, a throughput of 154.2 seq/min is reported for running the same model, yet on an A100 (40GB). For this issue, we are using an A100 (80GB), and so we set the reference point at 4.02 seq/sec. By the end of this issue, we get 5.41 seq/sec, achieving an improvement of 1.34x.
Analyzing single_query_cached_kv_attention
The main kernel in vLLM is single_query_cached_kv_attention
, which is used to compute the forward pass of an attention layer, using the KV cache designed in vLLM. We begin by profiling this kernel using NVIDIA Nsight Compute to check for potential improvements.
A preliminary look through NVIDIA Nsight Compute reveals several points to tackle. As seen above, the kernel underutilizes the SM resources both in terms of compute and memory -- uses only roughly 15% of the compute and 50% of the memory bandwidth.
As seen here, each SM has no warps ready to schedule 5 out of 6 times. Thus, we begin by trying to identify the culprit for why the warps are stalling.
The kernel works roughly as follows. Each block is responsible for computing the entire attention mechanism for the last token of one specific sequence and one specific head in that token. Each block is 128 threads by default (4 warps).
- Each thread begins by reading its appropriate query values into its registers/local memory:
Q_vec q_vecs[NUM_VECS_PER_THREAD];
. The threads of the block are split into 'groups' such that each group loads the entire query. On our configuration (default configuration + running LLaMA13B on an A100, 80GB), each thread group has 2 threads. That means that every 2 threads in the warp will read the entire head of the query into their own registers/local memory -- i.e., each thread holds half of the query head. - The code then proceeds to iterate over the entire sequence. Each thread group fetches a single key from the KV cache, and computes the dot product of the query head that was loaded with the appropriate key head. We continue in this fashion until the dot product between the query head and the corresponding key head in every key in the sequence is computed. This for loop basically loads keys from global memory into the registers/local memory, computes the dot product, and then aggregates within each thread group. Throughout, the values for the softmax are also computed; the logits are stored in shared memory.
- After the loop is done, each warp aggregates the results from all the thread groups inside the warp. Following this, aggregation happens between the warps in the block.
- Now, the value vectors are fetched from global memory and are summed according to the computed logits. Each thread stores in its registers/local memory an accumulator for the sums it computes.
- Finally, aggregation of the summed value-dot-logits values happens within each warp and then within the entire block.
To find which stage is holding the warps back, we observe the assembly analysis in Nsight Compute. The warps wait a lot of time on the commands in this screenshot.
As we can see, there is a global load happening, and then roughly 4% of the time stalls happen there (a value is loaded from global memory into register R78, and then warps halt before executing the instruction highlighted as to run the instruction they must wait for the load into R78 to finish). Notably, further below, this code repeats 14 times in total (due to loop unrolling), which causes most of the stalls in this kernel.
These commands are part of the first two steps above, where the threads load the query head and key heads (note: compiling with source code so that Nsight Compute will show the lines in the source the warps are stuck on could help, but it also can significantly change the assembly outputted; therefore, we work directly with the assembly instead -- if someone has a better solution, we would love to hear 🤩). Specifically, these commands are the load of the key heads. As not much can be done about the loading of the key heads, we focus on the query heads which are also loaded from global memory.
The query heads are read multiple times from global memory -- specifically, in our case (default configs, LLaMA13B, A100, 80GB), every byte of the query is read by 64 different threads. Therefore, we begin by optimizing this such that each byte in the query head is read only by exactly 1 thread, and then stored in shared memory for other threads in the block to access.
We replace this code:
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
with this code (see this PR):
constexpr int NUM_THREAD_GROUPS_LOWER_BOUND = NUM_THREADS / THREAD_GROUP_SIZE;
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) {
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS_LOWER_BOUND) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
Running the benchmark gives:
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:45<00:00, 4.43it/s]
Throughput: 4.43 requests/s, 2118.42 tokens/s
As the reference point above is 4.02 seq/sec, the result of 4.43 seq/sec that we get here is a 1.10x improvement.
At this point, we rerun the nsight compute analysis above.
As we can see, the kernel is now at a rather high memory bandwidth utilization (86%). We tried several other improvements (see this section) to squeeze a bit more performance out of this kernel, yet they did not improve the overall runtime of the benchmark. Therefore, as the memory bandwidth utilization is rather high and it appears that the kernel is loading the minimal amount of data it needs to from global memory (it has to load the keys and values...), then we decided to stop looking at the kernel itself and began looking elsewhere.
Overall Program Analysis
Observe the following report generated by using NVIDIA Nsight Systems to profile the entire program execution.
As we can see, roughly half the time the program does not use the GPU at all (observe that DRAM Bandwidth, SM Warp Occupancy, etc, are practically zero half the time). This is time which is spent in the CPU, running the python code which surrounds the model. We investigate and find that the culprit is the sampling of the generated tokens. Observe the forward code of the class LlamaForCausalLM
.
class LlamaForCausalLM(nn.Module):
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata)
return next_tokens
It turns out that half the program time is spent in the above call to self.model
and half in the call to self.sampler
(note: this is not possible to see by timing the Python only, as the kernels are run on the GPU asyncronously and the CPU waits for them later on).
Specifically, the sampler performs the following for each sequence being completed (link).
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(probs, num_samples=1, replacement=True)
That is, for each sequence, probs
is the generated probabilities for the next token. The above code focuses on a specific sequence and samples just for that sequence. We replace feeding the entire matrix (num_sequences x token_space) into torch.multinomial
to perform the sampling for all sequences at once. The following is a POC-level snippet which does this for the current benchmark (sampling just 1 token for each sequence, no beam-search or any other technique).
As the code change is rather long, we do not write it out here -- please refer to the following commit to see the change (note that the code is currently meant as a POC and not production-grade).
Rerunning the benchmark gives the following.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:12<00:00, 5.18it/s]
Throughput: 5.18 requests/s, 2478.31 tokens/s
As the reference point above is 4.02 seq/sec, the result of 5.18 seq/sec that we get here is a 1.28x improvement so far.
We rerun nsight systems and observe the following.
Indeed, the time between GPU calls drastically shrunk. We zoom in to see what remains there.
It seems that there are many small 4 byte reads from the GPU to the CPU. The culprit is the following line, where the logprobs of the chosen tokens are read from the GPU to the CPU one-by-one.
output_logprobs[next_token_id] = logprob[j, next_token_id].item()
These many small reads have a huge overhead and incur high sync costs. Fixing this by coalesing the reads requires some manuvering in the code (it turns out that there is another small read in another place). See this commit for a POC. We rerun the baseline and get the following.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:04<00:00, 5.41it/s]
Throughput: 5.41 requests/s, 2587.22 tokens/s
As the reference point above is 4.02 seq/sec, the result of 5.41 seq/sec that we get here is a 1.34x improvement.
Further potential directions & Ideas which did not pan out
Potential idea: Cache Utilization
As can be seen above, the main kernel of the program has very low L1/L2 cache utilization. Taking another look at the kernel, it seems that 4x more data is being read from global memory than is needed -- each read to fetch a key-segment from the KV cache is translated into a read of 64 bytes, yet, the thread only uses 16 of those bytes. As the keys in the KV cache have shape [num_blocks, num_heads, head_size/x, block_size, x]
, these reads are very spaced apart. Similar issues happen with value fetching.
Could it be that reshaping this cache such that each thread uses all 64 bytes that it reads at once would save 4x on memory bandwidth?
We are not 100% certain that only 16 of each 64 byte read is used, as the assembly seems to point to all 64 bytes being used, while the source code seems to imply only 16. Further investigation is needed here (and includes rewriting some of the other kernels/python code to reshape the cache). Therefore, we would appreciate the author's input here before we try implementing this change (we assume there is a reason this shape was originally chosen) -- i.e., are we missing something :)?
Overall, it seems like it potentially would be worthwhile to investigate the memory loads of this kernel. Observe the following two comments from NVIDIA Nsight Compute:
DRAM Excessive Read Sectors
Est. Speedup: 85.97%
The DRAM fetch granularity for read misses in L2 is 64 bytes, i.e. the lower or upper half of an L2 cache line. Try changing your access pattern to make use of both sectors returned by a DRAM read request for optimal usage of the DRAM throughput. For strided memory reads, avoid strides of 64 bytes or larger to avoid moving unused sectors from DRAM to L2.
Shared Load Bank Conflicts
Est. Speedup: 39.90%
The memory access pattern for shared loads might not be optimal and causes on average a 2.9 - way bank conflict across all 1500640 shared load requests.This results in 1779535 bank conflicts, which represent 40.50% of the overall 4393570 wavefronts for shared loads.
Potential idea: Parallel Tokenization
There is a potential for a further improvement of roughly 10% by parallelizing tokenization after sampling. Specifically, this line get called sequentially for every sequence when we sample convert each sampled token into text. This takes roughly 10% of the execution time -- time where the GPU sits completely idle.
Failed idea: Batch reading from the block tables
At the start of the for-loop fetching the keys, the physical block number is read from the block table (global memory). This line stalls many threads. It turns out that all the threads in a warp read the same position in the block table (which is ok, since, iirc, as only 1 read is sent to the memory and its results are broadcast to the threads automatically). To try to reduce the stall, we can have each thread read a different value, and then only once in every 32 for-loop iterations would we go out for a global memory read.
We implemented this, yet it had no affect on the runtime. We believe it's due to both the fact that the kernel is memory-bound anyhow, and, it seemed that the stall just moved to being concentrated on the key reads (which come right after).
Failed idea: atomicAdd for the last aggregation
We tried replacing the final aggregation in the kernel with atomicAdd between the 4 warps in the block. This degraded the results we observed.
Final Thoughts
vLLM is truly a thought-provoking and intriguing concept! We very much enjoyed delving into this code and are very eager to see how far this can be optimized!! Who knows if it's possible to go even much faster :)
Looking forward to hearing your thoughts!