Skip to content

Commit

Permalink
feat: add sequential checkgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed May 31, 2022
1 parent 47bdd2d commit 0567cfe
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 78 deletions.
135 changes: 82 additions & 53 deletions internal/check/checkgroup/checkgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,53 @@ import (
"github.com/ory/keto/internal/check/checkgroup"
)

var neverFinishesCheckFn checkgroup.Func = func(context.Context, chan<- checkgroup.Result) {}
var neverFinishesCheckFn checkgroup.Func = func(ctx context.Context, resultCh chan<- checkgroup.Result) {
<-ctx.Done()
resultCh <- checkgroup.Result{Err: ctx.Err()}
}

var checkgroups = []struct {
name string
new checkgroup.Factory
}{
{name: "sequential", new: checkgroup.NewSequential},
{name: "concurrent", new: checkgroup.NewConcurrent},
}

func runWithCheckgroup(t *testing.T, test func(t *testing.T, new checkgroup.Factory)) {
for _, group := range checkgroups {
group := group
t.Run(group.name, func(t *testing.T) {
t.Parallel()
test(t, group.new)
})
}
}

func TestCheckgroup_cancels(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
g := checkgroup.New(ctx)
g.Add(neverFinishesCheckFn)
cancel()
assert.Equal(t, checkgroup.Result{Err: context.Canceled}, g.Result())
runWithCheckgroup(t, func(t *testing.T, new checkgroup.Factory) {
ctx, cancel := context.WithCancel(context.Background())
g := new(ctx)
g.Add(neverFinishesCheckFn)
cancel()
assert.Equal(t, checkgroup.Result{Err: context.Canceled}, g.Result())
})
}

func TestCheckgroup_reports_first_result(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runWithCheckgroup(t, func(t *testing.T, new checkgroup.Factory) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := checkgroup.New(ctx)
g.Add(neverFinishesCheckFn)
g.Add(checkgroup.IsMemberFunc)
assert.Equal(t, checkgroup.Result{Membership: checkgroup.IsMember}, g.Result())
g := new(ctx)
g.Add(neverFinishesCheckFn)
g.Add(checkgroup.IsMemberFunc)
assert.Equal(t, checkgroup.Result{Membership: checkgroup.IsMember}, g.Result())
})
}

func TestCheckgroup_cancels_all_other_subchecks(t *testing.T) {
Expand All @@ -45,68 +70,72 @@ func TestCheckgroup_cancels_all_other_subchecks(t *testing.T) {

ctx := context.Background()

g := checkgroup.New(ctx)
g := checkgroup.NewConcurrent(ctx)
g.Add(neverFinishesCheckFn)
g.Add(checkgroup.IsMemberFunc)
g.Add(mockCheckFn)
g.Result()

assert.True(t, <-wasCancelled)
assert.NotNil(t, <-g.Ctx.Done())
assert.True(t, g.Done())
}

func TestCheckgroup_returns_first_successful_is_member(t *testing.T) {
t.Parallel()

ctx := context.Background()
runWithCheckgroup(t, func(t *testing.T, new checkgroup.Factory) {
ctx := context.Background()

g := new(ctx)
g.Add(neverFinishesCheckFn)
g.Add(checkgroup.NotMemberFunc)
g.Add(checkgroup.NotMemberFunc)
time.Sleep(1 * time.Millisecond)
assert.False(t, g.Done())
g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) {
time.Sleep(10 * time.Millisecond)
resultCh <- checkgroup.ResultIsMember
})

g := checkgroup.New(ctx)
g.Add(neverFinishesCheckFn)
g.Add(checkgroup.NotMemberFunc)
g.Add(checkgroup.NotMemberFunc)
time.Sleep(1 * time.Millisecond)
assert.False(t, g.Done())
g.Add(func(_ context.Context, resultCh chan<- checkgroup.Result) {
time.Sleep(10 * time.Millisecond)
resultCh <- checkgroup.ResultIsMember
assert.Equal(t, checkgroup.Result{Membership: checkgroup.IsMember}, g.Result())
assert.True(t, g.Done())
})

assert.Equal(t, checkgroup.Result{Membership: checkgroup.IsMember}, g.Result())
assert.NotNil(t, <-g.Ctx.Done())
assert.True(t, g.Done())
}

func TestCheckgroup_returns_immediately_if_nothing_to_check(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runWithCheckgroup(t, func(t *testing.T, new checkgroup.Factory) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := checkgroup.New(ctx)
assert.Equal(t, checkgroup.ResultNotMember, g.Result())
g := new(ctx)
assert.Equal(t, checkgroup.ResultNotMember, g.Result())
})
}

