Skip to content

Commit d8d6659

Browse files
committed
feat: Add per server LRU capacity
Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>
1 parent 616c670 commit d8d6659

File tree

6 files changed

+113
-75
lines changed

6 files changed

+113
-75
lines changed

cmd/epp/main.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ func loadPrefixCacheConfig() prefix.Config {
122122

123123
return prefix.Config{
124124
HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger),
125-
MaxPodsPerPrefix: envutil.GetEnvInt("PREFIX_MAX_PODS_PER_PREFIX", prefix.DefaultMaxPodsPerPrefix, baseLogger),
126125
MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger),
127-
LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger),
126+
LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger),
128127
}
129128
}
130129

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package prefix
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"sync"
2223
"time"
2324

@@ -27,32 +28,23 @@ import (
2728
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2829
)
2930

30-
// podSet holds an LRU cache of servers that may have a specific prefix hash.
31-
type podSet struct {
32-
enteries *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp).
33-
}
34-
3531
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
36-
// prefix cached .
32+
// prefix cached.
3733
type indexer struct {
38-
mu sync.RWMutex
39-
cache *lru.Cache[BlockHash, *podSet]
40-
maxCacheSize int
41-
maxServersToMatch int
34+
mu sync.RWMutex
35+
hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
36+
podToLRU map[string]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
37+
maxLRUSize int
4238
}
4339

4440
// newIndexer initializes an indexer with size limits and starts cache size reporting.
45-
func newIndexer(maxCacheSize, maxServersToMatch int) *indexer {
46-
c, err := lru.New[BlockHash, *podSet](maxCacheSize)
47-
if err != nil {
48-
panic(err)
49-
}
41+
func newIndexer(maxLRUSize int) *indexer {
5042
ix := &indexer{
51-
cache: c,
52-
maxCacheSize: maxCacheSize,
53-
maxServersToMatch: maxServersToMatch,
43+
hashToPods: make(map[BlockHash]podSet),
44+
podToLRU: make(map[string]*lru.Cache[BlockHash, struct{}]),
45+
maxLRUSize: maxLRUSize,
5446
}
55-
go ix.ReportCacheSize(time.Second)
47+
go ix.ReportLRUSize(time.Second)
5648
return ix
5749
}
5850

@@ -61,51 +53,106 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
6153
if pod.Name == "" {
6254
return
6355
}
64-
6556
i.mu.Lock()
66-
defer i.mu.Unlock()
57+
// Check if the LRU pod exist
58+
podName := pod.String()
59+
lruForPod, exists := i.podToLRU[podName]
60+
if !exists {
61+
newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod))
62+
i.podToLRU[podName] = newLRU
63+
lruForPod = newLRU
64+
}
65+
i.mu.Unlock()
6766

67+
// Add to LRU (may evict)
6868
for _, hash := range hashes {
69-
p, ok := i.cache.Get(hash)
70-
if !ok {
71-
// Create podSet with new LRU
72-
podLRU, _ := lru.New[ServerID, struct{}](i.maxServersToMatch)
73-
p = &podSet{enteries: podLRU}
74-
i.cache.Add(hash, p)
75-
}
69+
lruForPod.Add(hash, struct{}{})
70+
}
7671

77-
p.enteries.Add(pod, struct{}{})
72+
// Update hashToPods once under lock
73+
i.mu.Lock()
74+
for _, hash := range hashes {
75+
pods := i.hashToPods[hash]
76+
if pods == nil {
77+
pods = make(podSet)
78+
}
79+
pods[pod] = struct{}{}
80+
i.hashToPods[hash] = pods
7881
}
82+
i.mu.Unlock()
83+
7984
}
8085

8186
// Get returns a set of servers that have the given prefix hash cached.
82-
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
87+
func (i *indexer) Get(hash BlockHash) podSet {
8388
i.mu.RLock()
8489
defer i.mu.RUnlock()
8590

86-
res := map[ServerID]bool{}
87-
pods, ok := i.cache.Get(hash)
91+
res := podSet{}
92+
pods, ok := i.hashToPods[hash]
8893
if !ok {
8994
return res
9095
}
91-
for _, pod := range pods.enteries.Keys() {
92-
res[pod] = true
96+
97+
return pods
98+
}
99+
100+
// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
101+
func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) {
102+
return func(hash BlockHash, _ struct{}) {
103+
fmt.Printf("Evicted hash %v from pod %s\n", hash, pod)
104+
105+
i.mu.Lock()
106+
defer i.mu.Unlock()
107+
print("enter eviction")
108+
// Remove the pod from the hash→pods map
109+
if podSet, ok := i.hashToPods[hash]; ok {
110+
delete(podSet, pod)
111+
if len(podSet) == 0 {
112+
delete(i.hashToPods, hash)
113+
} else {
114+
i.hashToPods[hash] = podSet
115+
}
116+
}
117+
print("After eviction")
93118
}
94-
return res
95119
}
96120

