@@ -20,154 +20,92 @@ import (
20
20
"context"
21
21
"sync"
22
22
"time"
23
- "unsafe"
24
-
25
- "container/list"
26
23
24
+ lru "github.com/hashicorp/golang-lru/v2"
27
25
"sigs.k8s.io/controller-runtime/pkg/log"
28
26
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
29
27
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
30
28
)
31
29
32
- func newIndexer (maxCacheSize int ) * indexer {
33
- t := & indexer {
34
- maxCacheSize : maxCacheSize ,
35
- table : make (map [BlockHash ]map [ServerID ]* list.Element ),
36
- ll : list .New (),
37
- }
38
- go t .ReportCacheSize (time .Second )
39
- return t
30
+ // block holds an LRU cache of servers that may have a specific prefix hash.
31
+ type block struct {
32
+ Pods * lru.Cache [ServerID , struct {}] // Can be extended with metadata (e.g., timestamp).
40
33
}
41
34
42
35
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
43
36
// prefix cached .
44
37
type indexer struct {
45
- mu sync.RWMutex
46
- maxCacheSize int
47
- table map [ BlockHash ] map [ ServerID ] * list. Element // from any prefix cache to the cache entry to find the server
48
- ll * list. List // LinkedList to keep track of the order of entries
38
+ mu sync.RWMutex
39
+ cache * lru. Cache [ BlockHash , * block ]
40
+ maxCacheSize int
41
+ maxServersToMatch int
49
42
}
50
43
51
- // value is the value stored in the linked list.
52
- type value struct {
53
- server ServerID
54
- hash BlockHash
44
+ // newIndexer initializes an indexer with size limits and starts cache size reporting.
45
+ func newIndexer (maxCacheSize , maxServersToMatch int ) * indexer {
46
+ c , err := lru.New [BlockHash , * block ](maxCacheSize )
47
+ if err != nil {
48
+ panic (err )
49
+ }
50
+ ix := & indexer {
51
+ cache : c ,
52
+ maxCacheSize : maxCacheSize ,
53
+ maxServersToMatch : maxServersToMatch ,
54
+ }
55
+ go ix .ReportCacheSize (time .Second )
56
+ return ix
55
57
}
56
58
57
- // Get returns the set of servers that have the given prefix hash cached.
58
- func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
59
- i .mu .RLock ()
60
- defer i .mu .RUnlock ()
61
- res := map [ServerID ]bool {}
62
- for server := range i .table [hash ] {
63
- res [server ] = true
59
+ // Add adds a list of prefix hashes to the cache, tied to the server.
60
+ func (i * indexer ) Add (hashes []BlockHash , pod ServerID ) {
61
+ if len (hashes ) == 0 || pod .Name == "" {
62
+ return
64
63
}
65
- return res
66
- }
67
64
68
- // Add adds a list of prefix hashes of a single request to the server the request was sent to.
69
- // The intuition is that this server is likely to have the prefix cached, so next time a request
70
- // sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71
- func (i * indexer ) Add (hashes []BlockHash , server ServerID ) {
72
65
i .mu .Lock ()
73
66
defer i .mu .Unlock ()
74
- for _ , hash := range hashes {
75
- i .add (hash , server )
76
- }
77
- }
78
-
79
- func (i * indexer ) check (hash BlockHash , server ServerID ) (* list.Element , bool ) {
80
- servers , ok := i .table [hash ]
81
- if ! ok {
82
- return nil , false
83
- }
84
- e , ok := servers [server ]
85
- return e , ok
86
- }
87
67
88
- func (i * indexer ) add (hash BlockHash , server ServerID ) {
89
- e , exists := i .check (hash , server )
90
- if exists {
91
- i .ll .MoveToBack (e )
92
- } else {
93
- i .create (hash , server )
68
+ for _ , hash := range hashes {
69
+ b , ok := i .cache .Get (hash )
70
+ if ! ok {
71
+ // Create block with new LRU
72
+ podLRU , _ := lru.New [ServerID , struct {}](i .maxServersToMatch )
73
+ b = & block {Pods : podLRU }
74
+ i .cache .Add (hash , b )
75
+ }
76
+
77
+ b .Pods .Add (pod , struct {}{})
94
78
}
95
79
}
96
80
97
- func (i * indexer ) create (hash BlockHash , server ServerID ) {
98
- for i .ll .Len () >= i .maxCacheSize {
99
- // Evict the least recently used entry if we've exceeded the max cache size
100
- i .evict ()
101
- }
102
-
103
- if _ , ok := i .table [hash ]; ! ok {
104
- i .table [hash ] = make (map [ServerID ]* list.Element )
105
- }
106
- v := & value {
107
- server : server ,
108
- hash : hash ,
109
- }
110
- e := i .ll .PushBack (v )
111
- i.table [hash ][server ] = e
112
- }
81
+ // Get returns a set of servers that have the given prefix hash cached.
82
+ func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
83
+ i .mu .RLock ()
84
+ defer i .mu .RUnlock ()
113
85
114
- // evict removes the least recently used entry from the cache
115
- func (i * indexer ) evict () {
116
- oldestNode := i .ll .Front ()
117
- if oldestNode == nil {
118
- return
86
+ res := map [ServerID ]bool {}
87
+ block , ok := i .cache .Get (hash )
88
+ if ! ok {
89
+ return res
119
90
}
120
- i .ll .Remove (oldestNode )
121
-
122
- v := oldestNode .Value .(* value )
123
- hash := v .hash
124
- server := v .server
125
- // Remove from the hash map
126
- serverMap := i .table [hash ]
127
- delete (serverMap , server )
128
-
129
- // If this was the last server for this hash, remove the hash entry entirely
130
- if len (serverMap ) == 0 {
131
- delete (i .table , hash )
91
+ for _ , pod := range block .Pods .Keys () {
92
+ res [pod ] = true
132
93
}
133
-
134
- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Evicted LRU entry" , "hash" , hash , "server" , server )
94
+ return res
135
95
}
136
96
137
- // ReportCacheSize starts a goroutine that periodically reports the cache size metric
97
+ // ReportCacheSize starts a goroutine that periodically reports the cache size metric.
138
98
func (i * indexer ) ReportCacheSize (interval time.Duration ) {
139
99
ticker := time .NewTicker (interval )
140
100
defer ticker .Stop ()
141
101
for range ticker .C {
142
102
i .mu .RLock ()
143
- metrics .RecordPrefixCacheSize (int64 (i .ll .Len ()))
144
- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .ll .Len (), "estimated size MB" , i .ll .Len ()* i .estimateEntrySize ()/ 1000000 )
103
+ size := i .cache .Len ()
104
+ metrics .RecordPrefixCacheSize (int64 (size ))
105
+ log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" ,
106
+ "# entries" , size ,
107
+ "prefix cache utilization [%]" , float64 (size )* 100 / float64 (i .maxCacheSize ),
108
+ )
145
109
i .mu .RUnlock ()
146
110
}
147
111
}
148
-
149
- // estimateEntrySize estimates the memory size of a cache entry in bytes.
150
- func (i * indexer ) estimateEntrySize () int {
151
- size := 0
152
-
153
- // Estimate the size of a node in the linked list.
154
- // First get the size of the node struct via unsafe.Sizeof.
155
- // The prev and next pointers are 8 bytes each on a 64-bit system.
156
- // The BlockHash is a uint64, which is 8 bytes.
157
- // The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158
- // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159
- // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160
- size += int (unsafe .Sizeof (value {}))
161
- // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162
- size += 2 * 63
163
-
164
- // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165
- size += 8 // Size of the BlockHash (uint64).
166
- size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167
- size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168
- size += 8 // Size of the pointer to the node in the hash map.
169
-
170
- // Based on the above estimates, the estimated size of an entry is:
171
- // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172
- return size
173
- }
0 commit comments