Skip to content

Commit

Permalink
Merge pull request #90225 from cockroachdb/blathers/backport-release-…
Browse files Browse the repository at this point in the history
…22.2-90207
  • Loading branch information
ajwerner authored Oct 21, 2022
2 parents 6cdd1f1 + b8a318a commit 021b27b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
20 changes: 15 additions & 5 deletions pkg/util/shuffle/shuffle.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

package shuffle

import "math/rand"
import (
"math/rand"
"sync"
"sync/atomic"
)

// Interface for shuffle. When it is satisfied, a collection can be shuffled by
// the routines in this package. The methods require that the elements of the
Expand All @@ -23,10 +27,16 @@ type Interface interface {
Swap(i, j int)
}

var seedSource int64
var randSyncPool = sync.Pool{
New: func() interface{} {
return rand.New(rand.NewSource(atomic.AddInt64(&seedSource, 1)))
},
}

// Shuffle randomizes the order of the array.
func Shuffle(data Interface) {
n := data.Len()
for i := 1; i < n; i++ {
data.Swap(i, rand.Intn(i+1))
}
r := randSyncPool.Get().(*rand.Rand)
defer randSyncPool.Put(r)
r.Shuffle(data.Len(), data.Swap)
}
96 changes: 77 additions & 19 deletions pkg/util/shuffle/shuffle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
package shuffle

import (
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
"unsafe"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
)
Expand All @@ -26,9 +29,15 @@ func (ts testSlice) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] }

func TestShuffle(t *testing.T) {
defer leaktest.AfterTest(t)()
rand.Seed(0)
old := randSyncPool.New
defer func() { randSyncPool.New = old }()
r := rand.New(rand.NewSource(0))
randSyncPool.New = func() interface{} {
return r
}

verify := func(original, expected testSlice) {
t.Helper()
Shuffle(original)
if !reflect.DeepEqual(original, expected) {
t.Errorf("expected %v, got %v", expected, original)
Expand All @@ -44,33 +53,82 @@ func TestShuffle(t *testing.T) {
verify(ts, testSlice{1})

ts = testSlice{1, 2}
verify(ts, testSlice{2, 1})
verify(ts, testSlice{1, 2})
verify(ts, testSlice{2, 1})

ts = testSlice{1, 2, 3}
verify(ts, testSlice{3, 1, 2})
verify(ts, testSlice{2, 3, 1})
verify(ts, testSlice{1, 3, 2})
verify(ts, testSlice{1, 2, 3})
verify(ts, testSlice{1, 2, 3})
verify(ts, testSlice{3, 1, 2})

ts = testSlice{1, 2, 3, 4, 5}
verify(ts, testSlice{2, 1, 3, 5, 4})
verify(ts, testSlice{4, 2, 1, 5, 3})
verify(ts, testSlice{1, 4, 2, 3, 5})
verify(ts, testSlice{2, 5, 4, 1, 3})
verify(ts, testSlice{4, 2, 3, 1, 5})
verify(ts, testSlice{5, 1, 3, 4, 2})
verify(ts, testSlice{2, 5, 3, 1, 4})
verify(ts, testSlice{3, 2, 5, 4, 1})
verify(ts, testSlice{1, 2, 4, 3, 5})
verify(ts, testSlice{3, 1, 5, 2, 4})

verify(ts[2:2], testSlice{})
verify(ts[0:0], testSlice{})
verify(ts[5:5], testSlice{})
verify(ts[3:5], testSlice{1, 5})
verify(ts[3:5], testSlice{5, 1})
verify(ts[0:2], testSlice{4, 2})
verify(ts[0:2], testSlice{2, 4})
verify(ts[1:4], testSlice{3, 5, 4})
verify(ts[1:4], testSlice{5, 4, 3})
verify(ts[0:4], testSlice{4, 5, 2, 3})
verify(ts[0:4], testSlice{2, 4, 3, 5})

verify(ts, testSlice{1, 3, 4, 2, 5})
verify(ts[3:5], testSlice{4, 2})
verify(ts[3:5], testSlice{2, 4})
verify(ts[0:2], testSlice{1, 3})
verify(ts[0:2], testSlice{1, 3})
verify(ts[1:4], testSlice{5, 2, 3})
verify(ts[1:4], testSlice{3, 5, 2})
verify(ts[0:4], testSlice{2, 3, 1, 5})
verify(ts[0:4], testSlice{5, 3, 1, 2})

verify(ts, testSlice{2, 3, 1, 5, 4})
}

type ints []int

func (i ints) Len() int { return len(i) }
func (i ints) Swap(a, b int) { i[a], i[b] = i[b], i[a] }

// BenchmarkConcurrentShuffle is used to demonstrate that the Shuffle
// function scales with cores. Once upon a time, it did not.
func BenchmarkConcurrentShuffle(b *testing.B) {
for _, concurrency := range []int{1, 4, 8} {
b.Run(fmt.Sprintf("concurrency=%d", concurrency), func(b *testing.B) {
for _, size := range []int{1 << 7, 1 << 10, 1 << 13} {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
b.SetBytes(int64(size * int(unsafe.Sizeof(0))))
bufs := make([]ints, 0, concurrency)
for i := 0; i < concurrency; i++ {
bufs = append(bufs, rand.Perm(size))
}
ns := distribute(b.N, concurrency)
var wg sync.WaitGroup
wg.Add(concurrency)
b.ResetTimer()
for i := 0; i < concurrency; i++ {
go func(buf *ints, n int) {
defer wg.Done()
for j := 0; j < n; j++ {
Shuffle(buf)
}
}(&bufs[i], ns[i])
}
wg.Wait()
})
}
})
}
}

// distribute returns a slice of <num> integers that add up to <total> and are
// within +/-1 of each other.
func distribute(total, num int) []int {
res := make([]int, num)
for i := range res {
// Use the average number of remaining connections.
div := len(res) - i
res[i] = (total + div/2) / div
total -= res[i]
}
return res
}

0 comments on commit 021b27b

Please sign in to comment.