Skip to content

Commit

Permalink
fix: make checkgroup sequential
Browse files Browse the repository at this point in the history
The checkgroup now behaves sequentially by blocking `Add` if there is
already a subcheck running.
  • Loading branch information
hperl committed Aug 17, 2022
1 parent 8a38df9 commit 1d2a49a
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 77 deletions.
126 changes: 126 additions & 0 deletions internal/check/bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package check_test

import (
"context"
"fmt"
"testing"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/keto/internal/check"
"github.com/ory/keto/internal/check/checkgroup"
"github.com/ory/keto/internal/driver/config"
"github.com/ory/keto/internal/namespace"
"github.com/ory/keto/internal/namespace/ast"
)

func wideNamespace(width int) *namespace.Namespace {
wideNS := &namespace.Namespace{
Name: fmt.Sprintf("%d_wide", width),
Relations: []ast.Relation{{Name: "editor"}},
}
viewerRelation := &ast.Relation{
Name: "viewer",
SubjectSetRewrite: &ast.SubjectSetRewrite{
Operation: ast.OperatorOr,
Children: ast.Children{},
},
}
for i := 0; i < width; i++ {
relation := fmt.Sprintf("relation-%d", i)
viewerRelation.SubjectSetRewrite.Children = append(
viewerRelation.SubjectSetRewrite.Children,
&ast.ComputedSubjectSet{Relation: relation},
)
wideNS.Relations = append(wideNS.Relations, ast.Relation{Name: relation})
}
viewerRelation.SubjectSetRewrite.Children = append(
viewerRelation.SubjectSetRewrite.Children,
&ast.ComputedSubjectSet{Relation: "editor"},
)
wideNS.Relations = append(wideNS.Relations, *viewerRelation)

return wideNS
}

func BenchmarkCheckEngine(b *testing.B) {
ctx := context.Background()
var (
depths = []int{2, 4, 8, 16, 32}
widths = []int{10, 20, 40, 80, 100}
maxDepth = depths[len(depths)-1]
)

var namespaces = []*namespace.Namespace{
{Name: "deep",
Relations: []ast.Relation{
{Name: "owner"},
{Name: "editor",
SubjectSetRewrite: &ast.SubjectSetRewrite{
Children: ast.Children{&ast.ComputedSubjectSet{
Relation: "owner"}}}},
{Name: "viewer",
SubjectSetRewrite: &ast.SubjectSetRewrite{
Children: ast.Children{
&ast.ComputedSubjectSet{
Relation: "editor"},
&ast.TupleToSubjectSet{
Relation: "parent",
ComputedSubjectSetRelation: "viewer"}}}},
}},
}

reg := newDepsProvider(b, namespaces)
reg.Logger().Logger.SetLevel(logrus.InfoLevel)

tuples := []string{
"deep:deep_file#parent@deep:folder_1#...",
}
for i := 1; i < maxDepth; i++ {
tuples = append(tuples, fmt.Sprintf("deep:folder_%d#parent@deep:folder_%d#...", i, i+1))
}
for _, d := range depths {
tuples = append(tuples, fmt.Sprintf("deep:folder_%d#owner@user_%d", d, d))
}
for _, w := range widths {
namespaces = append(namespaces, wideNamespace(w))
tuples = append(tuples, fmt.Sprintf("%d-wide:wide_file#editor@user", w))
}
insertFixtures(b, reg.RelationTupleManager(), tuples)

require.NoError(b, reg.Config(ctx).Set(config.KeyLimitMaxReadDepth, 100*maxDepth))
e := check.NewEngine(reg)

b.ResetTimer()
b.Run("case=deep tree", func(b *testing.B) {
for _, depth := range depths {
b.Run(fmt.Sprintf("depth=%03d", depth), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rt := tupleFromString(b, fmt.Sprintf("deep:deep_file#viewer@user_%d", depth))
res := e.CheckRelationTuple(ctx, rt, 2*depth)
assert.NoError(b, res.Err)
if res.Membership != checkgroup.IsMember {
b.Error("user should be able to view 'deep_file'")
}
}
})
}
})