97-
// ReportCacheSize starts a goroutine that periodically reports the cache size metric.
98-
func (i *indexer) ReportCacheSize(interval time.Duration) {
121+
// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
122+
func (i *indexer) ReportLRUSize(interval time.Duration) {
99123
ticker := time.NewTicker(interval)
100124
defer ticker.Stop()
101125
for range ticker.C {
102126
i.mu.RLock()
103-
size := i.cache.Len()
104-
metrics.RecordPrefixCacheSize(int64(size))
105-
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU",
106-
"# entries", size,
107-
"prefix cache utilization [%]", float64(size)*100/float64(i.maxCacheSize),
127+
totalEntries := 0
128+
maxPodEntries := 0
129+
maxPodName := ""
130+
131+
for podName, lruCache := range i.podToLRU {
132+
size := lruCache.Len()
133+
totalEntries += size
134+
if size > maxPodEntries {
135+
maxPodEntries = size
136+
maxPodName = podName
137+
}
138+
}
139+
140+
numPods := len(i.podToLRU)
141+
avg := 0.0
142+
if numPods > 0 {
143+
avg = float64(totalEntries) / float64(numPods)
144+
}
145+
146+
metrics.RecordPrefixCacheSize(int64(totalEntries))
147+
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state",
148+
"total entries", totalEntries,
149+
"# pods", numPods,
150+
"avg entries per pod", avg,
151+
"pod with max cache", maxPodName,
152+
"max pod size", maxPodEntries,
153+
"global max LRU cache capacity per pod", i.maxLRUSize,
108154
)
155+
109156
i.mu.RUnlock()
110157
}
111158
}

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,25 @@ import (
2222
)
2323

