Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tavern/internal/http/stream/circular_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package stream

import (
"sync"
)

// CircularBuffer is a fixed-size byte buffer that overwrites old data when full.
// It is safe for concurrent use.
type CircularBuffer struct {
mu sync.Mutex
data []byte
size int
start int
length int
}

// NewCircularBuffer creates a new circular buffer with the given size.
func NewCircularBuffer(size int) *CircularBuffer {
return &CircularBuffer{
data: make([]byte, size),
size: size,
start: 0,
length: 0,
}
}

// Write appends data to the buffer.
func (cb *CircularBuffer) Write(p []byte) {
cb.mu.Lock()
defer cb.mu.Unlock()

n := len(p)
if n == 0 {
return
}

// If the data being written is larger than the buffer size,
// we only care about the last `size` bytes.
if n >= cb.size {
copy(cb.data, p[n-cb.size:])
cb.start = 0
cb.length = cb.size
return
}

// We are writing n bytes.
// We write starting at (start + length) % size.
writeStart := (cb.start + cb.length) % cb.size

// Check if the write wraps around the end of the buffer
if writeStart+n <= cb.size {
// Contiguous write
copy(cb.data[writeStart:], p)
} else {
// Wrapped write
chunk1 := cb.size - writeStart
copy(cb.data[writeStart:], p[:chunk1])
copy(cb.data[0:], p[chunk1:])
}

// Update length and start
if cb.length+n <= cb.size {
cb.length += n
} else {
// Buffer overflowed
overflow := (cb.length + n) - cb.size
cb.start = (cb.start + overflow) % cb.size
cb.length = cb.size
}
}

// Bytes returns the current content of the buffer.
func (cb *CircularBuffer) Bytes() []byte {
cb.mu.Lock()
defer cb.mu.Unlock()

out := make([]byte, cb.length)
if cb.length == 0 {
return out
}

// If the data is contiguous
if cb.start+cb.length <= cb.size {
copy(out, cb.data[cb.start:cb.start+cb.length])
} else {
// Data wraps around
chunk1 := cb.size - cb.start
copy(out, cb.data[cb.start:])
copy(out[chunk1:], cb.data[:cb.length-chunk1])
}
return out
}
33 changes: 33 additions & 0 deletions tavern/internal/http/stream/circular_buffer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package stream

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCircularBuffer(t *testing.T) {
cb := NewCircularBuffer(10)

// Test basic write
cb.Write([]byte("hello"))
assert.Equal(t, []byte("hello"), cb.Bytes())

// Test append within size
cb2 := NewCircularBuffer(10)
cb2.Write([]byte("hello "))
cb2.Write([]byte("world"))
assert.Equal(t, []byte("ello world"), cb2.Bytes())

// Test write larger than size
cb3 := NewCircularBuffer(5)
cb3.Write([]byte("1234567890"))
assert.Equal(t, []byte("67890"), cb3.Bytes())

// Test write exact size
cb4 := NewCircularBuffer(5)
cb4.Write([]byte("12345"))
assert.Equal(t, []byte("12345"), cb4.Bytes())
cb4.Write([]byte("6"))
assert.Equal(t, []byte("23456"), cb4.Bytes())
}
195 changes: 133 additions & 62 deletions tavern/internal/http/stream/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ const (
// maxRegistrationBufSize defines the maximum receivers that can be buffered in the registration / unregistration channel
// before new calls to `mux.Register()` and `mux.Unregister()` will block.
maxRegistrationBufSize = 256
// defaultHistorySize is the default size of the circular buffer for stream history.
defaultHistorySize = 1024
)

var upgrader = websocket.Upgrader{
Expand All @@ -22,28 +24,44 @@ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}

type historyState struct {
buffer *CircularBuffer
sessions map[string]*sessionBuffer
}

// A Mux enables multiplexing subscription messages to multiple Streams.
// Streams will only receive a Message if their configured ID matches the incoming metadata of a Message.
// Additionally, new messages may be published using the Mux.
type Mux struct {
pub *pubsub.Topic
sub *pubsub.Subscription
register chan *Stream
unregister chan *Stream
streams map[*Stream]bool
pub *pubsub.Topic
sub *pubsub.Subscription
register chan *Stream
unregister chan *Stream
streams map[*Stream]bool
history map[string]*historyState
historySize int
}

// A MuxOption is used to provide further configuration to the Mux.
type MuxOption func(*Mux)

// WithHistorySize sets the size of the circular buffer for stream history.
func WithHistorySize(size int) MuxOption {
return func(m *Mux) {
m.historySize = size
}
}

