Skip to content

Commit 7dc0b4a

Browse files
kfirtoledovMaroon
andcommitted
refactor: Replace prefix cache structure with golang-lru
Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com> Co-authored-by: Maroon Ayoub <maroon.ayoub@ibm.com>
1 parent b62931e commit 7dc0b4a

File tree

7 files changed

+136
-142
lines changed

7 files changed

+136
-142
lines changed

cmd/epp/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ func loadPrefixCacheConfig() prefix.Config {
120120

121121
return prefix.Config{
122122
HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger),
123+
MaxNumServersToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_SERVER_TO_MATCH", prefix.DefaultNumServersToMatch, baseLogger),
123124
MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger),
124125
LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger),
125126
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/go-logr/logr v1.4.3
1010
github.com/google/go-cmp v0.7.0
1111
github.com/google/uuid v1.6.0
12+
github.com/hashicorp/golang-lru/v2 v2.0.7
1213
github.com/onsi/ginkgo/v2 v2.23.4
1314
github.com/onsi/gomega v1.37.0
1415
github.com/prometheus/client_golang v1.22.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T
9595
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
9696
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE=
9797
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI=
98+
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
99+
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
98100
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
99101
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
100102
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go

Lines changed: 53 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -20,154 +20,92 @@ import (
2020
"context"
2121
"sync"
2222
"time"
23-
"unsafe"
24-
25-
"container/list"
2623

24+
lru "github.com/hashicorp/golang-lru/v2"
2725
"sigs.k8s.io/controller-runtime/pkg/log"
2826
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2927
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3028
)
3129