b.Run("case=wide tree", func(b *testing.B) {
for _, width := range widths {
b.Run(fmt.Sprintf("width=%03d", width), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rt := tupleFromString(b, fmt.Sprintf("%d-wide:wide_file#editor@user", width))
res := e.CheckRelationTuple(ctx, rt, 2*width)
assert.NoError(b, res.Err)
if res.Membership != checkgroup.IsMember {
b.Error("user should be able to view 'wide_file'")
}
}
})
}
})
}
8 changes: 2 additions & 6 deletions internal/check/binop.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ func or(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result {
childCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

pool := checkgroup.PoolFromContext(ctx)
for _, check := range checks {
check := check
pool.Add(func() { check(childCtx, resultCh) })
go check(childCtx, resultCh)
}

for i := 0; i < len(checks); i++ {
Expand All @@ -51,10 +49,8 @@ func and(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result {
childCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

pool := checkgroup.PoolFromContext(ctx)
for _, check := range checks {
check := check
pool.Add(func() { check(childCtx, resultCh) })
go check(childCtx, resultCh)
}

tree := &ketoapi.Tree[*relationtuple.RelationTuple]{
Expand Down
30 changes: 14 additions & 16 deletions internal/check/checkgroup/checkgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func TestCheckgroup_cancels(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
g := checkgroup.New(ctx)
g.Add(neverFinishesCheckFunc)
g.Add(neverFinishesCheckFunc)
g.Add(neverFinishesCheckFunc)
g.Add(neverFinishesCheckFunc)
g.Add(neverFinishesCheckFunc)
go g.Add(neverFinishesCheckFunc)
go g.Add(neverFinishesCheckFunc)
go g.Add(neverFinishesCheckFunc)
go g.Add(neverFinishesCheckFunc)
cancel()
assert.Equal(t, checkgroup.Result{Err: context.Canceled}, g.Result())
}
Expand All @@ -60,31 +60,35 @@ func TestCheckgroup_reports_first_result(t *testing.T) {
defer cancel()

g := checkgroup.New(ctx)
g.Add(neverFinishesCheckFunc)
g.Add(notMemberAfterDelayFunc(1 * time.Microsecond))
g.Add(checkgroup.IsMemberFunc)
assert.Equal(t, checkgroup.Result{Membership: checkgroup.IsMember}, g.Result())
}

func TestCheckgroup_cancels_all_other_subchecks(t *testing.T) {
t.Parallel()

wasCancelled := make(chan bool)
wasCalled := false
wasCancelled := false
var mockCheckFn = func(ctx context.Context, resultCh chan<- checkgroup.Result) {
wasCalled = true
<-ctx.Done()
wasCancelled <- true
wasCancelled = true
resultCh <- checkgroup.Result{Err: ctx.Err()}
}

ctx := context.Background()

g := checkgroup.New(ctx)
g.Add(mockCheckFn)
g.Add(neverFinishesCheckFunc)
g.Add(notMemberAfterDelayFunc(1 * time.Microsecond))
g.Add(checkgroup.IsMemberFunc)
go g.Add(mockCheckFn)
result := g.Result()

assert.Equal(t, checkgroup.ResultIsMember, result)
assert.True(t, <-wasCancelled)
if wasCalled {
assert.True(t, wasCancelled)
}
assert.True(t, g.Done())
}

Expand Down Expand Up @@ -143,9 +147,6 @@ func TestCheckgroup_has_no_leaks(t *testing.T) {
checkgroup.UnknownMemberFunc,
isMemberAfterDelayFunc(5 * time.Millisecond),
notMemberAfterDelayFunc(1 * time.Millisecond),
neverFinishesCheckFunc,
neverFinishesCheckFunc,
neverFinishesCheckFunc,
},
expected: checkgroup.ResultIsMember,
},
Expand All @@ -158,9 +159,6 @@ func TestCheckgroup_has_no_leaks(t *testing.T) {
checkgroup.UnknownMemberFunc,
isMemberAfterDelayFunc(5 * time.Millisecond),
notMemberAfterDelayFunc(1 * time.Millisecond),
neverFinishesCheckFunc,
neverFinishesCheckFunc,
neverFinishesCheckFunc,
},
expected: checkgroup.ResultIsMember,
},
Expand Down
55 changes: 35 additions & 20 deletions internal/check/checkgroup/concurrent_checkgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ type concurrentCheckgroup struct {
// ctx.Err()}.
ctx context.Context

// pool is the worker pool (or nil if we want unbounded parallel checks),
// derived from the context.
pool Pool

// subcheckCtx is the context used for the subchecks.
subcheckCtx context.Context

Expand All @@ -39,15 +35,19 @@ type concurrentCheckgroup struct {
// result is only written once by the consumer, and can only be read after
// the doneCh channel is closed.
result Result

// reading from reserveCheckCh reserves the right to create a concurrent
// check.
reserveCheckCh chan struct{}
}

func NewConcurrent(ctx context.Context) Checkgroup {
g := &concurrentCheckgroup{
ctx: ctx,
pool: PoolFromContext(ctx),
finalizeCh: make(chan struct{}),
doneCh: make(chan struct{}),
addCheckCh: make(chan CheckFunc),
ctx: ctx,
finalizeCh: make(chan struct{}),
doneCh: make(chan struct{}),
addCheckCh: make(chan CheckFunc),
reserveCheckCh: make(chan struct{}, 1),
}
g.subcheckCtx, g.cancel = context.WithCancel(g.ctx)
g.startConsumer()
Expand All @@ -64,7 +64,7 @@ func (g *concurrentCheckgroup) startConsumer() {
g.startConsumerOnce.Do(func() {
go func() {
var (
subcheckCh = make(chan Result, 1)
resultCh = make(chan Result, 1)
totalChecks = 0
finishedChecks = 0
finalizing = false
Expand All @@ -79,22 +79,26 @@ func (g *concurrentCheckgroup) startConsumer() {
// `context.Canceled`), but we still want to receive these results
// so that there are no dangling goroutines.
defer func() {
go receiveRemaining(subcheckCh, totalChecks-finishedChecks)
go receiveRemaining(resultCh, totalChecks-finishedChecks)
}()

// Start with one reservation available.
g.reserveCheckCh <- struct{}{}

for {
select {
case f := <-g.addCheckCh:
case check := <-g.addCheckCh:
if finalizing {
continue
}
totalChecks++
g.pool.Add(func() { f(g.subcheckCtx, subcheckCh) })
go check(g.subcheckCtx, resultCh)

case <-g.finalizeCh:
if finalizing {
// we're already finalizing
// we don't want to accidentally set the result to ResultNotMember on a second finalize request
// we're already finalizing, so we don't want to
// accidentally set the result to ResultNotMember on a
// second finalize request
continue
}
finalizing = true
Expand All @@ -103,7 +107,7 @@ func (g *concurrentCheckgroup) startConsumer() {
return
}

case result := <-subcheckCh:
case result := <-resultCh:
finishedChecks++
if result.Err != nil || result.Membership == IsMember {
g.result = result
Expand All @@ -115,6 +119,12 @@ func (g *concurrentCheckgroup) startConsumer() {
return
}

// ready for a new check
select {
case g.reserveCheckCh <- struct{}{}:
default:
}

case <-g.subcheckCtx.Done():
g.result = Result{Err: g.ctx.Err()}
return
Expand All @@ -136,7 +146,11 @@ func (g *concurrentCheckgroup) Done() bool {
// Add adds the CheckFunc to the checkgroup and starts running it.
func (g *concurrentCheckgroup) Add(check CheckFunc) {
select {
case g.addCheckCh <- check:
case <-g.reserveCheckCh:
select {
case g.addCheckCh <- check:
case <-g.subcheckCtx.Done():
}
case <-g.subcheckCtx.Done():
}
}
Expand All @@ -146,9 +160,10 @@ func (g *concurrentCheckgroup) SetIsMember() {
g.Add(IsMemberFunc)
}

// tryFinalize tries to set the group state to finalize, i.e, signal the consumer that the result
// was requested and that no more checks will be added. If the consumer is
// already done, finalizing is not necessary anymore. This should never block.
// tryFinalize tries to set the group state to finalize, i.e, signal the
// consumer that the result was requested and that no more checks will be added.
// If the consumer is already done, finalizing is not necessary anymore. This
// should never block.
func (g *concurrentCheckgroup) tryFinalize() {
select {
case g.finalizeCh <- struct{}{}:
Expand Down
4 changes: 4 additions & 0 deletions internal/check/checkgroup/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type (
// only be run if there is a free worker available in the pool, thus
// limiting the concurrent workloads in flight.
Add(check func())

// TryAdd tries to add the check function if the pool has capacity.
// Otherwise, it returns false and does not add the check.
TryAdd(check func()) bool
}

workerPool struct {
Expand Down
17 changes: 16 additions & 1 deletion internal/check/checkgroup/workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ type (

const poolCtxKey ctxKey = "pool"

// WithPool returns a new context that contains the pool. The pool will be used by the checkgroup and the binary operators (or, and) when spawning subchecks.
// WithPool returns a new context that contains the pool. The pool will be used
// by the checkgroup and the binary operators (or, and) when spawning subchecks.
func WithPool(ctx context.Context, pool Pool) context.Context {
return context.WithValue(ctx, poolCtxKey, pool)
}
Expand Down Expand Up @@ -70,7 +71,21 @@ func (p *workerPool) Add(check func()) {
p.jobs <- check
}

func (p *workerPool) TryAdd(check func()) bool {
select {
case p.jobs <- check:
return true
default:
return false
}
}

// Add on a limitless pool just runs the function in a go routine.
func (p *limitlessPool) Add(check func()) {
go check()
}

func (p *limitlessPool) TryAdd(check func()) bool {
p.Add(check)
return true
}
Loading

0 comments on commit 1d2a49a

Please sign in to comment.