Skip to content

Commit ddd6d25

Browse files
Simplify sampler interface (#3026)
1 parent 0928176 commit ddd6d25

24 files changed

+110
-110
lines changed

network/ip_tracker.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,8 @@ func (i *ipTracker) GetGossipableIPs(
400400

401401
uniform.Initialize(uint64(len(i.gossipableIPs)))
402402
for len(ips) < maxNumIPs {
403-
index, err := uniform.Next()
404-
if err != nil {
403+
index, hasNext := uniform.Next()
404+
if !hasNext {
405405
return ips
406406
}
407407

network/p2p/validators.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID {
125125

126126
uniform.Initialize(uint64(len(v.validatorList)))
127127
for len(sampled) < limit {
128-
i, err := uniform.Next()
129-
if err != nil {
128+
i, hasNext := uniform.Next()
129+
if !hasNext {
130130
break
131131
}
132132

network/peer/set.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ func (s *peerSet) Sample(n int, precondition func(Peer) bool) []Peer {
124124

125125
peers := make([]Peer, 0, n)
126126
for len(peers) < n {
127-
index, err := sampler.Next()
128-
if err != nil {
127+
index, hasNext := sampler.Next()
128+
if !hasNext {
129129
// We have run out of peers to attempt to sample.
130130
break
131131
}

snow/consensus/snowman/bootstrapper/sampler.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
package bootstrapper
55

66
import (
7+
"errors"
8+
79
"github.com/ava-labs/avalanchego/utils/math"
810
"github.com/ava-labs/avalanchego/utils/sampler"
911
"github.com/ava-labs/avalanchego/utils/set"
1012
)
1113

14+
var errUnexpectedSamplerFailure = errors.New("unexpected sampler failure")
15+
1216
// Sample keys from [elements] uniformly by weight without replacement. The
1317
// returned set will have size less than or equal to [maxSize]. This function
1418
// will error if the sum of all weights overflows.
@@ -36,9 +40,9 @@ func Sample[T comparable](elements map[T]uint64, maxSize int) (set.Set[T], error
3640
}
3741

3842
maxSize = int(min(uint64(maxSize), totalWeight))
39-
indices, err := sampler.Sample(maxSize)
40-
if err != nil {
41-
return nil, err
43+
indices, ok := sampler.Sample(maxSize)
44+
if !ok {
45+
return nil, errUnexpectedSamplerFailure
4246
}
4347

4448
sampledElements := set.NewSet[T](maxSize)

snow/validators/manager_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111

1212
"github.com/ava-labs/avalanchego/ids"
1313
"github.com/ava-labs/avalanchego/utils/crypto/bls"
14-
"github.com/ava-labs/avalanchego/utils/sampler"
1514
"github.com/ava-labs/avalanchego/utils/set"
1615

1716
safemath "github.com/ava-labs/avalanchego/utils/math"
@@ -396,7 +395,7 @@ func TestSample(t *testing.T) {
396395
require.Equal([]ids.NodeID{nodeID0}, sampled)
397396

398397
_, err = m.Sample(subnetID, 2)
399-
require.ErrorIs(err, sampler.ErrOutOfRange)
398+
require.ErrorIs(err, errInsufficientWeight)
400399

401400
nodeID1 := ids.GenerateTestNodeID()
402401
require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, math.MaxInt64-1))

snow/validators/set.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ var (
2323
errDuplicateValidator = errors.New("duplicate validator")
2424
errMissingValidator = errors.New("missing validator")
2525
errTotalWeightNotUint64 = errors.New("total weight is not a uint64")
26+
errInsufficientWeight = errors.New("insufficient weight")
2627
)
2728

2829
// newSet returns a new, empty set of validators.
@@ -257,9 +258,9 @@ func (s *vdrSet) sample(size int) ([]ids.NodeID, error) {
257258
s.samplerInitialized = true
258259
}
259260

260-
indices, err := s.sampler.Sample(size)
261-
if err != nil {
262-
return nil, err
261+
indices, ok := s.sampler.Sample(size)
262+
if !ok {
263+
return nil, errInsufficientWeight
263264
}
264265

265266
list := make([]ids.NodeID, size)

snow/validators/set_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111

1212
"github.com/ava-labs/avalanchego/ids"
1313
"github.com/ava-labs/avalanchego/utils/crypto/bls"
14-
"github.com/ava-labs/avalanchego/utils/sampler"
1514
"github.com/ava-labs/avalanchego/utils/set"
1615

1716
safemath "github.com/ava-labs/avalanchego/utils/math"
@@ -343,7 +342,7 @@ func TestSetSample(t *testing.T) {
343342
require.Equal([]ids.NodeID{nodeID0}, sampled)
344343

345344
_, err = s.Sample(2)
346-
require.ErrorIs(err, sampler.ErrOutOfRange)
345+
require.ErrorIs(err, errInsufficientWeight)
347346

348347
nodeID1 := ids.GenerateTestNodeID()
349348
require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1))

utils/sampler/uniform.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ package sampler
77
type Uniform interface {
88
Initialize(sampleRange uint64)
99
// Sample returns length numbers in the range [0,sampleRange). If there
10-
// aren't enough numbers in the range, an error is returned. If length is
10+
// aren't enough numbers in the range, false is returned. If length is
1111
// negative the implementation may panic.
12-
Sample(length int) ([]uint64, error)
12+
Sample(length int) ([]uint64, bool)
1313

14+
Next() (uint64, bool)
1415
Reset()
15-
Next() (uint64, error)
1616
}
1717

1818
// NewUniform returns a new sampler

utils/sampler/uniform_best.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ samplerLoop:
5656

5757
start := s.clock.Time()
5858
for i := 0; i < s.benchmarkIterations; i++ {
59-
if _, err := sampler.Sample(sampleSize); err != nil {
59+
if _, ok := sampler.Sample(sampleSize); !ok {
6060
continue samplerLoop
6161
}
6262
}

utils/sampler/uniform_replacer.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,34 @@ func (s *uniformReplacer) Initialize(length uint64) {
3636
s.drawsCount = 0
3737
}
3838

39-
func (s *uniformReplacer) Sample(count int) ([]uint64, error) {
39+
func (s *uniformReplacer) Sample(count int) ([]uint64, bool) {
4040
s.Reset()
4141

4242
results := make([]uint64, count)
4343
for i := 0; i < count; i++ {
44-
ret, err := s.Next()
45-
if err != nil {
46-
return nil, err
44+
ret, hasNext := s.Next()
45+
if !hasNext {
46+
return nil, false
4747
}
4848
results[i] = ret
4949
}
50-
return results, nil
50+
return results, true
5151
}
5252

5353
func (s *uniformReplacer) Reset() {
5454
clear(s.drawn)
5555
s.drawsCount = 0
5656
}
5757

58-
func (s *uniformReplacer) Next() (uint64, error) {
58+
func (s *uniformReplacer) Next() (uint64, bool) {
5959
if s.drawsCount >= s.length {
60-
return 0, ErrOutOfRange
60+
return 0, false
6161
}
6262

6363
draw := s.rng.Uint64Inclusive(s.length-1-s.drawsCount) + s.drawsCount
6464
ret := s.drawn.get(draw, draw)
6565
s.drawn[draw] = s.drawn.get(s.drawsCount, s.drawsCount)
6666
s.drawsCount++
6767

68-
return ret, nil
68+
return ret, true
6969
}

0 commit comments

Comments
 (0)