32-
func newIndexer(maxCacheSize int) *indexer {
33-
t := &indexer{
34-
maxCacheSize: maxCacheSize,
35-
table: make(map[BlockHash]map[ServerID]*list.Element),
36-
ll: list.New(),
37-
}
38-
go t.ReportCacheSize(time.Second)
39-
return t
30+
// block holds an LRU cache of servers that may have a specific prefix hash.
31+
type block struct {
32+
Pods *lru.Cache[ServerID, struct{}] // Can be extended with metadata (e.g., timestamp).
4033
}
4134

4235
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
4336
// prefix cached .
4437
type indexer struct {
45-
mu sync.RWMutex
46-
maxCacheSize int
47-
table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server
48-
ll *list.List // LinkedList to keep track of the order of entries
38+
mu sync.RWMutex
39+
cache *lru.Cache[BlockHash, *block]
40+
maxCacheSize int
41+
maxServersToMatch int
4942
}
5043

51-
// value is the value stored in the linked list.
52-
type value struct {
53-
server ServerID
54-
hash BlockHash
44+
// newIndexer initializes an indexer with size limits and starts cache size reporting.
45+
func newIndexer(maxCacheSize, maxServersToMatch int) *indexer {
46+
c, err := lru.New[BlockHash, *block](maxCacheSize)
47+
if err != nil {
48+
panic(err)
49+
}
50+
ix := &indexer{
51+
cache: c,
52+
maxCacheSize: maxCacheSize,
53+
maxServersToMatch: maxServersToMatch,
54+
}
55+
go ix.ReportCacheSize(time.Second)
56+
return ix
5557
}
5658

57-
// Get returns the set of servers that have the given prefix hash cached.
58-
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
59-
i.mu.RLock()
60-
defer i.mu.RUnlock()
61-
res := map[ServerID]bool{}
62-
for server := range i.table[hash] {
63-
res[server] = true
59+
// Add adds a list of prefix hashes to the cache, tied to the server.
60+
func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
61+
if len(hashes) == 0 || pod.Name == "" {
62+
return
6463
}
65-
return res
66-
}
6764

68-
// Add adds a list of prefix hashes of a single request to the server the request was sent to.
69-
// The intuition is that this server is likely to have the prefix cached, so next time a request
70-
// sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71-
func (i *indexer) Add(hashes []BlockHash, server ServerID) {
7265
i.mu.Lock()
7366
defer i.mu.Unlock()
74-
for _, hash := range hashes {
75-
i.add(hash, server)
76-
}
77-
}
78-
79-
func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) {
80-
servers, ok := i.table[hash]
81-
if !ok {
82-
return nil, false
83-
}
84-
e, ok := servers[server]
85-
return e, ok
86-
}
8767

88-
func (i *indexer) add(hash BlockHash, server ServerID) {
89-
e, exists := i.check(hash, server)
90-
if exists {
91-
i.ll.MoveToBack(e)
92-
} else {
93-
i.create(hash, server)
68+
for _, hash := range hashes {
69+
b, ok := i.cache.Get(hash)
70+
if !ok {
71+
// Create block with new LRU
72+
podLRU, _ := lru.New[ServerID, struct{}](i.maxServersToMatch)
73+
b = &block{Pods: podLRU}
74+
i.cache.Add(hash, b)
75+
}
76+
77+
b.Pods.Add(pod, struct{}{})
9478
}
9579
}
9680

97-
func (i *indexer) create(hash BlockHash, server ServerID) {
98-
for i.ll.Len() >= i.maxCacheSize {
99-
// Evict the least recently used entry if we've exceeded the max cache size
100-
i.evict()
101-
}
102-
103-
if _, ok := i.table[hash]; !ok {
104-
i.table[hash] = make(map[ServerID]*list.Element)
105-
}
106-
v := &value{
107-
server: server,
108-
hash: hash,
109-
}
110-
e := i.ll.PushBack(v)
111-
i.table[hash][server] = e
112-
}
81+
// Get returns a set of servers that have the given prefix hash cached.
82+
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
83+
i.mu.RLock()
84+
defer i.mu.RUnlock()
11385

114-
// evict removes the least recently used entry from the cache
115-
func (i *indexer) evict() {
116-
oldestNode := i.ll.Front()
117-
if oldestNode == nil {
118-
return
86+
res := map[ServerID]bool{}
87+
block, ok := i.cache.Get(hash)
88+
if !ok {
89+
return res
11990
}
120-
i.ll.Remove(oldestNode)
121-
122-
v := oldestNode.Value.(*value)
123-
hash := v.hash
124-
server := v.server
125-
// Remove from the hash map
126-
serverMap := i.table[hash]
127-
delete(serverMap, server)
128-
129-
// If this was the last server for this hash, remove the hash entry entirely
130-
if len(serverMap) == 0 {
131-
delete(i.table, hash)
91+
for _, pod := range block.Pods.Keys() {
92+
res[pod] = true
13293
}
133-
134-
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server)
94+
return res
13595
}
13696

137-
// ReportCacheSize starts a goroutine that periodically reports the cache size metric
97+
// ReportCacheSize starts a goroutine that periodically reports the cache size metric.
13898
func (i *indexer) ReportCacheSize(interval time.Duration) {
13999
ticker := time.NewTicker(interval)
140100
defer ticker.Stop()
141101
for range ticker.C {
142102
i.mu.RLock()
143-
metrics.RecordPrefixCacheSize(int64(i.ll.Len()))
144-
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000)
103+
size := i.cache.Len()
104+
metrics.RecordPrefixCacheSize(int64(size))
105+
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU",
106+
"# entries", size,
107+
"prefix cache utilization [%]", float64(size)*100/float64(i.maxCacheSize),
108+
)
145109
i.mu.RUnlock()
146110
}
147111
}
148-
149-
// estimateEntrySize estimates the memory size of a cache entry in bytes.
150-
func (i *indexer) estimateEntrySize() int {
151-
size := 0
152-
153-
// Estimate the size of a node in the linked list.
154-
// First get the size of the node struct via unsafe.Sizeof.
155-
// The prev and next pointers are 8 bytes each on a 64-bit system.
156-
// The BlockHash is a uint64, which is 8 bytes.
157-
// The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158-
// The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159-
// So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160-
size += int(unsafe.Sizeof(value{}))
161-
// Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162-
size += 2 * 63
163-
164-
// Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165-
size += 8 // Size of the BlockHash (uint64).
166-
size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167-
size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168-
size += 8 // Size of the pointer to the node in the hash map.
169-
170-
// Based on the above estimates, the estimated size of an entry is:
171-
// (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172-
return size
173-
}

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
)
2323