func TestCheckgroup_propagates_not_member_results(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := checkgroup.New(ctx)
for i := 0; i < 100; i++ {
i := i
g.Add(func(ctx context.Context, resultCh chan<- checkgroup.Result) {
select {
case <-time.After(time.Duration(i) * time.Millisecond):
resultCh <- checkgroup.ResultNotMember
case <-ctx.Done():
resultCh <- checkgroup.Result{Err: context.Canceled}
}
})
}

resultCh := make(chan checkgroup.Result)
go g.CheckFunc()(ctx, resultCh)
result := <-resultCh

assert.Equal(t, checkgroup.ResultNotMember, result)
runWithCheckgroup(t, func(t *testing.T, new checkgroup.Factory) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := new(ctx)
for i := 0; i < 10; i++ {
i := i
g.Add(func(ctx context.Context, resultCh chan<- checkgroup.Result) {
select {
case <-time.After(time.Duration(i) * time.Millisecond):
resultCh <- checkgroup.ResultNotMember
case <-ctx.Done():
resultCh <- checkgroup.Result{Err: context.Canceled}
}
})
}

resultCh := make(chan checkgroup.Result)
go g.CheckFunc()(ctx, resultCh)
result := <-resultCh

assert.Equal(t, checkgroup.ResultNotMember, result)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"sync"
)

// A Checkgroup is a collection of goroutines performing checks.
type Checkgroup struct {
Ctx context.Context
// A concurrentCheckgroup is a collection of goroutines performing checks.
type concurrentCheckgroup struct {
ctx context.Context

cancel context.CancelFunc
resultCh chan Result
Expand All @@ -20,32 +20,32 @@ type Checkgroup struct {
}
}

func New(ctx context.Context) *Checkgroup {
return &Checkgroup{Ctx: ctx}
func NewConcurrent(ctx context.Context) Checkgroup {
return &concurrentCheckgroup{ctx: ctx}
}

func (g *Checkgroup) incrementRunningCheckCount() {
func (g *concurrentCheckgroup) incrementRunningCheckCount() {
g.counts.Lock()
defer g.counts.Unlock()
g.counts.totalChecks++
}
func (g *Checkgroup) incrementFinishedCheckCount() {
func (g *concurrentCheckgroup) incrementFinishedCheckCount() {
g.counts.Lock()
defer g.counts.Unlock()
g.counts.finishedChecks++
}

func (g *Checkgroup) allCheckFinished() bool {
func (g *concurrentCheckgroup) allCheckFinished() bool {
g.counts.RLock()
defer g.counts.RUnlock()
return g.counts.totalChecks == g.counts.finishedChecks
}

func (g *Checkgroup) startConsumer() {
func (g *concurrentCheckgroup) startConsumer() {
g.once.Do(func() {
g.subcheckCh = make(chan Result)
g.resultCh = make(chan Result)
g.Ctx, g.cancel = context.WithCancel(g.Ctx)
g.ctx, g.cancel = context.WithCancel(g.ctx)
go func() {
for {
select {
Expand All @@ -57,7 +57,7 @@ func (g *Checkgroup) startConsumer() {
return
}

case <-g.Ctx.Done():
case <-g.ctx.Done():
g.resultCh <- Result{Err: context.Canceled}
g.cancel()
return
Expand All @@ -67,33 +67,33 @@ func (g *Checkgroup) startConsumer() {
})
}

func (g *Checkgroup) Done() bool {
func (g *concurrentCheckgroup) Done() bool {
select {
case <-g.Ctx.Done():
case <-g.ctx.Done():
return true
default:
return false
}
}

// Add adds the Func to the checkgroup and starts running it.
func (g *Checkgroup) Add(check Func) {
func (g *concurrentCheckgroup) Add(check Func) {
g.startConsumer()
g.incrementRunningCheckCount()
go check(g.Ctx, g.subcheckCh)
go check(g.ctx, g.subcheckCh)
}

// SetIsMember makes the checkgroup emit "IsMember" directly.
func (g *Checkgroup) SetIsMember() {
func (g *concurrentCheckgroup) SetIsMember() {
g.Add(IsMemberFunc)
}

func (g *Checkgroup) noChecksAdded() bool {
func (g *concurrentCheckgroup) noChecksAdded() bool {
return g.counts.totalChecks == 0
}

// Result returns the Result, possibly blocking.
func (g *Checkgroup) Result() Result {
func (g *concurrentCheckgroup) Result() Result {
g.startConsumer()
if g.noChecksAdded() {
g.cancel()
Expand All @@ -104,7 +104,7 @@ func (g *Checkgroup) Result() Result {
}

// CheckFunc returns a `Func` that writes the result to the result channel.
func (g *Checkgroup) CheckFunc() Func {
func (g *concurrentCheckgroup) CheckFunc() Func {
g.startConsumer()
if g.noChecksAdded() {
g.cancel()
Expand Down
29 changes: 23 additions & 6 deletions internal/check/checkgroup/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@ package checkgroup

import "context"

type Func func(ctx context.Context, resultCh chan<- Result)
type (
Checkgroup interface {
Done() bool
Add(check Func)
SetIsMember()
Result() Result
CheckFunc() Func
}

type Result struct {
Membership Membership
Err error
}
Factory func(ctx context.Context) Checkgroup

Func func(ctx context.Context, resultCh chan<- Result)
Result struct {
Membership Membership
Err error
}

type Membership int
Membership int
)

const (
MembershipUnknown Membership = iota
Expand All @@ -22,6 +33,12 @@ var (
ResultNotMember = Result{Membership: NotMember}
)

var DefaultFactory Factory = NewSequential

func New(ctx context.Context) Checkgroup {
return DefaultFactory(ctx)
}

func ErrorFunc(err error) Func {
return func(_ context.Context, resultCh chan<- Result) {
resultCh <- Result{Err: err}
Expand Down
Loading

0 comments on commit 0567cfe

Please sign in to comment.