Skip to content

Commit

Permalink
scylla.go: add scyllaPortIterator
Browse files Browse the repository at this point in the history
Adds an object which, for a given shard and total shard count, will
allow to iterate over source ports which can be used to connect to that
particular shard.
  • Loading branch information
piodul authored and mmatczuk committed Nov 12, 2020
1 parent b6efa43 commit e16e9fb
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
46 changes: 46 additions & 0 deletions scylla.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,49 @@ func closeConns(conns []*Conn) {
}
}
}

type scyllaPortIterator struct {
currentPort int
shardCount int
}

const (
scyllaPortBasedBalancingMin = 0x8000
scyllaPortBasedBalancingMax = 0xFFFF
)

func newScyllaPortIterator(shardID, shardCount int) *scyllaPortIterator {
if shardCount == 0 {
panic("shardCount cannot be 0")
}

// Find the smallest port p such that p >= min and p % shardCount == shardID
port := scyllaPortBasedBalancingMin - scyllaShardForSourcePort(scyllaPortBasedBalancingMin, shardCount) + shardID
if port < scyllaPortBasedBalancingMin {
port += shardCount
}

return &scyllaPortIterator{
currentPort: port,
shardCount: shardCount,
}
}

func (spi *scyllaPortIterator) nextPort() (uint16, bool) {
if spi == nil {
return 0, false
}

p := spi.currentPort

if p > scyllaPortBasedBalancingMax {
return 0, false
}

spi.currentPort += spi.shardCount
return uint16(p), true
}

func scyllaShardForSourcePort(sourcePort uint16, shardCount int) int {
return int(sourcePort) % shardCount
}
59 changes: 59 additions & 0 deletions scylla_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocql

import (
"fmt"
"math"
"runtime"
"sync"
Expand Down Expand Up @@ -215,3 +216,61 @@ func TestScyllaLWTExtParsing(t *testing.T) {
}
})
}

func TestScyllaPortIterator(t *testing.T) {
t.Parallel()

for _shardCount := 1; _shardCount <= 64; _shardCount++ {
shardCount := _shardCount
t.Run(fmt.Sprintf("shard count %d", shardCount), func(t *testing.T) {
t.Parallel()
for shardID := 0; shardID < shardCount; shardID++ {
// Count by brute force ports that can be used to connect to requested shard
expectedPortCount := 0
for i := scyllaPortBasedBalancingMin; i <= scyllaPortBasedBalancingMax; i++ {
if i%shardCount == shardID {
expectedPortCount++
}
}

// Enumerate all ports using the port iterator and assert various things
iterator := newScyllaPortIterator(shardID, shardCount)
actualPortCount := 0
previousPort := 0

for {
portU16, ok := iterator.nextPort()
if !ok {
break
}

port := int(portU16)

if port < scyllaPortBasedBalancingMin || port > scyllaPortBasedBalancingMax {
t.Errorf("expected port %d generated from iterator to be in range [%d..%d]",
port, scyllaPortBasedBalancingMin, scyllaPortBasedBalancingMax)
}

if port <= previousPort {
t.Errorf("expected port %d generated from iterator to be larger than the previous generated port %d",
port, previousPort)
}

actualShardOfPort := scyllaShardForSourcePort(portU16, shardCount)
if actualShardOfPort != shardID {
t.Errorf("expected port %d returned from iterator to belong to shard %d, but belongs to %d",
port, shardID, actualShardOfPort)
}

previousPort = port
actualPortCount++
}

if expectedPortCount != actualPortCount {
t.Errorf("expected port iterator to generate %d ports, but got %d",
expectedPortCount, actualPortCount)
}
}
})
}
}

0 comments on commit e16e9fb

Please sign in to comment.