Skip to content

✨ [pkg/test] Add context to ConcurrentT. #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions pkg/test/concurrent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test

import (
"context"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -130,17 +131,26 @@ type ConcurrentT struct {
t require.TestingT
failed bool
failedCh chan struct{}
ctx context.Context

mutex sync.Mutex
stages map[string]*stage
}

// NewConcurrent creates a new concurrent testing object.
func NewConcurrent(t require.TestingT) *ConcurrentT {
return NewConcurrentCtx(t, context.Background())
}

// NewConcurrentCtx creates a new concurrent testing object controlled by a
// context. If that context expires, any ongoing stages and wait calls will
// fail.
func NewConcurrentCtx(t require.TestingT, ctx context.Context) *ConcurrentT {
return &ConcurrentT{
t: t,
stages: make(map[string]*stage),
failedCh: make(chan struct{}),
ctx: ctx,
}
}

Expand All @@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
return s
}

// Wait waits until the stages and barriers with the requested names terminate.
// If any stage or barrier fails, terminates the current goroutine or test.
// Wait waits until the stages and barriers with the requested names
// terminate or the test's context expires. If the context expires, fails the
// test. If any stage or barrier fails, terminates the current goroutine or
// test.
func (t *ConcurrentT) Wait(names ...string) {
if len(names) == 0 {
panic("Wait(): called with 0 names")
Expand All @@ -177,6 +189,11 @@ func (t *ConcurrentT) Wait(names ...string) {
for _, name := range names {
stage := t.getStage(name)
select {
case <-t.ctx.Done():
t.failNowMutex.Lock()
t.t.Errorf("Wait for stage %s: %v", name, t.ctx.Err())
t.failNowMutex.Unlock()
t.FailNow()
case <-stage.wg.WaitCh():
if stage.failed.IsSet() {
t.FailNow()
Expand Down Expand Up @@ -209,28 +226,41 @@ func (t *ConcurrentT) FailNow() {
// fn must not spawn any goroutines or pass along the T object to goroutines
// that call T.Fatal. To achieve this, make other goroutines call
// ConcurrentT.StageN() instead.
// If the test's context expires before the call returns, fails the test.
func (t *ConcurrentT) StageN(name string, goroutines int, fn func(ConcT)) {
stage := t.spawnStage(name, goroutines)

stageT := ConcT{TestingT: stage, ct: t}
abort := CheckAbort(func() {
abort, ok := CheckAbortCtx(t.ctx, func() {
fn(stageT)
})

if abort != nil {
// Fail the stage, if it had not been marked as such, yet.
if stage.failed.TrySet() {
defer stage.wg.Done()
}
// If it is a panic or Goexit from certain contexts, print stack trace.
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
print("\n", abort.String())
}
if ok && abort == nil {
stage.pass()
t.Wait(name)
return
}

// Fail the stage, if it had not been marked as such, yet.
if stage.failed.TrySet() {
defer stage.wg.Done()
}

// If it did not terminate, just abort the test.
if !ok {
t.failNowMutex.Lock()
t.t.Errorf("Stage %s: %v", name, t.ctx.Err())
t.failNowMutex.Unlock()
t.FailNow()
}

stage.pass()
t.Wait(name)
// If it is a panic or Goexit from certain contexts, print stack trace.
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
t.failNowMutex.Lock()
t.t.Errorf("Stage %s: %s", name, abort.String())
t.failNowMutex.Unlock()
}
t.FailNow()
}

func shouldPrintStack(stack string) bool {
Expand Down
33 changes: 24 additions & 9 deletions pkg/test/concurrent_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test_test

import (
"context"
"fmt"
"strconv"
"sync"
Expand Down Expand Up @@ -50,20 +51,34 @@ func TestConcurrentT_Wait(t *testing.T) {
})
ctxtest.AssertTerminates(t, timeout, func() { ct.Wait("known") })
})

t.Run("context expiry", func(t *testing.T) {
ctxtest.AssertTerminates(t, timeout, func() {
test.AssertFatal(t, func(t test.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
test.NewConcurrentCtx(t, ctx).Stage("", func(test.ConcT) {
time.Sleep(timeout)
})
})
})
})
}

func TestConcurrentT_FailNow(t *testing.T) {
var ct *test.ConcurrentT
t.Run("idempotence", func(t *testing.T) {
var ct *test.ConcurrentT

// Test that NewConcurrent.FailNow() calls T.FailNow().
test.AssertFatal(t, func(t test.T) {
ct = test.NewConcurrent(t)
ct.FailNow()
})
// Test that NewConcurrent.FailNow() calls T.FailNow().
test.AssertFatal(t, func(t test.T) {
ct = test.NewConcurrent(t)
ct.FailNow()
})

// Test that after that, FailNow() calls runtime.Goexit().
assert.True(t, test.CheckGoexit(ct.FailNow),
"redundant FailNow() must call runtime.Goexit()")
// Test that after that, FailNow() calls runtime.Goexit().
assert.True(t, test.CheckGoexit(ct.FailNow),
"redundant FailNow() must call runtime.Goexit()")
})

t.Run("hammer", func(t *testing.T) {
const parallel = 12
Expand Down
44 changes: 28 additions & 16 deletions pkg/test/goexit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test

import (
"context"
"fmt"
"runtime/debug"
"strings"
Expand Down Expand Up @@ -64,10 +65,11 @@ func (g Goexit) String() string {
return "runtime.Goexit:\n\n" + g.Stack()
}

// CheckAbort tests whether a supplied function is aborted early using panic()
// or runtime.Goexit(). Returns a descriptor of the termination cause or nil if
// it terminated normally.
func CheckAbort(function func()) (abort Abort) {
// CheckAbortCtx tests whether a supplied function terminates within a context,
// and whether it is aborted early using panic() or runtime.Goexit(). Returns
// whether the function terminated before the expiry of the context and if so, a
// descriptor of the termination cause or nil if it terminated normally.
func CheckAbortCtx(ctx context.Context, function func()) (abort Abort, ok bool) {
done := make(chan struct{})

goexit := true // Whether runtime.Goexit occurred.
Expand Down Expand Up @@ -103,20 +105,30 @@ func CheckAbort(function func()) (abort Abort) {
goexit = false
}()

<-done

// Concatenate the inner call stack of the failure (which starts at the
// goroutine instantiation) with the goroutine that is calling CheckAbort.
if goexit || aborted {
base.stack += "\n" + getStack(true, 1, 0)
select {
case <-ctx.Done():
return nil, false
case <-done:
ok = true
// Concatenate the inner call stack of the failure (which starts at the
// goroutine instantiation) with the goroutine that is calling CheckAbort.
if goexit || aborted {
base.stack += "\n" + getStack(true, 1, 0)
}

if goexit {
abort = &Goexit{base}
} else if aborted {
abort = &Panic{base, recovered}
}
return
}
}

if goexit {
abort = &Goexit{base}
} else if aborted {
abort = &Panic{base, recovered}
}
return
// CheckAbort calls CheckAbortCtx with context.Background.
func CheckAbort(function func()) Abort {
abort, _ := CheckAbortCtx(context.Background(), function)
return abort
}

// getStack retrieves the current call stack as text, and optionally removes the
Expand Down