Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tavern/portals/stream/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package stream

import (
"errors"
"fmt"
"time"

"realm.pub/tavern/portals/portalpb"
)

// ReceiverFunc is a callback that reads a Mote from a source.
// This allows OrderedReader to wrap any gRPC stream method.
type ReceiverFunc func() (*portalpb.Mote, error)

// OrderedReader receives Motes and ensures they are read in order.
type OrderedReader struct {
nextSeqID uint64
buffer map[uint64]*portalpb.Mote
maxBuffer int
staleTimeout time.Duration
firstBufferedAt time.Time
receiver ReceiverFunc
}

// NewOrderedReader creates a new OrderedReader.
// maxBuffer limits the number of out-of-order messages to buffer.
// staleTimeout is the duration to wait for the next expected sequence ID before erroring if other messages are arriving.
func NewOrderedReader(receiver ReceiverFunc, maxBuffer int, staleTimeout time.Duration) *OrderedReader {
if maxBuffer <= 0 {
maxBuffer = 100 // Default sensible limit
}
return &OrderedReader{
nextSeqID: 0,
buffer: make(map[uint64]*portalpb.Mote),
maxBuffer: maxBuffer,
staleTimeout: staleTimeout,
receiver: receiver,
}
}

// Read returns the next ordered Mote.
// It will block until the next ordered Mote is available or an error occurs.
func (r *OrderedReader) Read() (*portalpb.Mote, error) {
// First check if the next message is already in the buffer
if mote, ok := r.buffer[r.nextSeqID]; ok {
delete(r.buffer, r.nextSeqID)
r.nextSeqID++
// Reset buffer time if buffer becomes empty
if len(r.buffer) == 0 {
r.firstBufferedAt = time.Time{}
}
return mote, nil
}

// Check for stale state before reading
if len(r.buffer) > 0 {
if time.Since(r.firstBufferedAt) > r.staleTimeout {
return nil, fmt.Errorf("stale stream: timeout waiting for seqID %d", r.nextSeqID)
}
}

// Read loop
for {
mote, err := r.receiver()
if err != nil {
return nil, err
}

if mote.SeqId == r.nextSeqID {
r.nextSeqID++
// We found the next packet. Check if we have subsequent packets buffered.
// However, since we return one packet at a time, we just return this one.
// The next call to Read() will check the buffer.

// If buffer is empty now (implied, unless we had gaps filled out of order, which Read check handles), reset timer?
// Actually, if we just returned X, and X+1 is in buffer, next Read gets X+1.
// If X+1 is NOT in buffer but X+2 is, the timer (firstBufferedAt) should presumably persist?
// The original logic was: firstBufferedAt is set when the buffer goes from empty to non-empty.
// Here, if we return X, and buffer has items, we might still be waiting for X+1 (if X was fresh but we have X+2).
// But wait, if we just received X, and we have X+2, X+3 buffered.
// We return X.
// Next Read called. X+1 is missing. Buffer has X+2.
// We are still in a "gap" state relative to X+1.
// Does the timer reset?
// "Timeout for stale detection". Usually means if we are stuck at a gap for too long.
// If we make progress (received X), we should probably reset or extend the timer?
// But if we are missing X+1, and we have X+2 since 10 minutes ago...
// If we just got X, we are "making progress". So arguably the stream is alive.
// Let's reset the timer if we successfully return a packet, OR if the buffer becomes empty.
// If we return X, and buffer is not empty, it means we have future packets.
// Logic: If we successfully processed a packet, we are not stale *yet* regarding flow.
// However, if X+2 arrived 10m ago, and X just arrived, maybe X+1 is lost?
// If X just arrived, we are moving.

// I'll update firstBufferedAt to Now if buffer is still non-empty after finding a match.
// This effectively gives "staleTimeout" duration to fill the *next* gap after a success.
if len(r.buffer) > 0 {
r.firstBufferedAt = time.Now()
} else {
r.firstBufferedAt = time.Time{}
}
return mote, nil
} else if mote.SeqId > r.nextSeqID {
// Gap detected
if len(r.buffer) == 0 {
r.firstBufferedAt = time.Now()
}

if _, exists := r.buffer[mote.SeqId]; !exists {
r.buffer[mote.SeqId] = mote
}

if len(r.buffer) > r.maxBuffer {
return nil, errors.New("stale stream: buffer limit exceeded")
}

// Check timeout again inside the loop as we might be receiving many out of order packets quickly
if time.Since(r.firstBufferedAt) > r.staleTimeout {
return nil, fmt.Errorf("stale stream: timeout waiting for seqID %d", r.nextSeqID)
}

// Continue reading to find the expected packet
continue
} else {
// mote.SeqId < r.nextSeqID
// Duplicate or old packet. Ignore.
continue
}
}
}
69 changes: 69 additions & 0 deletions tavern/portals/stream/sequencer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package stream

