Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions vnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package sctp
import (
"bytes"
crand "crypto/rand"
"fmt"
"net"
"reflect"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -686,3 +688,270 @@ func TestCookieEchoRetransmission(t *testing.T) {
<-serverShutDown
log.Info("all done")
}

// Simulate an RTT switch (high -> low) by delaying early DATA, then disabling delay so
// later DATA arrives before earlier DATA. Under a RACK regression, rackMinRTT would never increases,
// causing reoWnd to be too small and marking packets sent at high RTT as spuriously lost.
func TestRACK_RTTSwitch_Reordering_NoDrop(t *testing.T) { //nolint:gocyclo,cyclop,maintidx
lim := test.TimeOut(10 * time.Second)
defer lim.Stop()

loggerFactory := logging.NewDefaultLoggerFactory()
log := loggerFactory.NewLogger("test-rack-rtt-switch")

venv, err := buildVNetEnv(t, &vNetEnvConfig{
minDelay: 0,
loggerFactory: loggerFactory,
log: log,
})
require.NoError(t, err)
require.NotNil(t, venv)

defer venv.wan.Stop() // nolint:errcheck

var delayOn atomic.Value
delayOn.Store(true)
venv.wan.AddChunkFilter(func(c vnet.Chunk) bool {
p := &packet{}
if err := p.unmarshal(true, c.UserData()); err != nil {
return true
}
v := delayOn.Load()
if val, ok := v.(bool); ok && !val {
return true
}
for i := 0; i < len(p.chunks); i++ {
if _, ok := p.chunks[i].(*chunkPayloadData); ok {
time.Sleep(100 * time.Millisecond)

break
}
}

return true
})

const (
numMessages = 40
messageSize = 256
)

makeMessages := func() [][]byte {
msgs := make([][]byte, numMessages)
for i := 0; i < numMessages; i++ {
b := bytes.Repeat([]byte{byte(i % 251)}, messageSize)
msgs[i] = b
}

return msgs
}

type statsResult struct {
fr uint64
ok bool
}

errCh := make(chan error, 16)
clientDone := make(chan struct{})
serverDone := make(chan struct{})
clientStatsCh := make(chan statsResult, 1)
serverStatsCh := make(chan statsResult, 1)

go func() {
defer close(serverDone)

fail := func(e error) {
if e != nil {
errCh <- e
}
}

conn, err := venv.net0.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
)
if err != nil {
fail(fmt.Errorf("server DialUDP: %w", err))
serverStatsCh <- statsResult{ok: false}

return
}

defer conn.Close() // nolint:errcheck

assoc, err := Server(Config{
NetConn: conn,
LoggerFactory: loggerFactory,
})
if err != nil {
fail(fmt.Errorf("server assoc: %w", err))
serverStatsCh <- statsResult{ok: false}

return
}

defer func() {
var fr uint64
if assoc != nil {
fr = assoc.stats.getNumFastRetrans()
}
serverStatsCh <- statsResult{fr: fr, ok: assoc != nil}
_ = assoc.Close()
}()

stream, err := assoc.AcceptStream()
if err != nil {
fail(fmt.Errorf("server AcceptStream: %w", err))

return
}
defer stream.Close() // nolint:errcheck
stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0)

buf := make([]byte, 1500)
for {
_ = stream.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, rerr := stream.Read(buf)
if rerr != nil {
return
}
if n > 0 {
_, _ = stream.Write(buf[:n])
}
}
}()

go func() {
defer close(clientDone)

fail := func(e error) {
if e != nil {
errCh <- e
}
}

conn, err := venv.net1.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
)
if err != nil {
fail(fmt.Errorf("client DialUDP: %w", err))
clientStatsCh <- statsResult{ok: false}

return
}
defer conn.Close() // nolint:errcheck

assoc, err := Client(Config{
NetConn: conn,
LoggerFactory: loggerFactory,
})
if err != nil {
fail(fmt.Errorf("client assoc: %w", err))
clientStatsCh <- statsResult{ok: false}

return
}

defer func() {
var fr uint64
if assoc != nil {
fr = assoc.stats.getNumFastRetrans()
}
clientStatsCh <- statsResult{fr: fr, ok: assoc != nil}
_ = assoc.Close()
}()

stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary)
if err != nil {
fail(fmt.Errorf("client OpenStream: %w", err))

return
}
defer stream.Close() // nolint:errcheck
stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0)

msgs := makeMessages()

// phase 1: high-RTT emulation we send 25 messages and drop a DATA chunk for one time.
delayOn.Store(true)
venv.dropNextDataChunk(1)
for i := 0; i < 25; i++ {
if _, werr := stream.Write(msgs[i]); werr != nil {
fail(fmt.Errorf("client write phase1 i=%d: %w", i, werr))

return
}
}

// phase 2 we switch to low-RTT, newer datea should arrive before older.
delayOn.Store(false)
for i := 25; i < numMessages; i++ {
if _, werr := stream.Write(msgs[i]); werr != nil {
fail(fmt.Errorf("client write phase2 i=%d: %w", i, werr))

return
}
}

seen := make(map[byte]bool, numMessages)
buf := make([]byte, 4096)
deadline := time.Now().Add(10 * time.Second)

for len(seen) < numMessages && time.Now().Before(deadline) {
_ = stream.SetReadDeadline(time.Now().Add(250 * time.Millisecond))
n, rerr := stream.Read(buf)
if rerr != nil || n == 0 {
continue
}
if n < messageSize {
fail(fmt.Errorf("short echo read: got=%d want=%d", n, messageSize)) //nolint:err113

return
}
id := buf[0]
if seen[id] {
// dups are harmless, keep reading
continue
}

expected := bytes.Repeat([]byte{id}, messageSize)
if !bytes.Equal(buf[:messageSize], expected) {
fail(fmt.Errorf("payload mismatch for id=%d", int(id))) //nolint:err113

return
}
seen[id] = true
}

if len(seen) != numMessages {
fail(fmt.Errorf("missing echoes: got=%d want=%d", len(seen), numMessages)) //nolint:err113

return
}
}()

<-clientDone
<-serverDone

// drain and assert errors, well if any :)
close(errCh)
for e := range errCh {
assert.NoError(t, e)
}

// check FR stats reported.
// we can uncomment this to check the FR stats.
// I tested it and it works fine on the pch07/rack-sctp branch.
// cs := <-clientStatsCh
// ss := <-serverStatsCh
//
// if assert.True(t, cs.ok, "client assoc/stats unavailable") {
// assert.LessOrEqual(t, cs.fr, uint64(2),
// "client fast retransmits should be low")
// }
// if assert.True(t, ss.ok, "server assoc/stats unavailable") {
// assert.LessOrEqual(t, ss.fr, uint64(2),
// "server fast retransmits should be low")
// }
}
Loading