Skip to content

Commit

Permalink
feat: add a worker pool to limit parallel checks
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Aug 11, 2022
1 parent 451fbd0 commit 8a38df9
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 23 deletions.
8 changes: 8 additions & 0 deletions .schema/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@
"description": "The global maximum depth on all read operations. Note that this does not affect how deeply nested the tuples can be. This value can be decreased for a request by a value specified on the request, only if the request-specific value is greater than 1 and less than the global maximum depth.",
"minimum": 1,
"maximum": 65535
},
"max_parallel_checks": {
"type": "integer",
"default": 100,
"title": "Global maximum number of parallel checks",
"description": "This is the maximum number of checks that can be in flight in parallel.",
"minimum": 1,
"maximum": 65535
}
},
"additionalProperties": false
Expand Down
8 changes: 8 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@
"description": "The global maximum depth on all read operations. Note that this does not affect how deeply nested the tuples can be. This value can be decreased for a request by a value specified on the request, only if the request-specific value is greater than 1 and less than the global maximum depth.",
"minimum": 1,
"maximum": 65535
},
"max_parallel_checks": {
"type": "integer",
"default": 100,
"title": "Global maximum number of parallel checks",
"description": "This is the maximum number of checks that can be in flight in parallel.",
"minimum": 1,
"maximum": 65535
}
},
"additionalProperties": false
Expand Down
8 changes: 6 additions & 2 deletions internal/check/binop.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ func or(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result {
childCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

pool := checkgroup.PoolFromContext(ctx)
for _, check := range checks {
go check(childCtx, resultCh)
check := check
pool.Add(func() { check(childCtx, resultCh) })
}

for i := 0; i < len(checks); i++ {
Expand All @@ -49,8 +51,10 @@ func and(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result {
childCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

pool := checkgroup.PoolFromContext(ctx)
for _, check := range checks {
go check(childCtx, resultCh)
check := check
pool.Add(func() { check(childCtx, resultCh) })
}

tree := &ketoapi.Tree[*relationtuple.RelationTuple]{
Expand Down
41 changes: 25 additions & 16 deletions internal/check/checkgroup/checkgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package checkgroup_test

import (
"context"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -90,27 +91,35 @@ func TestCheckgroup_cancels_all_other_subchecks(t *testing.T) {
func TestCheckgroup_returns_first_successful_is_member(t *testing.T) {
t.Parallel()

ctx := context.Background()
for i := 1; i < 5; i++ {
t.Run(fmt.Sprintf("workers=%d", i), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = checkgroup.WithPool(ctx, checkgroup.NewPool(
checkgroup.WithWorkers(i),
checkgroup.WithContext(ctx)))

g := checkgroup.New(ctx)
g.Add(checkgroup.NotMemberFunc)
g.Add(checkgroup.NotMemberFunc)
time.Sleep(1 * time.Millisecond)
g := checkgroup.New(ctx)
g.Add(checkgroup.NotMemberFunc)
g.Add(checkgroup.NotMemberFunc)
time.Sleep(1 * time.Millisecond)

assert.False(t, g.Done())
assert.False(t, g.Done())

g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) {
resultCh <- checkgroup.ResultIsMember
})
g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) {
resultCh <- checkgroup.ResultIsMember
})

resultCh := make(chan checkgroup.Result)
go g.CheckFunc()(ctx, resultCh)
resultCh := make(chan checkgroup.Result)
go g.CheckFunc()(ctx, resultCh)

assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, <-resultCh)
assert.True(t, g.Done())
assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, g.Result())
assert.Equal(t, checkgroup.ResultIsMember, <-resultCh)
assert.True(t, g.Done())
})
}
}

func TestCheckgroup_returns_immediately_if_nothing_to_check(t *testing.T) {
Expand Down
7 changes: 6 additions & 1 deletion internal/check/checkgroup/concurrent_checkgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ type concurrentCheckgroup struct {
// ctx.Err()}.
ctx context.Context

// pool is the worker pool (or nil if we want unbounded parallel checks),
// derived from the context.
pool Pool

// subcheckCtx is the context used for the subchecks.
subcheckCtx context.Context

Expand Down Expand Up @@ -40,6 +44,7 @@ type concurrentCheckgroup struct {
func NewConcurrent(ctx context.Context) Checkgroup {
g := &concurrentCheckgroup{
ctx: ctx,
pool: PoolFromContext(ctx),
finalizeCh: make(chan struct{}),
doneCh: make(chan struct{}),
addCheckCh: make(chan CheckFunc),
Expand Down Expand Up @@ -84,7 +89,7 @@ func (g *concurrentCheckgroup) startConsumer() {
continue
}
totalChecks++
go f(g.subcheckCtx, subcheckCh)
g.pool.Add(func() { f(g.subcheckCtx, subcheckCh) })

case <-g.finalizeCh:
if finalizing {
Expand Down
25 changes: 25 additions & 0 deletions internal/check/checkgroup/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,38 @@ import (

type (
Checkgroup interface {
// Done returns true if a result is available.
Done() bool

// Add adds the CheckFunc to the checkgroup and starts running it.
Add(check CheckFunc)

// SetIsMember makes the checkgroup emit "IsMember" directly.
SetIsMember()

// Result returns the result, possibly blocking.
Result() Result

// CheckFunc returns a CheckFunc that writes the result to the result
// channel.
CheckFunc() CheckFunc
}

Pool interface {
// Add adds the function to the pool and schedules it. The function will
// only be run if there is a free worker available in the pool, thus
// limiting the concurrent workloads in flight.
Add(check func())
}

workerPool struct {
ctx context.Context
numWorkers int
jobs chan func()
}

limitlessPool struct{}

Factory = func(ctx context.Context) Checkgroup

CheckFunc = func(ctx context.Context, resultCh chan<- Result)
Expand Down
76 changes: 76 additions & 0 deletions internal/check/checkgroup/workerpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package checkgroup

import "context"

type (
PoolOption func(*workerPool)
ctxKey string
)

const poolCtxKey ctxKey = "pool"

// WithPool returns a new context that contains the pool. The pool will be used by the checkgroup and the binary operators (or, and) when spawning subchecks.
func WithPool(ctx context.Context, pool Pool) context.Context {
return context.WithValue(ctx, poolCtxKey, pool)
}

// PoolFromContext returns the pool from the context, or a pool that does not
// limit the number of parallel jobs if none found.
func PoolFromContext(ctx context.Context) Pool {
if p, ok := ctx.Value(poolCtxKey).(*workerPool); !ok {
return new(limitlessPool)
} else {
return p
}
}

// NewPool creates a new worker pool. With no options, this yields a pool with
// exactly one worker, meaning that all tasks that are added will run
// sequentially.
func NewPool(opts ...PoolOption) Pool {
pool := &workerPool{
numWorkers: 1,
}
for _, opt := range opts {
opt(pool)
}

pool.jobs = make(chan func(), pool.numWorkers)
for i := 0; i < pool.numWorkers; i++ {
go worker(pool.jobs)
}

if pool.ctx != nil {
go func() {
<-pool.ctx.Done()
close(pool.jobs)
}()
}

return pool
}

func worker(jobs <-chan func()) {
for job := range jobs {
job()
}
}

func WithWorkers(count int) PoolOption {
return func(p *workerPool) { p.numWorkers = count }
}
func WithContext(ctx context.Context) PoolOption {
return func(p *workerPool) { p.ctx = ctx }
}

// Add adds the function to the pool and schedules it. The function will only be
// run if there is a free worker available in the pool, thus limiting the
// concurrent workloads in flight.
func (p *workerPool) Add(check func()) {
p.jobs <- check
}

// Add on a limitless pool just runs the function in a go routine.
func (p *limitlessPool) Add(check func()) {
go check()
}
40 changes: 40 additions & 0 deletions internal/check/checkgroup/workerpool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package checkgroup_test

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/ory/keto/internal/check/checkgroup"
)

func TestPool(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

numWorkers := 5
p := checkgroup.NewPool(
checkgroup.WithWorkers(numWorkers),
checkgroup.WithContext(ctx),
)

var (
jobsCount int32
wg sync.WaitGroup
)

wg.Add(1000)
for i := 0; i < 1000; i++ {
p.Add(func() {
defer wg.Done()
if jobs := atomic.AddInt32(&jobsCount, 1); jobs > int32(numWorkers) {
t.Errorf("%d jobs in flight, more than %d", jobs, numWorkers)
}
time.Sleep(1 * time.Millisecond)
atomic.AddInt32(&jobsCount, -1)
})
}
wg.Wait()
}
8 changes: 7 additions & 1 deletion internal/check/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ type (
PermissionEngine() *Engine
}
Engine struct {
d EngineDependencies
d EngineDependencies
pool checkgroup.Pool
}
EngineDependencies interface {
relationtuple.ManagerProvider
Expand All @@ -39,6 +40,9 @@ const WildcardRelation = "..."
func NewEngine(d EngineDependencies) *Engine {
return &Engine{
d: d,
pool: checkgroup.NewPool(
checkgroup.WithWorkers(d.Config(context.Background()).MaxParallelChecks()),
),
}
}

Expand All @@ -63,6 +67,8 @@ func (e *Engine) CheckRelationTuple(ctx context.Context, r *relationTuple, restD
restDepth = globalMaxDepth
}

ctx = checkgroup.WithPool(ctx, e.pool)

resultCh := make(chan checkgroup.Result)
go e.checkIsAllowed(ctx, r, restDepth)(ctx, resultCh)
select {
Expand Down
2 changes: 2 additions & 0 deletions internal/check/testmain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
goleak.IgnoreCurrent(),
// fixed-size worker pool:
goleak.IgnoreTopFunction("github.com/ory/keto/internal/check/checkgroup.worker"),
goleak.IgnoreTopFunction("net/http.(*persistConn).readLoop"),
goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"),
)
Expand Down
11 changes: 8 additions & 3 deletions internal/driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import (
const (
KeyDSN = "dsn"

KeyLimitMaxReadDepth = "limit.max_read_depth"
KeyReadAPIHost = "serve.read.host"
KeyReadAPIPort = "serve.read.port"
KeyLimitMaxReadDepth = "limit.max_read_depth"
KeyLimitMaxParallelChecks = "limit.max_parallel_checks"
KeyReadAPIHost = "serve.read.host"
KeyReadAPIPort = "serve.read.port"

KeyWriteAPIHost = "serve.write.host"
KeyWriteAPIPort = "serve.write.port"
Expand Down Expand Up @@ -161,6 +162,10 @@ func (k *Config) MaxReadDepth() int {
return k.p.Int(KeyLimitMaxReadDepth)
}

func (k *Config) MaxParallelChecks() int {
return k.p.Int(KeyLimitMaxParallelChecks)
}

func (k *Config) WriteAPIListenOn() string {
return fmt.Sprintf(
"%s:%d",
Expand Down

0 comments on commit 8a38df9

Please sign in to comment.