Description
Recently got interested in how to run Llama3 inferences with large context lengths like 128K. For context Llama2 had a max sequence length of 4096. One solution that always works is to go distributed with techniques like Ring Attention where you split the sequence over multiple devices but instead I'm interested in how to run large context windows on a single GPU.
For larger sequence lengths the primary VRAM bottleneck is not the model parameters but the size of the KV cache which has an analytical formula of: 2 * layers * attention heads * head_dim * byte_per_element * batch_size * sequence_length and the model param has a simple formula of number_of_param * byte_per_element
So what I plotted below was the model params + KV cache size as I increased the sequence length

A few things jump out
- We can run a context length of 128K with an out of the box GPT-fast implementation at fp16 but it needs to be on 80GB+ GPU
- Int8 KV quantization is likely the most important problem we can prioritize because it should allow us to support LLama8B inference even on a consumer GPU like a 3090 or 4090 with 24GB of VRAM
- Keep in mind how the curve starts bending at a context length of around 16K below that KV cache size is not a huge worry
The second plot was the exact same thing with Llama70B

At this size
- Fp16 is out of the question
- int8 is feasible on an 80GB GPU at smaller sequence lengths
- int4 is feasible on a 40GB GPU at smaller sequence lengths
- int2 is feasible on a 24GB GPU at smaller sequence lengths
- Running 128K sequence length is only feasible with int2 and an 80GB GPU
- With int8 and below we can do single node inference which has up to 8 40GB GPU -> 320 > 250 so again int8 is a nice sweet spot for kv cache quantization
A few important caveats
- These are analytical results so will need to run real experiment to see whether this holds up. For now the results above ignore accidental large intermediaries
- NVIDIA is unlikely to increase VRAM on consumer GPUs. NVIDIA is also disabling peer to peer access on their consumer GPUs so it's unlikely consumer GPUs will be a thriving market for distributed inference
- Int4/Int2 quantization will likely come with massive perplexity loss unless the quantization is done in a clever way similarly to KVQuant or leverage QAT - int8 is unlikely to require as much cleverness
- There are other techniques to reduce KV cache sizes using MQA or cross layer KV cache sharing but we can't (AFAIK) flexibly change attention patterns from training to inference and expect things to work so some co-design with the training teams will be beneficial long term