Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ func (c *NoCache[K, V]) Delete(context.Context, K) bool { return false }

// Clear is a NOOP
func (c *NoCache[K, V]) Clear() { return }

type DataCache[K comparable, V any] interface {
Get(context.Context, K) (V, bool)
Set(context.Context, K, V)
}
68 changes: 59 additions & 9 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Interface[K comparable, V any] interface {
// It's important that the length of the input keys matches the length of the output results.
//
// The keys passed to this function are guaranteed to be unique
type BatchFunc[K comparable, V any] func(context.Context, []K) []*Result[V]
type BatchFunc[K comparable, V any] func(context.Context, Keys[K]) []*Result[V]

// Result is the data structure that a BatchFunc returns.
// It contains the resolved data, and any errors that may have occurred while fetching the data.
Expand Down Expand Up @@ -72,6 +72,8 @@ type Loader[K comparable, V any] struct {
// implementation could be used as long as it implements the `Cache` interface.
cacheLock sync.Mutex
cache Cache[K, V]

dataCache DataCache[K, V]
// should we clear the cache on each batch?
// this would allow batching but no long term caching
clearCacheOnBatch bool
Expand Down Expand Up @@ -112,6 +114,7 @@ type ThunkMany[V any] func() ([]V, []error)

// type used to on input channel
type batchRequest[K comparable, V any] struct {
ctx context.Context
key K
channel chan *Result[V]
}
Expand All @@ -126,6 +129,13 @@ func WithCache[K comparable, V any](c Cache[K, V]) Option[K, V] {
}
}

// WithDataCache sets the BatchLoader cache for data (not thunk)
func WithDataCache[K comparable, V any](c DataCache[K, V]) Option[K, V] {
return func(l *Loader[K, V]) {
l.dataCache = c
}
}

// WithBatchCapacity sets the batch capacity. Default is 0 (unbounded).
func WithBatchCapacity[K comparable, V any](c int) Option[K, V] {
return func(l *Loader[K, V]) {
Expand Down Expand Up @@ -201,7 +211,7 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts ...Opti
// The first context passed to this function within a given batch window will be provided to
// the registered BatchFunc.
func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
ctx, finish := l.tracer.TraceLoad(originalContext, key)
ctx, finish := l.tracer.TraceLoad(originalContext, ContextKey(originalContext, key))

c := make(chan *Result[V], 1)
var result struct {
Expand Down Expand Up @@ -244,12 +254,12 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {

// this is sent to batch fn. It contains the key and the channel to return
// the result on
req := &batchRequest[K, V]{key, c}
req := &batchRequest[K, V]{ctx, key, c}

l.batchLock.Lock()
// start the batch window if it hasn't already started.
if l.curBatcher == nil {
l.curBatcher = l.newBatcher(l.silent, l.tracer)
l.curBatcher = l.newBatcher(l.silent, l.tracer, l.dataCache)
// start the current batcher batch function
go l.curBatcher.batch(originalContext)
// start a sleeper for the current batcher
Expand Down Expand Up @@ -281,7 +291,7 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {

// LoadMany loads multiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in.
func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) ThunkMany[V] {
ctx, finish := l.tracer.TraceLoadMany(originalContext, keys)
ctx, finish := l.tracer.TraceLoadMany(originalContext, ContextKeys(originalContext, keys))

var (
length = len(keys)
Expand Down Expand Up @@ -391,16 +401,18 @@ type batcher[K comparable, V any] struct {
finished bool
silent bool
tracer Tracer[K, V]
cache DataCache[K, V]
}

// newBatcher returns a batcher for the current requests
// all the batcher methods must be protected by a global batchLock
func (l *Loader[K, V]) newBatcher(silent bool, tracer Tracer[K, V]) *batcher[K, V] {
func (l *Loader[K, V]) newBatcher(silent bool, tracer Tracer[K, V], cache DataCache[K, V]) *batcher[K, V] {
return &batcher[K, V]{
input: make(chan *batchRequest[K, V], l.inputCap),
batchFn: l.batchFn,
silent: silent,
tracer: tracer,
cache: cache,
}
}

Expand All @@ -412,17 +424,51 @@ func (b *batcher[K, V]) end() {
}
}

func batchWithCache[K comparable, V any](ctx context.Context, batchfn BatchFunc[K, V], keys Keys[K], cache DataCache[K, V]) []*Result[V] {
result := make([]*Result[V], len(keys))
reqKeys := make(Keys[K], 0, len(keys))
keyPosition := make(map[int]int, len(keys))

for i := range keys {
val, ok := cache.Get(keys[i].Context(), keys[i].Raw())
if ok {
result[i] = &Result[V]{Data: val}
continue
}
reqKeys = append(reqKeys, keys[i])
keyPosition[len(reqKeys)-1] = i
}

items := batchfn(ctx, reqKeys)
for i := range items {
reali, ok := keyPosition[i]
if !ok {
// if items more that we request, add item for end and show error after
result = append(result, items[i])
continue
}

result[reali] = items[i]

if items[i].Error == nil {
cache.Set(keys[reali].Context(), keys[reali].Raw(), items[i].Data)
}
}

return result
}

// execute the batch of all items in queue
func (b *batcher[K, V]) batch(originalContext context.Context) {
var (
keys = make([]K, 0)
keys = make(Keys[K], 0)
reqs = make([]*batchRequest[K, V], 0)
items = make([]*Result[V], 0)
panicErr interface{}
)

for item := range b.input {
keys = append(keys, item.key)
keys = append(keys, ContextKey(item.ctx, item.key))
reqs = append(reqs, item)
}

Expand All @@ -442,7 +488,11 @@ func (b *batcher[K, V]) batch(originalContext context.Context) {
log.Printf("Dataloader: Panic received in batch function: %v\n%s", panicErr, buf)
}
}()
items = b.batchFn(ctx, keys)
if b.cache != nil {
items = batchWithCache(ctx, b.batchFn, keys, b.cache)
} else {
items = b.batchFn(ctx, keys)
}
}()

if panicErr != nil {
Expand Down
Loading