diff --git a/.schema/config.schema.json b/.schema/config.schema.json index 927d75e27..e31011a84 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -312,6 +312,14 @@ "description": "The global maximum depth on all read operations. Note that this does not affect how deeply nested the tuples can be. This value can be decreased for a request by a value specified on the request, only if the request-specific value is greater than 1 and less than the global maximum depth.", "minimum": 1, "maximum": 65535 + }, + "max_parallel_checks": { + "type": "integer", + "default": 100, + "title": "Global maximum number of parallel checks", + "description": "This is the maximum number of checks that can be in flight in parallel.", + "minimum": 1, + "maximum": 65535 } }, "additionalProperties": false diff --git a/embedx/config.schema.json b/embedx/config.schema.json index 80677441a..63012454e 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -312,6 +312,14 @@ "description": "The global maximum depth on all read operations. Note that this does not affect how deeply nested the tuples can be. This value can be decreased for a request by a value specified on the request, only if the request-specific value is greater than 1 and less than the global maximum depth.", "minimum": 1, "maximum": 65535 + }, + "max_parallel_checks": { + "type": "integer", + "default": 100, + "title": "Global maximum number of parallel checks", + "description": "This is the maximum number of checks that can be in flight in parallel.", + "minimum": 1, + "maximum": 65535 } }, "additionalProperties": false diff --git a/internal/check/binop.go b/internal/check/binop.go index 53e79c6b6..60b672edc 100644 --- a/internal/check/binop.go +++ b/internal/check/binop.go @@ -21,8 +21,10 @@ func or(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result { childCtx, cancelFn := context.WithCancel(ctx) defer cancelFn() + pool := checkgroup.PoolFromContext(ctx) for _, check := range checks { - go check(childCtx, resultCh) + check := check + pool.Add(func() { check(childCtx, resultCh) }) } for i := 0; i < len(checks); i++ { @@ -49,8 +51,10 @@ func and(ctx context.Context, checks []checkgroup.CheckFunc) checkgroup.Result { childCtx, cancelFn := context.WithCancel(ctx) defer cancelFn() + pool := checkgroup.PoolFromContext(ctx) for _, check := range checks { - go check(childCtx, resultCh) + check := check + pool.Add(func() { check(childCtx, resultCh) }) } tree := &ketoapi.Tree[*relationtuple.RelationTuple]{ diff --git a/internal/check/checkgroup/checkgroup_test.go b/internal/check/checkgroup/checkgroup_test.go index df8e78cb4..538b081bf 100644 --- a/internal/check/checkgroup/checkgroup_test.go +++ b/internal/check/checkgroup/checkgroup_test.go @@ -2,6 +2,7 @@ package checkgroup_test import ( "context" + "fmt" "testing" "time" @@ -90,27 +91,35 @@ func TestCheckgroup_cancels_all_other_subchecks(t *testing.T) { func TestCheckgroup_returns_first_successful_is_member(t *testing.T) { t.Parallel() - ctx := context.Background() + for i := 1; i < 5; i++ { + t.Run(fmt.Sprintf("workers=%d", i), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = checkgroup.WithPool(ctx, checkgroup.NewPool( + checkgroup.WithWorkers(i), + checkgroup.WithContext(ctx))) - g := checkgroup.New(ctx) - g.Add(checkgroup.NotMemberFunc) - g.Add(checkgroup.NotMemberFunc) - time.Sleep(1 * time.Millisecond) + g := checkgroup.New(ctx) + g.Add(checkgroup.NotMemberFunc) + g.Add(checkgroup.NotMemberFunc) + time.Sleep(1 * time.Millisecond) - assert.False(t, g.Done()) + assert.False(t, g.Done()) - g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) { - resultCh <- checkgroup.ResultIsMember - }) + g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) { + resultCh <- checkgroup.ResultIsMember + }) - resultCh := make(chan checkgroup.Result) - go g.CheckFunc()(ctx, resultCh) + resultCh := make(chan checkgroup.Result) + go g.CheckFunc()(ctx, resultCh) - assert.Equal(t, checkgroup.ResultIsMember, g.Result()) - assert.Equal(t, checkgroup.ResultIsMember, g.Result()) - assert.Equal(t, checkgroup.ResultIsMember, g.Result()) - assert.Equal(t, checkgroup.ResultIsMember, <-resultCh) - assert.True(t, g.Done()) + assert.Equal(t, checkgroup.ResultIsMember, g.Result()) + assert.Equal(t, checkgroup.ResultIsMember, g.Result()) + assert.Equal(t, checkgroup.ResultIsMember, g.Result()) + assert.Equal(t, checkgroup.ResultIsMember, <-resultCh) + assert.True(t, g.Done()) + }) + } } func TestCheckgroup_returns_immediately_if_nothing_to_check(t *testing.T) { diff --git a/internal/check/checkgroup/concurrent_checkgroup.go b/internal/check/checkgroup/concurrent_checkgroup.go index f44bd9695..3f6dcda3f 100644 --- a/internal/check/checkgroup/concurrent_checkgroup.go +++ b/internal/check/checkgroup/concurrent_checkgroup.go @@ -12,6 +12,10 @@ type concurrentCheckgroup struct { // ctx.Err()}. ctx context.Context + // pool is the worker pool (or nil if we want unbounded parallel checks), + // derived from the context. + pool Pool + // subcheckCtx is the context used for the subchecks. subcheckCtx context.Context @@ -40,6 +44,7 @@ type concurrentCheckgroup struct { func NewConcurrent(ctx context.Context) Checkgroup { g := &concurrentCheckgroup{ ctx: ctx, + pool: PoolFromContext(ctx), finalizeCh: make(chan struct{}), doneCh: make(chan struct{}), addCheckCh: make(chan CheckFunc), @@ -84,7 +89,7 @@ func (g *concurrentCheckgroup) startConsumer() { continue } totalChecks++ - go f(g.subcheckCtx, subcheckCh) + g.pool.Add(func() { f(g.subcheckCtx, subcheckCh) }) case <-g.finalizeCh: if finalizing { diff --git a/internal/check/checkgroup/definitions.go b/internal/check/checkgroup/definitions.go index c7dba76e4..da8940b6e 100644 --- a/internal/check/checkgroup/definitions.go +++ b/internal/check/checkgroup/definitions.go @@ -11,13 +11,38 @@ import ( type ( Checkgroup interface { + // Done returns true if a result is available. Done() bool + + // Add adds the CheckFunc to the checkgroup and starts running it. Add(check CheckFunc) + + // SetIsMember makes the checkgroup emit "IsMember" directly. SetIsMember() + + // Result returns the result, possibly blocking. Result() Result + + // CheckFunc returns a CheckFunc that writes the result to the result + // channel. CheckFunc() CheckFunc } + Pool interface { + // Add adds the function to the pool and schedules it. The function will + // only be run if there is a free worker available in the pool, thus + // limiting the concurrent workloads in flight. + Add(check func()) + } + + workerPool struct { + ctx context.Context + numWorkers int + jobs chan func() + } + + limitlessPool struct{} + Factory = func(ctx context.Context) Checkgroup CheckFunc = func(ctx context.Context, resultCh chan<- Result) diff --git a/internal/check/checkgroup/workerpool.go b/internal/check/checkgroup/workerpool.go new file mode 100644 index 000000000..9dd7e8553 --- /dev/null +++ b/internal/check/checkgroup/workerpool.go @@ -0,0 +1,76 @@ +package checkgroup + +import "context" + +type ( + PoolOption func(*workerPool) + ctxKey string +) + +const poolCtxKey ctxKey = "pool" + +// WithPool returns a new context that contains the pool. The pool will be used by the checkgroup and the binary operators (or, and) when spawning subchecks. +func WithPool(ctx context.Context, pool Pool) context.Context { + return context.WithValue(ctx, poolCtxKey, pool) +} + +// PoolFromContext returns the pool from the context, or a pool that does not +// limit the number of parallel jobs if none found. +func PoolFromContext(ctx context.Context) Pool { + if p, ok := ctx.Value(poolCtxKey).(*workerPool); !ok { + return new(limitlessPool) + } else { + return p + } +} + +// NewPool creates a new worker pool. With no options, this yields a pool with +// exactly one worker, meaning that all tasks that are added will run +// sequentially. +func NewPool(opts ...PoolOption) Pool { + pool := &workerPool{ + numWorkers: 1, + } + for _, opt := range opts { + opt(pool) + } + + pool.jobs = make(chan func(), pool.numWorkers) + for i := 0; i < pool.numWorkers; i++ { + go worker(pool.jobs) + } + + if pool.ctx != nil { + go func() { + <-pool.ctx.Done() + close(pool.jobs) + }() + } + + return pool +} + +func worker(jobs <-chan func()) { + for job := range jobs { + job() + } +} + +func WithWorkers(count int) PoolOption { + return func(p *workerPool) { p.numWorkers = count } +} +func WithContext(ctx context.Context) PoolOption { + return func(p *workerPool) { p.ctx = ctx } +} + +// Add adds the function to the pool and schedules it. The function will only be +// run if there is a free worker available in the pool, thus limiting the +// concurrent workloads in flight. +func (p *workerPool) Add(check func()) { + p.jobs <- check +} + +// Add on a limitless pool just runs the function in a go routine. +func (p *limitlessPool) Add(check func()) { + go check() +} diff --git a/internal/check/checkgroup/workerpool_test.go b/internal/check/checkgroup/workerpool_test.go new file mode 100644 index 000000000..0d518efb2 --- /dev/null +++ b/internal/check/checkgroup/workerpool_test.go @@ -0,0 +1,40 @@ +package checkgroup_test + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ory/keto/internal/check/checkgroup" +) + +func TestPool(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + numWorkers := 5 + p := checkgroup.NewPool( + checkgroup.WithWorkers(numWorkers), + checkgroup.WithContext(ctx), + ) + + var ( + jobsCount int32 + wg sync.WaitGroup + ) + + wg.Add(1000) + for i := 0; i < 1000; i++ { + p.Add(func() { + defer wg.Done() + if jobs := atomic.AddInt32(&jobsCount, 1); jobs > int32(numWorkers) { + t.Errorf("%d jobs in flight, more than %d", jobs, numWorkers) + } + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&jobsCount, -1) + }) + } + wg.Wait() +} diff --git a/internal/check/engine.go b/internal/check/engine.go index f1e3de51c..37c31e41b 100644 --- a/internal/check/engine.go +++ b/internal/check/engine.go @@ -21,7 +21,8 @@ type ( PermissionEngine() *Engine } Engine struct { - d EngineDependencies + d EngineDependencies + pool checkgroup.Pool } EngineDependencies interface { relationtuple.ManagerProvider @@ -39,6 +40,9 @@ const WildcardRelation = "..." func NewEngine(d EngineDependencies) *Engine { return &Engine{ d: d, + pool: checkgroup.NewPool( + checkgroup.WithWorkers(d.Config(context.Background()).MaxParallelChecks()), + ), } } @@ -63,6 +67,8 @@ func (e *Engine) CheckRelationTuple(ctx context.Context, r *relationTuple, restD restDepth = globalMaxDepth } + ctx = checkgroup.WithPool(ctx, e.pool) + resultCh := make(chan checkgroup.Result) go e.checkIsAllowed(ctx, r, restDepth)(ctx, resultCh) select { diff --git a/internal/check/testmain_test.go b/internal/check/testmain_test.go index 9f9f271d5..dd314464e 100644 --- a/internal/check/testmain_test.go +++ b/internal/check/testmain_test.go @@ -9,6 +9,8 @@ import ( func TestMain(m *testing.M) { goleak.VerifyTestMain(m, goleak.IgnoreCurrent(), + // fixed-size worker pool: + goleak.IgnoreTopFunction("github.com/ory/keto/internal/check/checkgroup.worker"), goleak.IgnoreTopFunction("net/http.(*persistConn).readLoop"), goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"), ) diff --git a/internal/driver/config/provider.go b/internal/driver/config/provider.go index 886a96be4..536a72608 100644 --- a/internal/driver/config/provider.go +++ b/internal/driver/config/provider.go @@ -25,9 +25,10 @@ import ( const ( KeyDSN = "dsn" - KeyLimitMaxReadDepth = "limit.max_read_depth" - KeyReadAPIHost = "serve.read.host" - KeyReadAPIPort = "serve.read.port" + KeyLimitMaxReadDepth = "limit.max_read_depth" + KeyLimitMaxParallelChecks = "limit.max_parallel_checks" + KeyReadAPIHost = "serve.read.host" + KeyReadAPIPort = "serve.read.port" KeyWriteAPIHost = "serve.write.host" KeyWriteAPIPort = "serve.write.port" @@ -161,6 +162,10 @@ func (k *Config) MaxReadDepth() int { return k.p.Int(KeyLimitMaxReadDepth) } +func (k *Config) MaxParallelChecks() int { + return k.p.Int(KeyLimitMaxParallelChecks) +} + func (k *Config) WriteAPIListenOn() string { return fmt.Sprintf( "%s:%d",