2424
func TestIndexer_AddAndGet(t *testing.T) {
25-
i := newIndexer(2, 2)
25+
i := newIndexer(2)
2626

2727
hash1 := BlockHash(1)
2828
server := ServerID{Namespace: "default", Name: "server1"}
29-
29+
serverName := server.String()
3030
// Add an entry to the cache
3131
i.Add([]BlockHash{hash1}, server)
32-
3332
// Retrieve the entry
34-
assert.Equal(t, 1, i.cache.Len(), "Cache size should be 1 after adding an entry")
33+
assert.Equal(t, 1, i.podToLRU[serverName].Len(), "Cache size should be 1 after adding an entry")
3534
servers := i.Get(hash1)
3635
assert.Contains(t, servers, server, "Cache should contain the added server")
3736

3837
// Add another entry to the cache, the cache size should be incremented to 2.
3938
i.Add([]BlockHash{BlockHash(2)}, server)
40-
assert.Equal(t, 2, i.cache.Len(), "Cache size should be 2 after adding an entry")
39+
assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should be 2 after adding an entry")
4140

4241
// Add another entry to the cache, which should evict the first one due to max size.
42+
print("before Add")
4343
i.Add([]BlockHash{BlockHash(3)}, server)
44-
assert.Equal(t, 2, i.cache.Len(), "Cache size should still be 2 after adding an entry")
44+
print("after ADD")
45+
assert.Equal(t, 2, i.podToLRU[serverName].Len(), "Cache size should still be 2 after adding an entry")
4546
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ import (
3232

3333
const (
3434
DefaultScorerWeight = 1
35-
// DefaultMaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer.
36-
// This limits the number of recent pods associated with a given prefix to reduce memory usage
37-
// and ensure faster lookup. When the limit is reached, the least recently used pod is evicted.
38-
DefaultMaxPodsPerPrefix = 4
3935
// vLLM default token block size is 16, and a good guess of average characters per token is 4.
4036
DefaultHashBlockSize = 64
4137
// The maximum number of blocks to match. Two long requests with the same prefix up to this
@@ -44,16 +40,15 @@ const (
4440
// accuracy. Use a small value if most requests are short to reduce cache size and speed up the
4541
// matching process. Use a large value if most requests are long to increase the matching accuracy.
4642
DefaultMaxPrefixBlocks = 256
47-
// The indexer is an approximation to the actual prefix cache state on the model servers.
43+
// The indexer is an approximation to the actual prefix LRU cache state on the model servers per server (pod).
4844
// A small capacity ensures a high accuracy of cache hit on the model server, but it will
4945
// increase the chance of false negatives. A high capacity does the opposite.
5046
// To properly size this, consider the sum of the total number of cache entries on all model
5147
// servers. Consider the llama3 8B model on 8 H100 80GB GPUs. The size of the model weight is
5248
// about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each
5349
// token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16
54-
// in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 8 = 124.8K blocks, or
55-
// roughly 130K.
56-
DefaultLRUIndexerCapacity = 130000
50+
// in vLLM, we will have 250K / 16 = 15.6K blocks.
51+
DefaultLRUCapacityPerServer = 15000
5752
)
5853

5954
type Config struct {
@@ -63,19 +58,20 @@ type Config struct {
6358
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
6459
// be ignored.
6560
MaxPrefixBlocksToMatch int
66-
// MaxPodsPerPrefix defines the maximum number of pods (servers) to track per prefix hash in the LRU indexer.
67-
MaxPodsPerPrefix int
68-
// Max (approximate) size of the LRU indexer in number of entries.
69-
LRUIndexerCapacity int
61+
// Max (approximate) size of the LRU indexer in number of entries per server (pod).
62+
LRUCapacityPerServer int
7063
}
7164

7265
type Plugin struct {
7366
Config
7467
indexer Indexer
7568
}
7669

70+
// podSet holds an pods servers that may have a specific prefix hash.
71+
type podSet map[ServerID]struct{}
72+
7773
type Indexer interface {
78-
Get(hash BlockHash) map[ServerID]bool
74+
Get(hash BlockHash) podSet
7975
Add(hashes []BlockHash, server ServerID)
8076
}
8177

@@ -121,7 +117,7 @@ var _ framework.PostCycle = &Plugin{}
121117
func New(config Config) *Plugin {
122118
m := &Plugin{
123119
Config: config,
124-
indexer: newIndexer(config.LRUIndexerCapacity, config.MaxPodsPerPrefix),
120+
indexer: newIndexer(config.LRUCapacityPerServer),
125121
}
126122
return m
127123
}
@@ -135,7 +131,7 @@ func (m *Plugin) Name() string {
135131
func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
136132
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
137133
// pre score step, hashing prompt and find longest prefix match.
138-
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPodsPerPrefix)
134+
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
139135
state := &schedulingContextState{
140136
PrefixHashes: hashes,
141137
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ func TestPrefixPlugin(t *testing.T) {
3535
config := Config{
3636
HashBlockSize: 4,
3737
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
38-
LRUIndexerCapacity: DefaultLRUIndexerCapacity,
39-
MaxPodsPerPrefix: DefaultMaxPodsPerPrefix,
38+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
4039
}
4140
plugin := New(config)
4241

@@ -150,8 +149,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
150149
config := Config{
151150
HashBlockSize: blockSize,
152151
MaxPrefixBlocksToMatch: maxPrefixBlocks,
153-
LRUIndexerCapacity: DefaultLRUIndexerCapacity,
154-
MaxPodsPerPrefix: DefaultMaxPodsPerPrefix,
152+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
155153
}
156154

157155
plugin := New(config)

site-src/guides/epp-configuration/prefix-aware.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ extremely long inputs.
3232
128 (or 128*64=8192 characters, or roughly 2048 tokens). This is useful to tradeoff prefix match accuracy
3333
for performance.
3434

35-
* `PREFIX_CACHE_LRU_CAPACITY`: Maximum capacity the prefix LRU indexer in number of block hashes. Below
35+
* `PREFIX_CACHE_LRU_CAPACITY_PER_SERVER`: Maximum capacity the prefix LRU cache in number of block hashes per server (pod). Below
3636
shows a detailed analysis on how to estimate this.
37-
* `PREFIX_MAX_PODS_PER_PREFIX`: Defines the maximum number of pods (servers) tracked per prefix hash in the internal LRU cache.
38-
This setting helps optimize memory usage by retaining only the hottest (most recently active) pods for each prefix.
39-
When the limit is reached, older pods are evicted based on least-recently-used (LRU) order.
37+
4038

4139

4240
The prefix cache plugin estimates the prefix cache indexes in model server HBMs. In the perfect
@@ -68,7 +66,6 @@ When the limit is reached, older pods are evicted based on least-recently-used (
6866
# assume avg_chars_per_token = 4, prefix_indexer_hash_block_size = 64 (default)
6967
# each entry is about 358KB, so the memory footrpint is abut 11 MB per server
7068
lru_indexer_capacity_per_server = 500,000*4/64 = 31250
71-
lru_indexer_capacity_total = 3 * 31250 = 93750
7269
```
7370

7471
See the [Use Helm section](#helm) to install an inferencepool with the environment variables.
@@ -87,7 +84,7 @@ $ helm install triton-llama3-8b-instruct \
8784
--set provider.name=[none|gke] \
8885
--set inferenceExtension.env.EXPERIMENTAL_USE_SCHEDULER_V2=true \
8986
--set inferenceExtension.env.ENABLE_PREFIX_CACHE_SCHEDULING=true \
90-
--set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY=93750 \
87+
--set inferenceExtension.env.PREFIX_CACHE_LRU_CAPACITY_PER_SERVER=31250 \
9188
--set inferenceExtension.env.PREFIX_CACHE_MAX_PREFIX_BLOCKS=1024 \
9289
oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/inferencepool --version v0
9390
```

0 commit comments

Comments
 (0)