Skip to content

Commit

Permalink
fix: make assert.CollectT concurrency safe
Browse files Browse the repository at this point in the history
  • Loading branch information
czeslavo committed Jul 28, 2023
1 parent 486eb6f commit c325f46
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
21 changes: 18 additions & 3 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"runtime"
"runtime/debug"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
Expand Down Expand Up @@ -1862,10 +1863,13 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
// CollectT implements the TestingT interface and collects all errors.
type CollectT struct {
errors []error
mu sync.RWMutex
}

// Errorf collects the error.
func (c *CollectT) Errorf(format string, args ...interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.errors = append(c.errors, fmt.Errorf(format, args...))
}

Expand All @@ -1876,6 +1880,8 @@ func (c *CollectT) FailNow() {

// Reset clears the collected errors.
func (c *CollectT) Reset() {
c.mu.Lock()
defer c.mu.Unlock()
c.errors = nil
}

Expand All @@ -1884,11 +1890,20 @@ func (c *CollectT) Copy(t TestingT) {
if tt, ok := t.(tHelper); ok {
tt.Helper()
}
c.mu.RLock()
defer c.mu.RUnlock()
for _, err := range c.errors {
t.Errorf("%v", err)
}
}

// hasErrors returns true if any errors were collected.
func (c *CollectT) hasErrors() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.errors) > 0
}

// EventuallyWithT asserts that given condition will be met in waitFor time,
// periodically checking target function each tick. In contrast to Eventually,
// it supplies a CollectT to the condition function, so that the condition
Expand Down Expand Up @@ -1931,10 +1946,10 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
collect.Reset()
go func() {
condition(collect)
ch <- len(collect.errors) == 0
ch <- collect.hasErrors()
}()
case v := <-ch:
if v {
case hasErrors := <-ch:
if !hasErrors {
return true
}
tick = ticker.C
Expand Down
7 changes: 7 additions & 0 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,13 @@ func TestEventuallyWithTTrue(t *testing.T) {
Len(t, mockT.errors, 0)
}

func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
mockT := new(CollectT)
EventuallyWithT(mockT, func(c *CollectT) {
NoError(c, AnError)
}, time.Millisecond, time.Nanosecond)
}

func TestNeverFalse(t *testing.T) {
condition := func() bool {
return false
Expand Down

0 comments on commit c325f46

Please sign in to comment.