@@ -209,14 +209,31 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
209
209
value * Result [V ]
210
210
}
211
211
212
- // lock to prevent duplicate keys coming in before item has been added to cache.
212
+ // We need to lock both the batchLock and cacheLock because the batcher can
213
+ // reset the cache when either the batchCap or the wait time is reached.
214
+ //
215
+ // When we would only lock the cacheLock while doing l.cache.Get and/or
216
+ // l.cache.Set, it could be that the batcher resets the cache after those
217
+ // operations have finished but before the new request (if any) is send to the
218
+ // batcher.
219
+ //
220
+ // In that case it is no longer guaranteed that the keys passed to the BatchFunc
221
+ // function are unique as the cache has been reset so if the same key is
222
+ // requested again before the new batcher is started, the same key will be
223
+ // send to the batcher again causing unexpected behavior in the BatchFunc.
224
+ l .batchLock .Lock ()
213
225
l .cacheLock .Lock ()
226
+
214
227
if v , ok := l .cache .Get (ctx , key ); ok {
228
+ l .cacheLock .Unlock ()
229
+ l .batchLock .Unlock ()
215
230
defer finish (v )
216
- defer l .cacheLock .Unlock ()
217
231
return v
218
232
}
219
233
234
+ defer l .batchLock .Unlock ()
235
+ defer l .cacheLock .Unlock ()
236
+
220
237
thunk := func () (V , error ) {
221
238
result .mu .RLock ()
222
239
resultNotSet := result .value == nil
@@ -240,13 +257,11 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
240
257
defer finish (thunk )
241
258
242
259
l .cache .Set (ctx , key , thunk )
243
- l .cacheLock .Unlock ()
244
260
245
261
// this is sent to batch fn. It contains the key and the channel to return
246
262
// the result on
247
263
req := & batchRequest [K , V ]{key , c }
248
264
249
- l .batchLock .Lock ()
250
265
// start the batch window if it hasn't already started.
251
266
if l .curBatcher == nil {
252
267
l .curBatcher = l .newBatcher (l .silent , l .tracer )
@@ -274,7 +289,6 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
274
289
l .reset ()
275
290
}
276
291
}
277
- l .batchLock .Unlock ()
278
292
279
293
return thunk
280
294
}
0 commit comments