import (
"sync/atomic"

"realm.pub/tavern/portals/portalpb"
)

// payloadSequencer sequences payloads with a stream ID and monotonic sequence ID.
type payloadSequencer struct {
nextSeqID atomic.Uint64
streamID string
}

// newPayloadSequencer creates a new payloadSequencer with the given streamID.
func newPayloadSequencer(streamID string) *payloadSequencer {
return &payloadSequencer{
streamID: streamID,
}
}

// newSeqID returns the current sequence ID and increments it.
func (s *payloadSequencer) newSeqID() uint64 {
return s.nextSeqID.Add(1) - 1
}

// NewBytesMote creates a new Mote with a BytesPayload.
func (s *payloadSequencer) NewBytesMote(data []byte, kind portalpb.BytesPayloadKind) *portalpb.Mote {
return &portalpb.Mote{
StreamId: s.streamID,
SeqId: s.newSeqID(),
Payload: &portalpb.Mote_Bytes{
Bytes: &portalpb.BytesPayload{
Data: data,
Kind: kind,
},
},
}
}

// NewTCPMote creates a new Mote with a TCPPayload.
func (s *payloadSequencer) NewTCPMote(data []byte, dstAddr string, dstPort uint32) *portalpb.Mote {
return &portalpb.Mote{
StreamId: s.streamID,
SeqId: s.newSeqID(),
Payload: &portalpb.Mote_Tcp{
Tcp: &portalpb.TCPPayload{
Data: data,
DstAddr: dstAddr,
DstPort: dstPort,
},
},
}
}

// NewUDPMote creates a new Mote with a UDPPayload.
func (s *payloadSequencer) NewUDPMote(data []byte, dstAddr string, dstPort uint32) *portalpb.Mote {
return &portalpb.Mote{
StreamId: s.streamID,
SeqId: s.newSeqID(),
Payload: &portalpb.Mote_Udp{
Udp: &portalpb.UDPPayload{
Data: data,
DstAddr: dstAddr,
DstPort: dstPort,
},
},
}
}
190 changes: 190 additions & 0 deletions tavern/portals/stream/stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package stream

import (
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"realm.pub/tavern/portals/portalpb"
)

func TestSequencer(t *testing.T) {
streamID := "test-stream-id"
seq := newPayloadSequencer(streamID)
assert.Equal(t, streamID, seq.streamID)
assert.Equal(t, uint64(0), seq.nextSeqID.Load())

// Test newSeqID
id1 := seq.newSeqID()
assert.Equal(t, uint64(0), id1)
id2 := seq.newSeqID()
assert.Equal(t, uint64(1), id2)

// Test Mote creation
mote := seq.NewBytesMote([]byte("test"), portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_DATA)
assert.Equal(t, seq.streamID, mote.StreamId)
assert.Equal(t, uint64(2), mote.SeqId)
assert.IsType(t, &portalpb.Mote_Bytes{}, mote.Payload)
}

func TestOrderedWriter(t *testing.T) {
streamID := "test-stream-id"
var sentMote *portalpb.Mote
sender := func(m *portalpb.Mote) error {
sentMote = m
return nil
}

w := NewOrderedWriter(streamID, sender)

err := w.WriteBytes([]byte("hello"), portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_DATA)
require.NoError(t, err)
require.NotNil(t, sentMote)
assert.Equal(t, streamID, sentMote.StreamId)
assert.Equal(t, uint64(0), sentMote.SeqId)
assert.Equal(t, []byte("hello"), sentMote.GetBytes().Data)

err = w.WriteTCP([]byte("tcp"), "1.2.3.4", 80)
require.NoError(t, err)
assert.Equal(t, uint64(1), sentMote.SeqId)
assert.Equal(t, "1.2.3.4", sentMote.GetTcp().DstAddr)
}