// NewMux initializes and returns a new Mux with the provided pubsub info.
func NewMux(pub *pubsub.Topic, sub *pubsub.Subscription, options ...MuxOption) *Mux {
mux := &Mux{
pub: pub,
sub: sub,
register: make(chan *Stream, maxRegistrationBufSize),
unregister: make(chan *Stream, maxRegistrationBufSize),
streams: make(map[*Stream]bool),
pub: pub,
sub: sub,
register: make(chan *Stream, maxRegistrationBufSize),
unregister: make(chan *Stream, maxRegistrationBufSize),
streams: make(map[*Stream]bool),
history: make(map[string]*historyState),
historySize: defaultHistorySize,
}
for _, opt := range options {
opt(mux)
Expand All @@ -66,83 +84,110 @@ func (mux *Mux) Register(s *Stream) {
mux.register <- s
}

// registerStreams inserts all registered streams into the streams map.
func (mux *Mux) registerStreams(ctx context.Context) {
for {
select {
case s := <-mux.register:
slog.DebugContext(ctx, "mux registering new stream", "stream_id", s.id)
mux.streams[s] = true
default:
return
}
}
}

// Unregister a stream when it should no longer receive Messages from the Mux.
// Typically this is done via defer after registering a Stream.
// Unregistering a stream that is not registered will still close the stream channel.
func (mux *Mux) Unregister(s *Stream) {
mux.unregister <- s
}

// unregisterStreams deletes all unregistered streams from the streams map.
func (mux *Mux) unregisterStreams(ctx context.Context) {
for {
select {
case s := <-mux.unregister:
slog.DebugContext(ctx, "mux unregistering stream", "stream_id", s.id)
delete(mux.streams, s)
s.Close()
default:
return
}
}
}

// Start the mux, returning an error if polling ever fails.
func (mux *Mux) Start(ctx context.Context) error {
slog.DebugContext(ctx, "mux starting to manage streams and polling")
for {
// Manage Streams
mux.registerStreams(ctx)
mux.unregisterStreams(ctx)

// Poll for new messages
// Message channel to receive messages from the poller
type pollResult struct {
msg *pubsub.Message
err error
}
msgChan := make(chan pollResult)

// Start poller goroutine
go func() {
defer close(msgChan)
for {
msg, err := mux.sub.Receive(ctx)
select {
case <-ctx.Done():
return
case msgChan <- pollResult{msg: msg, err: err}:
// If context is done, stop.
if ctx.Err() != nil {
return
}
// Otherwise, loop again (retry on error).
}
}
}()

for {
select {
case <-ctx.Done():
slog.DebugContext(ctx, "mux context finished, exiting")
return ctx.Err()
default:
slog.DebugContext(ctx, "mux polling for message")
if err := mux.poll(ctx); err != nil {
slog.ErrorContext(ctx, "mux failed to poll subscription", "error", err)

case s := <-mux.register:
// Handle Registration
slog.DebugContext(ctx, "mux registering new stream", "stream_id", s.id)
mux.streams[s] = true

// Send history to the new stream
if state, ok := mux.history[s.id]; ok && state.buffer != nil {
data := state.buffer.Bytes()
if len(data) > 0 {
slog.DebugContext(ctx, "mux sending history to new stream", "stream_id", s.id, "bytes", len(data))
msg := &pubsub.Message{
Body: data,
Metadata: map[string]string{
metadataID: s.id,
MetadataMsgKind: "data",
// No order key needed for history injection
},
}
s.processOneMessage(ctx, msg)
}
}

case s := <-mux.unregister:
// Handle Unregistration
slog.DebugContext(ctx, "mux unregistering stream", "stream_id", s.id)
delete(mux.streams, s)
s.Close()

case res, ok := <-msgChan:
if !ok {
// Poller exited. If due to context cancel, return that error.
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("poller exited unexpectedly")
}
if res.err != nil {
// Log error and continue, matching original behavior (retry loop).
// Unless context is done.
if ctx.Err() != nil {
return ctx.Err()
}
slog.ErrorContext(ctx, "mux failed to poll subscription", "error", res.err)
continue
}

// Handle Message
mux.handleMessage(ctx, res.msg)
}
}
}

// poll for a new message, broadcasting to all registered streams.
// poll will also register & unregister streams after a new message is received.
func (mux *Mux) poll(ctx context.Context) error {
// Block waiting for message
msg, err := mux.sub.Receive(ctx)
if err != nil {
return fmt.Errorf("failed to poll for new message: %w", err)
}

// handleMessage processes a new message, updating history and broadcasting to streams.
func (mux *Mux) handleMessage(ctx context.Context, msg *pubsub.Message) {
// Always acknowledge the message
defer msg.Ack()

// Manage Streams
mux.registerStreams(ctx)
mux.unregisterStreams(ctx)

// Get Message Metadata
msgID, ok := msg.Metadata["id"]
if !ok {
slog.DebugContext(ctx, "mux received message without 'id' for stream, ignoring")
return nil
return
}
msgOrderKey, ok := msg.Metadata[metadataOrderKey]
if !ok {
Expand All @@ -153,6 +198,34 @@ func (mux *Mux) poll(ctx context.Context) error {
slog.DebugContext(ctx, "mux received message without msgOrderIndex")
}

// Update History
// Only buffer "data" messages (or messages with no kind specified, which default to data)
kind, hasKind := msg.Metadata[MetadataMsgKind]
if !hasKind || kind == "data" {
state, ok := mux.history[msgID]
if !ok {
state = &historyState{
buffer: NewCircularBuffer(mux.historySize),
sessions: make(map[string]*sessionBuffer),
}
mux.history[msgID] = state
}

// Use sessionBuffer to reorder messages before writing to circular buffer
key := parseOrderKey(msg)
sessBuf, ok := state.sessions[key]
if !ok {
sessBuf = &sessionBuffer{
data: make(map[uint64]*pubsub.Message, maxStreamOrderBuf),
}
state.sessions[key] = sessBuf
}

sessBuf.writeMessage(ctx, msg, func(m *pubsub.Message) {
state.buffer.Write(m.Body)
})
}

// Broadcast Message
slog.DebugContext(ctx, "mux broadcasting received message",
"msg_id", msgID,
Expand All @@ -176,6 +249,4 @@ func (mux *Mux) poll(ctx context.Context) error {
s.processOneMessage(ctx, msg)
}
}

return nil
}
Loading
Loading