2424
func TestIndexer_AddAndGet(t *testing.T) {
25-
cache := newIndexer(2)
25+
cache := newIndexer(2, 2)
2626

2727
hash1 := BlockHash(1)
2828
server := ServerID{Namespace: "default", Name: "server1"}
@@ -31,15 +31,15 @@ func TestIndexer_AddAndGet(t *testing.T) {
3131
cache.Add([]BlockHash{hash1}, server)
3232

3333
// Retrieve the entry
34-
assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry")
34+
assert.Equal(t, 1, cache.cache.Len(), "Cache size should be 1 after adding an entry")
3535
servers := cache.Get(hash1)
3636
assert.Contains(t, servers, server, "Cache should contain the added server")
3737

3838
// Add another entry to the cache, the cache size should be incremented to 2.
3939
cache.Add([]BlockHash{BlockHash(2)}, server)
40-
assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry")
40+
assert.Equal(t, 2, cache.cache.Len(), "Cache size should be 2 after adding an entry")
4141

4242
// Add another entry to the cache, which should evict the first one due to max size.
4343
cache.Add([]BlockHash{BlockHash(3)}, server)
44-
assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry")
44+
assert.Equal(t, 2, cache.cache.Len(), "Cache size should still be 2 after adding an entry")
4545
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,25 @@ const (
3636
// Why not just return the server with longest prefix match?
3737
// It may not be the optimal choice, e.g., it may have a high queue depth.
3838
// We optimistically search more than one to give more candidates for the scheduler to choose.
39-
DefaultNumServersToMatch = 2
39+
DefaultNumServersToMatch = 16
4040
// vLLM default token block size is 16, and a good guess of average characters per token is 4.
4141
DefaultHashBlockSize = 64
4242
// The maximum number of blocks to match. Two long requests with the same prefix up to this
4343
// limit will be indistinguishable.
4444
// This parameter provides a trade-off between cache size, prefix matching speed and matching
4545
// accuracy. Use a small value if most requests are short to reduce cache size and speed up the
4646
// matching process. Use a large value if most requests are long to increase the matching accuracy.
47-
DefaultMaxPrefixBlocks = 128
47+
DefaultMaxPrefixBlocks = 256
4848
// The indexer is an approximation to the actual prefix cache state on the model servers.
4949
// A small capacity ensures a high accuracy of cache hit on the model server, but it will
5050
// increase the chance of false negatives. A high capacity does the opposite.
5151
// To properly size this, consider the sum of the total number of cache entries on all model
52-
// servers. Consider the llama3 8B model on 3 H100 80GB GPUs. The size of the model weight is
52+
// servers. Consider the llama3 8B model on 8 H100 80GB GPUs. The size of the model weight is
5353
// about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each
5454
// token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16
55-
// in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 3 = 46.8K blocks, or
56-
// roughly 50K.
57-
// How much memory space does it require to hold the 50K block hashes?
58-
// According to the estimates in indexer.estimateEntrySize(), the size of each entry is
59-
// approximately 348 bytes. So in total we have 50K * 348 = 17.4MB.
60-
DefaultLRUIndexerCapacity = 50000
55+
// in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 8 = 124.8K blocks, or
56+
// roughly 130K.
57+
DefaultLRUIndexerCapacity = 130000
6158
)
6259

6360
type Config struct {
@@ -67,6 +64,8 @@ type Config struct {
6764
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
6865
// be ignored.
6966
MaxPrefixBlocksToMatch int
67+
// NumServersToMatch is the maximum number that can match per hash BlockHash.
68+
MaxNumServersToMatch int
7069
// Max (approximate) size of the LRU indexer in number of entries.
7170
LRUIndexerCapacity int
7271
}
@@ -123,7 +122,7 @@ var _ framework.PostCycle = &Plugin{}
123122
func New(config Config) *Plugin {
124123
m := &Plugin{
125124
Config: config,
126-
indexer: newIndexer(config.LRUIndexerCapacity),
125+
indexer: newIndexer(config.LRUIndexerCapacity, config.MaxNumServersToMatch),
127126
}
128127
return m
129128
}
@@ -138,14 +137,11 @@ func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleStat
138137
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
139138
// pre score step, hashing prompt and find longest prefix match.
140139
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
141-
numServers := DefaultNumServersToMatch
142-
if numServers > len(pods) {
143-
numServers = len(pods)
144-
}
145140
state := &schedulingContextState{
146141
PrefixHashes: hashes,
147-
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, numServers),
142+
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes),
148143
}
144+
149145
cycleState.Write(types.StateKey(m.Name()), state)
150146
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
151147
// calculate the scores of pods
@@ -181,22 +177,22 @@ func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, re
181177
}
182178

183179
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
184-
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash, numServers int) map[ServerID]int {
180+
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
185181
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
186182
res := make(map[ServerID]int)
187183
// Use a greedy strategy to search from the longest prefix.
188184
// NOTE: It's possible to further optimize this with a binary search.
189-
for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- {
185+
for i := 0; i < len(hashes); i++ {
190186
hash := hashes[i]
191187
cachedServers := m.indexer.Get(hash)
192-
if len(cachedServers) > 0 {
188+
if len(cachedServers) == 0 {
189+
break
190+
} else {
193191
loggerTrace.Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i)
194192
for server := range cachedServers {
195193
// Update servers with their longest prefix match.
196-
// If we already found this server with longer prefix match, don't update it.
197-
if _, ok := res[server]; !ok {
198-
res[server] = i + 1
199-
}
194+
res[server]++
195+
200196
}
201197
}
202198
}

0 commit comments

Comments
 (0)