From f04db64865153ade602a58bcee199748c6afee78 Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Thu, 13 Jun 2024 22:09:21 +0800 Subject: [PATCH] chore: Add compression mode and threshold to channel configuration --- channel/channel.go | 22 +++++++++++++++++++--- channel/channel_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/channel/channel.go b/channel/channel.go index 77ced9e..c3fce2f 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -18,7 +18,9 @@ type Channel struct { } type channelConfig struct { - originPatterns []string + originPatterns []string + compressionMode websocket.CompressionMode + compressionThreshold int } type Option func(*channelConfig) @@ -36,7 +38,9 @@ func NewChannel( opts ...Option, ) *Channel { config := channelConfig{ - originPatterns: []string{"*"}, + originPatterns: []string{"*"}, + compressionMode: websocket.CompressionDisabled, + compressionThreshold: 0, } for _, opt := range opts { @@ -73,7 +77,9 @@ func (c *Channel) wsConnectionHandler() http.Handler { } ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: c.config.originPatterns, + OriginPatterns: c.config.originPatterns, + CompressionMode: c.config.compressionMode, + CompressionThreshold: c.config.compressionThreshold, }) if err != nil { @@ -113,3 +119,13 @@ func WithOriginPatterns(patterns ...string) Option { c.originPatterns = patterns } } + +// WithCompressionMode sets the compression mode and threshold for a channel configuration. +// The compression mode determines how the channel data will be compressed, and the threshold +// specifies the minimum size of the payload required for compression to be applied. +func WithCompressionMode(mode websocket.CompressionMode, threshold int) Option { + return func(c *channelConfig) { + c.compressionMode = mode + c.compressionThreshold = threshold + } +} diff --git a/channel/channel_test.go b/channel/channel_test.go index ebe34af..59b690a 100644 --- a/channel/channel_test.go +++ b/channel/channel_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/ksysoev/wasabi/mocks" + "nhooyr.io/websocket" ) func TestNewChannel(t *testing.T) { @@ -204,3 +205,31 @@ func TestChannel_wsConnectionHandler_CanAcceptNewConnection(t *testing.T) { t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusUpgradeRequired) } } +func TestChannel_WithCompressionMode(t *testing.T) { + path := "/test/path" + dispatcher := mocks.NewMockDispatcher(t) + + channel := NewChannel(path, dispatcher, NewConnectionRegistry()) + + // Assert that the default compression mode and threshold are set correctly + if channel.config.compressionMode != websocket.CompressionDisabled { + t.Errorf("Unexpected compression mode: got %v, expected %v", channel.config.compressionMode, websocket.CompressionNoContextTakeover) + } + + if channel.config.compressionThreshold != 0 { + t.Errorf("Unexpected compression threshold: got %d, expected %d", channel.config.compressionThreshold, 0) + } + + compressionMode := websocket.CompressionNoContextTakeover + compressionThreshold := 1024 + channel = NewChannel(path, dispatcher, NewConnectionRegistry(), WithCompressionMode(compressionMode, compressionThreshold)) + + // Assert that the compression mode and threshold are set correctly + if channel.config.compressionMode != compressionMode { + t.Errorf("Unexpected compression mode: got %v, expected %v", channel.config.compressionMode, compressionMode) + } + + if channel.config.compressionThreshold != compressionThreshold { + t.Errorf("Unexpected compression threshold: got %d, expected %d", channel.config.compressionThreshold, compressionThreshold) + } +}