Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func loadPrefixCacheConfig() prefix.Config {
return prefix.Config{
HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger),
MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger),
LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger),
LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger),
}
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/go-logr/logr v1.4.3
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/onsi/ginkgo/v2 v2.23.4
github.com/onsi/gomega v1.37.0
github.com/prometheus/client_golang v1.22.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
Expand Down
208 changes: 92 additions & 116 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,154 +20,130 @@ import (
"context"
"sync"
"time"
"unsafe"

"container/list"

lru "github.com/hashicorp/golang-lru/v2"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

func newIndexer(maxCacheSize int) *indexer {
t := &indexer{
maxCacheSize: maxCacheSize,
table: make(map[BlockHash]map[ServerID]*list.Element),
ll: list.New(),
}
go t.ReportCacheSize(time.Second)
return t
}

// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
// prefix cached .
// prefix cached.
type indexer struct {
mu sync.RWMutex
maxCacheSize int
table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server
ll *list.List // LinkedList to keep track of the order of entries
mu sync.RWMutex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: another optimization is to use different mutex for the hashToPods and podToLRU, but I don't think it's very important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I used it only for the hashToPods operation, except in the ReportLRUSize, which we can remove if it hurts the performance

hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
maxLRUSize int
}

// value is the value stored in the linked list.
type value struct {
server ServerID
hash BlockHash
}

// Get returns the set of servers that have the given prefix hash cached.
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
i.mu.RLock()
defer i.mu.RUnlock()
res := map[ServerID]bool{}
for server := range i.table[hash] {
res[server] = true
// newIndexer initializes an indexer with size limits and starts cache size reporting.
func newIndexer(maxLRUSize int) *indexer {
ix := &indexer{
hashToPods: make(map[BlockHash]podSet),
podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]),
maxLRUSize: maxLRUSize,
}
return res

go ix.ReportLRUSize(time.Second)
return ix
}

// Add adds a list of prefix hashes of a single request to the server the request was sent to.
// The intuition is that this server is likely to have the prefix cached, so next time a request
// sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
func (i *indexer) Add(hashes []BlockHash, server ServerID) {
// Add adds a list of prefix hashes to the cache, tied to the server.
func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
i.mu.Lock()
defer i.mu.Unlock()
for _, hash := range hashes {
i.add(hash, server)
// Check if the LRU pod exist
lruForPod, exists := i.podToLRU[pod]
if !exists {
newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod))
i.podToLRU[pod] = newLRU
lruForPod = newLRU
}
}

func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) {
servers, ok := i.table[hash]
if !ok {
return nil, false
i.mu.Unlock()

// Add to LRU (may evict)
for _, hash := range hashes {
lruForPod.Add(hash, struct{}{})
}
e, ok := servers[server]
return e, ok
}

func (i *indexer) add(hash BlockHash, server ServerID) {
e, exists := i.check(hash, server)
if exists {
i.ll.MoveToBack(e)
} else {
i.create(hash, server)
// Update hashToPods once under lock
i.mu.Lock()
for _, hash := range hashes {
pods := i.hashToPods[hash]
if pods == nil {
pods = make(podSet)
}
pods[pod] = struct{}{}
i.hashToPods[hash] = pods
}

i.mu.Unlock()
}

func (i *indexer) create(hash BlockHash, server ServerID) {
for i.ll.Len() >= i.maxCacheSize {
// Evict the least recently used entry if we've exceeded the max cache size
i.evict()
}
// Get returns a set of servers that have the given prefix hash cached.
func (i *indexer) Get(hash BlockHash) podSet {
i.mu.RLock()
defer i.mu.RUnlock()

if _, ok := i.table[hash]; !ok {
i.table[hash] = make(map[ServerID]*list.Element)
}
v := &value{
server: server,
hash: hash,
res := podSet{}
pods, ok := i.hashToPods[hash]
if !ok {
return res
}
e := i.ll.PushBack(v)
i.table[hash][server] = e

return pods
}

// evict removes the least recently used entry from the cache
func (i *indexer) evict() {
oldestNode := i.ll.Front()
if oldestNode == nil {
return
// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) {
return func(hash BlockHash, _ struct{}) {
i.mu.Lock()
defer i.mu.Unlock()
// Remove the pod from the hash→pods map
if podSet, ok := i.hashToPods[hash]; ok {
delete(podSet, pod)
if len(podSet) == 0 {
delete(i.hashToPods, hash)
}
}
}
i.ll.Remove(oldestNode)

v := oldestNode.Value.(*value)
hash := v.hash
server := v.server
// Remove from the hash map
serverMap := i.table[hash]
delete(serverMap, server)

// If this was the last server for this hash, remove the hash entry entirely
if len(serverMap) == 0 {
delete(i.table, hash)
}

log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server)
}

// ReportCacheSize starts a goroutine that periodically reports the cache size metric
func (i *indexer) ReportCacheSize(interval time.Duration) {
// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
func (i *indexer) ReportLRUSize(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
i.mu.RLock()
metrics.RecordPrefixCacheSize(int64(i.ll.Len()))
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000)
totalEntries := 0
maxPodEntries := 0
maxPodName := ServerID{}

for pod, lruCache := range i.podToLRU {
size := lruCache.Len()
totalEntries += size
if size > maxPodEntries {
maxPodEntries = size
maxPodName = pod
}
}

numPods := len(i.podToLRU)
avg := 0.0
if numPods > 0 {
avg = float64(totalEntries) / float64(numPods)
}

metrics.RecordPrefixCacheSize(int64(totalEntries))
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state",
"total entries", totalEntries,
"# pods", numPods,
"avg entries per pod", avg,
"pod with max cache", maxPodName,
"max pod size", maxPodEntries,
"global max LRU cache capacity per pod", i.maxLRUSize,
)

i.mu.RUnlock()
}
}

// estimateEntrySize estimates the memory size of a cache entry in bytes.
func (i *indexer) estimateEntrySize() int {
size := 0

// Estimate the size of a node in the linked list.
// First get the size of the node struct via unsafe.Sizeof.
// The prev and next pointers are 8 bytes each on a 64-bit system.
// The BlockHash is a uint64, which is 8 bytes.
// The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
// The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
// So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
size += int(unsafe.Sizeof(value{}))
// Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
size += 2 * 63

// Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
size += 8 // Size of the BlockHash (uint64).
size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
size += 8 // Size of the pointer to the node in the hash map.

// Based on the above estimates, the estimated size of an entry is:
// (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
return size
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,23 @@ import (
)

func TestIndexer_AddAndGet(t *testing.T) {
cache := newIndexer(2)
i := newIndexer(2)

hash1 := BlockHash(1)
server := ServerID{Namespace: "default", Name: "server1"}

// Add an entry to the cache
cache.Add([]BlockHash{hash1}, server)
i.Add([]BlockHash{hash1}, server)

// Retrieve the entry
assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry")
servers := cache.Get(hash1)
assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry")
servers := i.Get(hash1)
assert.Contains(t, servers, server, "Cache should contain the added server")

// Add another entry to the cache, the cache size should be incremented to 2.
cache.Add([]BlockHash{BlockHash(2)}, server)
assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry")
i.Add([]BlockHash{BlockHash(2)}, server)
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry")

// Add another entry to the cache, which should evict the first one due to max size.
cache.Add([]BlockHash{BlockHash(3)}, server)
assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry")
i.Add([]BlockHash{BlockHash(3)}, server)
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry")
}
Loading