func TestOrderedReader_Ordered(t *testing.T) {
motes := []*portalpb.Mote{
{SeqId: 0, StreamId: "s1"},
{SeqId: 1, StreamId: "s1"},
{SeqId: 2, StreamId: "s1"},
}
idx := 0
receiver := func() (*portalpb.Mote, error) {
if idx >= len(motes) {
return nil, errors.New("EOF")
}
m := motes[idx]
idx++
return m, nil
}

r := NewOrderedReader(receiver, 10, time.Second)

for i := 0; i < 3; i++ {
m, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(i), m.SeqId)
}
}

func TestOrderedReader_OutOfOrder(t *testing.T) {
// Send 2, 0, 1
motes := []*portalpb.Mote{
{SeqId: 2, StreamId: "s1"},
{SeqId: 0, StreamId: "s1"},
{SeqId: 1, StreamId: "s1"},
}
idx := 0
receiver := func() (*portalpb.Mote, error) {
if idx >= len(motes) {
// Instead of erroring immediately, block or return nil to simulate waiting?
// The reader loop will continue if we don't return error.
// But here we expect to consume all.
return nil, errors.New("EOF")
}
m := motes[idx]
idx++
return m, nil
}

r := NewOrderedReader(receiver, 10, time.Second)

// First Read should get 0 (which arrives 2nd)
m0, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(0), m0.SeqId)

// Second Read should get 1 (which arrives 3rd)
m1, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(1), m1.SeqId)

// Third Read should get 2 (which arrived 1st and was buffered)
m2, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(2), m2.SeqId)
}

func TestOrderedReader_StaleTimeout(t *testing.T) {
// Send 1 (gap, missing 0) and wait
receiver := func() (*portalpb.Mote, error) {
time.Sleep(10 * time.Millisecond)
return &portalpb.Mote{SeqId: 1, StreamId: "s1"}, nil
}

// Timeout very short
r := NewOrderedReader(receiver, 10, 50*time.Millisecond)

// First read will receive 1, buffer it, then loop.
// Receiver sleeps 10ms.
// It will keep receiving 1 (duplicate) or we need to simulate a stream of future packets?
// The receiver above returns 1 every time.
// So:
// 1. Read() calls receiver -> gets 1. SeqID > 0. Buffers 1. Loop.
// 2. Read() calls receiver -> gets 1. SeqID > 0. Ignored (duplicate logic?).
// Wait, duplicates: "mote.SeqId > r.nextSeqID".
// If duplicate in buffer, we overwrite or ignore.
// "if _, exists := r.buffer[mote.SeqId]; !exists { r.buffer[...] = ... }"
// So it ignores duplicates in buffer.
// Loop continues.
// 3. Eventually time.Since(firstBufferedAt) > 50ms.
// Should error.

_, err := r.Read()
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout")
}

func TestOrderedReader_BufferLimit(t *testing.T) {
// Send 1, 2, 3... buffer limit 2. Missing 0.
idx := 1
receiver := func() (*portalpb.Mote, error) {
m := &portalpb.Mote{SeqId: uint64(idx), StreamId: "s1"}
idx++
return m, nil
}

r := NewOrderedReader(receiver, 2, time.Second)

_, err := r.Read()
assert.Error(t, err)
assert.Contains(t, err.Error(), "buffer limit exceeded")
}

func TestOrderedReader_DuplicateHandling(t *testing.T) {
// Send 0, 0, 1
motes := []*portalpb.Mote{
{SeqId: 0, StreamId: "s1"},
{SeqId: 0, StreamId: "s1"}, // Duplicate
{SeqId: 1, StreamId: "s1"},
}
idx := 0
receiver := func() (*portalpb.Mote, error) {
if idx >= len(motes) {
return nil, errors.New("EOF")
}
m := motes[idx]
idx++
return m, nil
}

r := NewOrderedReader(receiver, 10, time.Second)

m0, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(0), m0.SeqId)

m1, err := r.Read()
require.NoError(t, err)
assert.Equal(t, uint64(1), m1.SeqId)
}
Loading