Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
return v
}

defer l.batchLock.Unlock()
defer l.cacheLock.Unlock()

thunk := func() (V, error) {
<-req.done
result := req.result.Load()
Expand Down Expand Up @@ -294,6 +291,12 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
}
}

// NOTE: It is intended that these are not unlocked with a `defer`. This is due to the `defer finish(thunk)` above.
// There is a locking bug where, if you have a tracer that calls the thunk to read the results, the dataloader runs
// into a deadlock scenario, as `finish` is called before these mutexes are free'd on the same goroutine.
l.batchLock.Unlock()
l.cacheLock.Unlock()

return thunk
}

Expand Down
62 changes: 60 additions & 2 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ func TestLoader(t *testing.T) {
}
})

t.Run("test Load method does not create a deadlock mutex condition", func(t *testing.T) {
t.Parallel()

loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{}))

value, err := loader.Load(context.Background(), "1")()
if err != nil {
t.Error(err.Error())
}
if value != "1" {
t.Error("load didn't return the right value")
}

// By this function completing, we confirm that there is not a deadlock condition, else the test will hang
})

t.Run("test LoadMany returns errors", func(t *testing.T) {
t.Parallel()
errorLoader, _ := ErrorLoader[string](0)
Expand Down Expand Up @@ -202,6 +218,26 @@ func TestLoader(t *testing.T) {
}
})

t.Run("test LoadMany method does not create a deadlock mutex condition", func(t *testing.T) {
t.Parallel()

loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{}))

values, errs := loader.LoadMany(context.Background(), []string{"1", "2", "3"})()
for _, err := range errs {
if err != nil {
t.Error(err.Error())
}
}
for _, value := range values {
if value == "" {
t.Error("unexpected empty value in LoadMany returned")
}
}

// By this function completing, we confirm that there is not a deadlock condition, else the test will hang
})

t.Run("test thunkmany does not contain race conditions", func(t *testing.T) {
t.Parallel()
identityLoader, _ := IDLoader[string](0)
Expand Down Expand Up @@ -590,7 +626,7 @@ func TestLoader(t *testing.T) {
}

// test helpers
func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
func IDLoader[K comparable](max int, options ...Option[K, K]) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
Expand All @@ -602,7 +638,7 @@ func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
results = append(results, &Result[K]{key, nil})
}
return results
}, WithBatchCapacity[K, K](max))
}, append([]Option[K, K]{WithBatchCapacity[K, K](max)}, options...)...)
return identityLoader, &loadCalls
}
func BatchOnlyLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
Expand Down Expand Up @@ -788,6 +824,28 @@ func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) {
return loader, &loadCalls
}

type TracerWithThunkReading[K comparable, V any] struct{}

var _ Tracer[string, struct{}] = (*TracerWithThunkReading[string, struct{}])(nil)

func (_ *TracerWithThunkReading[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, TraceLoadFinishFunc[V]) {
return ctx, func(thunk Thunk[V]) {
_, _ = thunk()
}
}

func (_ *TracerWithThunkReading[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, TraceLoadManyFinishFunc[V]) {
return ctx, func(thunks ThunkMany[V]) {
_, _ = thunks()
}
}

func (_ *TracerWithThunkReading[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, TraceBatchFinishFunc[V]) {
return ctx, func(thunks []*Result[V]) {
//
}
}

/*
Benchmarks
*/
Expand Down