From fbd473b88442dae130d67fc2abb37fe82713bfdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Wed, 6 Nov 2024 16:42:24 +0100 Subject: [PATCH] refactor!: split subscribers and local subscribers --- bolt.go | 8 +-- bolt_test.go | 28 ++++---- local.go | 6 +- local_bench_test.go | 2 +- local_test.go | 28 ++++---- localsubscriber.go | 136 +++++++++++++++++++++++++++++++++++++++ metrics.go | 14 ++-- metrics_test.go | 8 +-- publish_test.go | 4 +- subscribe.go | 20 +++--- subscribe_test.go | 10 +-- subscriber.go | 131 ++----------------------------------- subscriber_bench_test.go | 2 +- subscriber_test.go | 10 +-- subscriberlist.go | 14 ++-- subscriberlist_test.go | 2 +- subscription_test.go | 12 ++-- transport.go | 8 +-- 18 files changed, 231 insertions(+), 212 deletions(-) create mode 100644 localsubscriber.go diff --git a/bolt.go b/bolt.go index 0843eb64..03f871da 100644 --- a/bolt.go +++ b/bolt.go @@ -202,7 +202,7 @@ func (t *BoltTransport) persist(updateID string, updateJSON []byte) error { } // AddSubscriber adds a new subscriber to the transport. -func (t *BoltTransport) AddSubscriber(s *Subscriber) error { +func (t *BoltTransport) AddSubscriber(s *LocalSubscriber) error { select { case <-t.closed: return ErrClosedTransport @@ -226,7 +226,7 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error { } // RemoveSubscriber removes a new subscriber from the transport. -func (t *BoltTransport) RemoveSubscriber(s *Subscriber) error { +func (t *BoltTransport) RemoveSubscriber(s *LocalSubscriber) error { select { case <-t.closed: return ErrClosedTransport @@ -249,7 +249,7 @@ func (t *BoltTransport) GetSubscribers() (string, []*Subscriber, error) { } //nolint:gocognit -func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) error { +func (t *BoltTransport) dispatchHistory(s *LocalSubscriber, toSeq uint64) error { err := t.db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(t.bucketName)) if b == nil { @@ -311,7 +311,7 @@ func (t *BoltTransport) Close() (err error) { t.Lock() defer t.Unlock() - t.subscribers.Walk(0, func(s *Subscriber) bool { + t.subscribers.Walk(0, func(s *LocalSubscriber) bool { s.Disconnect() return true diff --git a/bolt_test.go b/bolt_test.go index ba78612c..4d17f6d0 100644 --- a/bolt_test.go +++ b/bolt_test.go @@ -35,7 +35,7 @@ func TestBoltTransportHistory(t *testing.T) { }) } - s := NewSubscriber("8", transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("8", transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -68,7 +68,7 @@ func TestBoltTransportLogsBogusLastEventID(t *testing.T) { Topics: topics, }) - s := NewSubscriber("711131", logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("711131", logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -87,7 +87,7 @@ func TestBoltTopicSelectorHistory(t *testing.T) { transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Private: true, Event: Event{ID: "3"}}) transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Event: Event{ID: "4"}}) - s := NewSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/subscribed", "http://example.com/subscribed-public-only"}, []string{"http://example.com/subscribed"}) require.NoError(t, transport.AddSubscriber(s)) @@ -109,7 +109,7 @@ func TestBoltTransportRetrieveAllHistory(t *testing.T) { }) } - s := NewSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -139,7 +139,7 @@ func TestBoltTransportHistoryAndLive(t *testing.T) { }) } - s := NewSubscriber("8", transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("8", transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -221,7 +221,7 @@ func TestBoltTransportDoNotDispatchUntilListen(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s)) var wg sync.WaitGroup @@ -245,7 +245,7 @@ func TestBoltTransportDispatch(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/foo", "https://example.com/private"}, []string{"https://example.com/private"}) require.NoError(t, transport.AddSubscriber(s)) @@ -274,7 +274,7 @@ func TestBoltTransportClosed(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/foo"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -293,11 +293,11 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s1 := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) s1.SetTopics([]string{"foo"}, []string{}) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s2 := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) s2.SetTopics([]string{"foo"}, []string{}) require.NoError(t, transport.AddSubscriber(s2)) @@ -318,10 +318,10 @@ func TestBoltGetSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s1 := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) + s2 := NewLocalSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s2)) lastEventID, subscribers, err := transport.GetSubscribers() @@ -329,8 +329,8 @@ func TestBoltGetSubscribers(t *testing.T) { assert.Equal(t, EarliestLastEventID, lastEventID) assert.Len(t, subscribers, 2) - assert.Contains(t, subscribers, s1) - assert.Contains(t, subscribers, s2) + assert.Contains(t, subscribers, &s1.Subscriber) + assert.Contains(t, subscribers, &s2.Subscriber) } func TestBoltLastEventID(t *testing.T) { diff --git a/local.go b/local.go index 53ce1aa9..b8e160e0 100644 --- a/local.go +++ b/local.go @@ -54,7 +54,7 @@ func (t *LocalTransport) Dispatch(update *Update) error { } // AddSubscriber adds a new subscriber to the transport. -func (t *LocalTransport) AddSubscriber(s *Subscriber) error { +func (t *LocalTransport) AddSubscriber(s *LocalSubscriber) error { select { case <-t.closed: return ErrClosedTransport @@ -74,7 +74,7 @@ func (t *LocalTransport) AddSubscriber(s *Subscriber) error { } // RemoveSubscriber removes a subscriber from the transport. -func (t *LocalTransport) RemoveSubscriber(s *Subscriber) error { +func (t *LocalTransport) RemoveSubscriber(s *LocalSubscriber) error { select { case <-t.closed: return ErrClosedTransport @@ -102,7 +102,7 @@ func (t *LocalTransport) Close() (err error) { t.Lock() defer t.Unlock() close(t.closed) - t.subscribers.Walk(0, func(s *Subscriber) bool { + t.subscribers.Walk(0, func(s *LocalSubscriber) bool { s.Disconnect() return true diff --git a/local_bench_test.go b/local_bench_test.go index c00eabbe..58287295 100644 --- a/local_bench_test.go +++ b/local_bench_test.go @@ -41,7 +41,7 @@ func subBenchLocalTransport(b *testing.B, topics, concurrency, matchPct int, tes out := make(chan *Update, 50000) tss := &TopicSelectorStore{} for i := 0; i < concurrency; i++ { - s := NewSubscriber("", zap.NewNop(), tss) + s := NewLocalSubscriber("", zap.NewNop(), tss) if i%100 < matchPct { s.SetTopics(tsMatch, nil) } else { diff --git a/local_test.go b/local_test.go index 964c5ec7..33e89add 100644 --- a/local_test.go +++ b/local_test.go @@ -20,7 +20,7 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { err := transport.Dispatch(u) require.NoError(t, err) - s := NewSubscriber("", logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics(u.Topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -43,7 +43,7 @@ func TestLocalTransportDispatch(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -60,10 +60,10 @@ func TestLocalTransportClosed(t *testing.T) { tss := &TopicSelectorStore{} - s := NewSubscriber("", logger, tss) + s := NewLocalSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s)) require.NoError(t, transport.Close()) - assert.Equal(t, transport.AddSubscriber(NewSubscriber("", logger, tss)), ErrClosedTransport) + assert.Equal(t, transport.AddSubscriber(NewLocalSubscriber("", logger, tss)), ErrClosedTransport) assert.Equal(t, transport.Dispatch(&Update{}), ErrClosedTransport) _, ok := <-s.out @@ -78,10 +78,10 @@ func TestLiveCleanDisconnectedSubscribers(t *testing.T) { tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s2)) assert.Equal(t, 2, transport.subscribers.Len()) @@ -101,7 +101,7 @@ func TestLiveReading(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -113,23 +113,23 @@ func TestLiveReading(t *testing.T) { } func TestLocalTransportGetSubscribers(t *testing.T) { - logger := zap.NewNop() - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) + transport := NewLocalTransport() defer transport.Close() require.NotNil(t, transport) + logger := zap.NewNop() tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s2)) - lastEventID, subscribers, err := transport.(TransportSubscribers).GetSubscribers() + lastEventID, subscribers, err := transport.GetSubscribers() require.NoError(t, err) assert.Equal(t, EarliestLastEventID, lastEventID) assert.Len(t, subscribers, 2) - assert.Contains(t, subscribers, s1) - assert.Contains(t, subscribers, s2) + assert.Contains(t, subscribers, &s1.Subscriber) + assert.Contains(t, subscribers, &s2.Subscriber) } diff --git a/localsubscriber.go b/localsubscriber.go new file mode 100644 index 00000000..b821cb67 --- /dev/null +++ b/localsubscriber.go @@ -0,0 +1,136 @@ +package mercure + +import ( + "net/url" + "sync" + "sync/atomic" + + "github.com/gofrs/uuid" + "go.uber.org/zap" +) + +// LocalSubscriber represents a client subscribed to a list of topics on the current hub. +type LocalSubscriber struct { + Subscriber + + disconnected int32 + out chan *Update + outMutex sync.RWMutex + responseLastEventID chan string + ready int32 + liveQueue []*Update + liveMutex sync.RWMutex +} + +const outBufferLength = 1000 + +// NewLocalSubscriber creates a new subscriber. +func NewLocalSubscriber(lastEventID string, logger Logger, topicSelectorStore *TopicSelectorStore) *LocalSubscriber { + id := "urn:uuid:" + uuid.Must(uuid.NewV4()).String() + s := &LocalSubscriber{ + Subscriber: *NewSubscriber(logger, topicSelectorStore), + responseLastEventID: make(chan string, 1), + out: make(chan *Update, outBufferLength), + } + + s.ID = id + s.EscapedID = url.QueryEscape(id) + s.RequestLastEventID = lastEventID + + return s +} + +// Dispatch an update to the subscriber. +// Security checks must (topics matching) be done before calling Dispatch, +// for instance by calling Match. +func (s *LocalSubscriber) Dispatch(u *Update, fromHistory bool) bool { + if atomic.LoadInt32(&s.disconnected) > 0 { + return false + } + + if !fromHistory && atomic.LoadInt32(&s.ready) < 1 { + s.liveMutex.Lock() + if s.ready < 1 { + s.liveQueue = append(s.liveQueue, u) + s.liveMutex.Unlock() + + return true + } + s.liveMutex.Unlock() + } + + s.outMutex.Lock() + if atomic.LoadInt32(&s.disconnected) > 0 { + s.outMutex.Unlock() + + return false + } + + select { + case s.out <- u: + s.outMutex.Unlock() + default: + s.handleFullChan() + + return false + } + + return true +} + +// Ready flips the ready flag to true and flushes queued live updates returning number of events flushed. +func (s *LocalSubscriber) Ready() (n int) { + s.liveMutex.Lock() + s.outMutex.Lock() + + for _, u := range s.liveQueue { + select { + case s.out <- u: + n++ + default: + s.handleFullChan() + s.liveMutex.Unlock() + + return n + } + } + atomic.StoreInt32(&s.ready, 1) + + s.outMutex.Unlock() + s.liveMutex.Unlock() + + return n +} + +// Receive returns a chan when incoming updates are dispatched. +func (s *LocalSubscriber) Receive() <-chan *Update { + return s.out +} + +// HistoryDispatched must be called when all messages coming from the history have been dispatched. +func (s *LocalSubscriber) HistoryDispatched(responseLastEventID string) { + s.responseLastEventID <- responseLastEventID +} + +// Disconnect disconnects the subscriber. +func (s *LocalSubscriber) Disconnect() { + if atomic.LoadInt32(&s.disconnected) > 0 { + return + } + + s.outMutex.Lock() + defer s.outMutex.Unlock() + + atomic.StoreInt32(&s.disconnected, 1) + close(s.out) +} + +// handleFullChan disconnects the subscriber when the out channel is full. +func (s *LocalSubscriber) handleFullChan() { + atomic.StoreInt32(&s.disconnected, 1) + s.outMutex.Unlock() + + if c := s.logger.Check(zap.ErrorLevel, "subscriber unable to receive updates fast enough"); c != nil { + c.Write(zap.Object("subscriber", s)) + } +} diff --git a/metrics.go b/metrics.go index 8f4dc97e..e5398bfe 100644 --- a/metrics.go +++ b/metrics.go @@ -13,18 +13,18 @@ const metricsPath = "/metrics" type Metrics interface { // SubscriberConnected collects metrics about subscriber connections. - SubscriberConnected(s *Subscriber) + SubscriberConnected(s *LocalSubscriber) // SubscriberDisconnected collects metrics about subscriber disconnections. - SubscriberDisconnected(s *Subscriber) + SubscriberDisconnected(s *LocalSubscriber) // UpdatePublished collects metrics about update publications. UpdatePublished(u *Update) } type NopMetrics struct{} -func (NopMetrics) SubscriberConnected(_ *Subscriber) {} -func (NopMetrics) SubscriberDisconnected(_ *Subscriber) {} -func (NopMetrics) UpdatePublished(_ *Update) {} +func (NopMetrics) SubscriberConnected(_ *LocalSubscriber) {} +func (NopMetrics) SubscriberDisconnected(_ *LocalSubscriber) {} +func (NopMetrics) UpdatePublished(_ *Update) {} // PrometheusMetrics store Hub collected metrics. type PrometheusMetrics struct { @@ -84,12 +84,12 @@ func (m *PrometheusMetrics) Register(r *mux.Router) { r.Handle(metricsPath, promhttp.HandlerFor(m.registry.(*prometheus.Registry), promhttp.HandlerOpts{})).Methods(http.MethodGet) } -func (m *PrometheusMetrics) SubscriberConnected(_ *Subscriber) { +func (m *PrometheusMetrics) SubscriberConnected(_ *LocalSubscriber) { m.subscribersTotal.Inc() m.subscribers.Inc() } -func (m *PrometheusMetrics) SubscriberDisconnected(_ *Subscriber) { +func (m *PrometheusMetrics) SubscriberDisconnected(_ *LocalSubscriber) { m.subscribers.Dec() } diff --git a/metrics_test.go b/metrics_test.go index 88e82b8c..5af78744 100644 --- a/metrics_test.go +++ b/metrics_test.go @@ -15,12 +15,12 @@ func TestNumberOfRunningSubscribers(t *testing.T) { logger := zap.NewNop() tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) s1.SetTopics([]string{"topic1", "topic2"}, nil) m.SubscriberConnected(s1) assertGaugeValue(t, 1.0, m.subscribers) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) s2.SetTopics([]string{"topic2"}, nil) m.SubscriberConnected(s2) assertGaugeValue(t, 2.0, m.subscribers) @@ -38,12 +38,12 @@ func TestTotalNumberOfHandledSubscribers(t *testing.T) { logger := zap.NewNop() tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) s1.SetTopics([]string{"topic1", "topic2"}, nil) m.SubscriberConnected(s1) assertCounterValue(t, 1.0, m.subscribersTotal) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) s2.SetTopics([]string{"topic2"}, nil) m.SubscriberConnected(s2) assertCounterValue(t, 2.0, m.subscribersTotal) diff --git a/publish_test.go b/publish_test.go index 92dc2d16..b3b2164e 100644 --- a/publish_test.go +++ b/publish_test.go @@ -174,7 +174,7 @@ func TestPublishOK(t *testing.T) { hub := createDummy() topics := []string{"http://example.com/books/1"} - s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics(topics, topics) s.Claims = &claims{Mercure: mercureClaim{Subscribe: topics}} @@ -238,7 +238,7 @@ func TestPublishNoData(t *testing.T) { func TestPublishGenerateUUID(t *testing.T) { h := createDummy() - s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/books/1"}, s.SubscribedTopics) require.NoError(t, h.transport.AddSubscriber(s)) diff --git a/subscribe.go b/subscribe.go index 293a8902..86499b61 100644 --- a/subscribe.go +++ b/subscribe.go @@ -17,7 +17,7 @@ type responseController struct { // writeDeadline is the JWT expiration date or time.Now() + hub.writeTimeout writeDeadline time.Time hub *Hub - subscriber *Subscriber + subscriber *LocalSubscriber } func (rc *responseController) setDispatchWriteDeadline() bool { @@ -73,13 +73,13 @@ func (rc *responseController) flush() bool { return true } -func (h *Hub) newResponseController(w http.ResponseWriter, s *Subscriber) *responseController { +func (h *Hub) newResponseController(w http.ResponseWriter, s *LocalSubscriber) *responseController { wd := h.getWriteDeadline(s) return &responseController{*http.NewResponseController(w), w, wd.Add(-h.dispatchTimeout), wd, h, s} // nolint:bodyclose } -func (h *Hub) getWriteDeadline(s *Subscriber) (deadline time.Time) { +func (h *Hub) getWriteDeadline(s *LocalSubscriber) (deadline time.Time) { if h.writeTimeout != 0 { deadline = time.Now().Add(h.writeTimeout) } @@ -155,8 +155,8 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { } // registerSubscriber initializes the connection. -func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subscriber, *responseController) { - s := NewSubscriber(retrieveLastEventID(r, h.opt, h.logger), h.logger, h.topicSelectorStore) +func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*LocalSubscriber, *responseController) { + s := NewLocalSubscriber(retrieveLastEventID(r, h.opt, h.logger), h.logger, h.topicSelectorStore) s.RemoteAddr = r.RemoteAddr var privateTopics []string var claims *claims @@ -170,7 +170,7 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subsc } if err != nil || (claims == nil && !h.anonymous) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - if c := h.logger.Check(zap.DebugLevel, "Subscriber unauthorized"); c != nil { + if c := h.logger.Check(zap.DebugLevel, "LocalSubscriber unauthorized"); c != nil { c.Write(zap.Object("subscriber", s), zap.Error(err)) } @@ -215,7 +215,7 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subsc } // sendHeaders sends correct HTTP headers to create a keep-alive connection. -func (h *Hub) sendHeaders(w http.ResponseWriter, s *Subscriber) { +func (h *Hub) sendHeaders(w http.ResponseWriter, s *LocalSubscriber) { // Keep alive, useful only for HTTP 1 clients https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Keep-Alive w.Header().Set("Connection", "keep-alive") @@ -287,7 +287,7 @@ func (h *Hub) write(rc *responseController, data string) bool { return rc.flush() && rc.setDefaultWriteDeadline() } -func (h *Hub) shutdown(s *Subscriber) { +func (h *Hub) shutdown(s *LocalSubscriber) { // Notify that the client is closing the connection s.Disconnect() if err := h.transport.RemoveSubscriber(s); err != nil { @@ -297,13 +297,13 @@ func (h *Hub) shutdown(s *Subscriber) { } h.dispatchSubscriptionUpdate(s, false) - if c := h.logger.Check(zap.InfoLevel, "Subscriber disconnected"); c != nil { + if c := h.logger.Check(zap.InfoLevel, "LocalSubscriber disconnected"); c != nil { c.Write(zap.Object("subscriber", s)) } h.metrics.SubscriberDisconnected(s) } -func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { +func (h *Hub) dispatchSubscriptionUpdate(s *LocalSubscriber, active bool) { if !h.subscriptions { return } diff --git a/subscribe_test.go b/subscribe_test.go index f3550709..81f7be21 100644 --- a/subscribe_test.go +++ b/subscribe_test.go @@ -238,16 +238,16 @@ func (*addSubscriberErrorTransport) Dispatch(*Update) error { return nil } -func (*addSubscriberErrorTransport) AddSubscriber(*Subscriber) error { +func (*addSubscriberErrorTransport) AddSubscriber(*LocalSubscriber) error { return errFailedToAddSubscriber } -func (*addSubscriberErrorTransport) RemoveSubscriber(*Subscriber) error { +func (*addSubscriberErrorTransport) RemoveSubscriber(*LocalSubscriber) error { return nil } -func (*addSubscriberErrorTransport) GetSubscribers() (string, []*Subscriber, error) { - return "", []*Subscriber{}, nil +func (*addSubscriberErrorTransport) GetSubscribers() (string, []*LocalSubscriber, error) { + return "", []*LocalSubscriber{}, nil } func (*addSubscriberErrorTransport) Close() error { @@ -412,7 +412,7 @@ func TestUnsubscribe(t *testing.T) { req := httptest.NewRequest(http.MethodGet, defaultHubURL+"?topic=http://example.com/books/1", nil).WithContext(ctx) hub.SubscribeHandler(newSubscribeRecorder(), req) assert.Equal(t, 0, s.subscribers.Len()) - s.subscribers.Walk(0, func(s *Subscriber) bool { + s.subscribers.Walk(0, func(s *LocalSubscriber) bool { _, ok := <-s.out assert.False(t, ok) diff --git a/subscriber.go b/subscriber.go index 3e68ff0b..44212e15 100644 --- a/subscriber.go +++ b/subscriber.go @@ -4,15 +4,11 @@ import ( "fmt" "net/url" "regexp" - "sync" - "sync/atomic" - "github.com/gofrs/uuid" - "go.uber.org/zap" "go.uber.org/zap/zapcore" ) -// Subscriber represents a client subscribed to a list of topics. +// Subscriber represents a client subscribed to a list of topics on a remote or on the current hub. type Subscriber struct { ID string EscapedID string @@ -25,118 +21,15 @@ type Subscriber struct { AllowedPrivateTopics []string AllowedPrivateRegexps []*regexp.Regexp - disconnected int32 - out chan *Update - outMutex sync.RWMutex - responseLastEventID chan string - logger Logger - ready int32 - liveQueue []*Update - liveMutex sync.RWMutex - topicSelectorStore *TopicSelectorStore + logger Logger + topicSelectorStore *TopicSelectorStore } -const outBufferLength = 1000 - -// NewSubscriber creates a new subscriber. -func NewSubscriber(lastEventID string, logger Logger, topicSelectorStore *TopicSelectorStore) *Subscriber { - id := "urn:uuid:" + uuid.Must(uuid.NewV4()).String() - s := &Subscriber{ - ID: id, - EscapedID: url.QueryEscape(id), - RequestLastEventID: lastEventID, - responseLastEventID: make(chan string, 1), - out: make(chan *Update, outBufferLength), - logger: logger, - topicSelectorStore: topicSelectorStore, +func NewSubscriber(logger Logger, topicSelectorStore *TopicSelectorStore) *Subscriber { + return &Subscriber{ + logger: logger, + topicSelectorStore: topicSelectorStore, } - - return s -} - -// Dispatch an update to the subscriber. -// Security checks must (topics matching) be done before calling Dispatch, -// for instance by calling Match. -func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { - if atomic.LoadInt32(&s.disconnected) > 0 { - return false - } - - if !fromHistory && atomic.LoadInt32(&s.ready) < 1 { - s.liveMutex.Lock() - if s.ready < 1 { - s.liveQueue = append(s.liveQueue, u) - s.liveMutex.Unlock() - - return true - } - s.liveMutex.Unlock() - } - - s.outMutex.Lock() - if atomic.LoadInt32(&s.disconnected) > 0 { - s.outMutex.Unlock() - - return false - } - - select { - case s.out <- u: - s.outMutex.Unlock() - default: - s.handleFullChan() - - return false - } - - return true -} - -// Ready flips the ready flag to true and flushes queued live updates returning number of events flushed. -func (s *Subscriber) Ready() (n int) { - s.liveMutex.Lock() - s.outMutex.Lock() - - for _, u := range s.liveQueue { - select { - case s.out <- u: - n++ - default: - s.handleFullChan() - s.liveMutex.Unlock() - - return n - } - } - atomic.StoreInt32(&s.ready, 1) - - s.outMutex.Unlock() - s.liveMutex.Unlock() - - return n -} - -// Receive returns a chan when incoming updates are dispatched. -func (s *Subscriber) Receive() <-chan *Update { - return s.out -} - -// HistoryDispatched must be called when all messages coming from the history have been dispatched. -func (s *Subscriber) HistoryDispatched(responseLastEventID string) { - s.responseLastEventID <- responseLastEventID -} - -// Disconnect disconnects the subscriber. -func (s *Subscriber) Disconnect() { - if atomic.LoadInt32(&s.disconnected) > 0 { - return - } - - s.outMutex.Lock() - defer s.outMutex.Unlock() - - atomic.StoreInt32(&s.disconnected, 1) - close(s.out) } // SetTopics compiles topic selector regexps. @@ -237,13 +130,3 @@ func (s *Subscriber) MarshalLogObject(enc zapcore.ObjectEncoder) error { return nil } - -// handleFullChan disconnects the subscriber when the out channel is full. -func (s *Subscriber) handleFullChan() { - atomic.StoreInt32(&s.disconnected, 1) - s.outMutex.Unlock() - - if c := s.logger.Check(zap.ErrorLevel, "subscriber unable to receive updates fast enough"); c != nil { - c.Write(zap.Object("subscriber", s)) - } -} diff --git a/subscriber_bench_test.go b/subscriber_bench_test.go index 3b5e5ba0..d78c5f54 100644 --- a/subscriber_bench_test.go +++ b/subscriber_bench_test.go @@ -85,7 +85,7 @@ func strInt(s string) int { func subBenchSubscriber(b *testing.B, topics, concurrency, matchPct int, testName string) { b.Helper() - s := NewSubscriber("0e249241-6432-4ce1-b9b9-5d170163c253", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("0e249241-6432-4ce1-b9b9-5d170163c253", zap.NewNop(), &TopicSelectorStore{}) ts := make([]string, topics) tsMatch := make([]string, topics) tsNoMatch := make([]string, topics) diff --git a/subscriber_test.go b/subscriber_test.go index 7c324b5f..b3fe3fff 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -9,7 +9,7 @@ import ( ) func TestDispatch(t *testing.T) { - s := NewSubscriber("1", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("1", zap.NewNop(), &TopicSelectorStore{}) s.SubscribedTopics = []string{"http://example.com"} s.SubscribedTopics = []string{"http://example.com"} defer s.Disconnect() @@ -32,7 +32,7 @@ func TestDispatch(t *testing.T) { } func TestDisconnect(t *testing.T) { - s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.Disconnect() // can be called two times without crashing s.Disconnect() @@ -44,7 +44,7 @@ func TestLogSubscriber(t *testing.T) { sink, logger := newTestLogger(t) defer sink.Reset() - s := NewSubscriber("123", logger, &TopicSelectorStore{}) + s := NewLocalSubscriber("123", logger, &TopicSelectorStore{}) s.RemoteAddr = "127.0.0.1" s.SetTopics([]string{"https://example.com/bar"}, []string{"https://example.com/foo"}) @@ -59,7 +59,7 @@ func TestLogSubscriber(t *testing.T) { } func TestMatchTopic(t *testing.T) { - s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/no-match", "https://example.com/books/{id}"}, []string{"https://example.com/users/foo/{?topic}"}) assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}})) @@ -73,7 +73,7 @@ func TestMatchTopic(t *testing.T) { } func TestSubscriberDoesNotBlockWhenChanIsFull(t *testing.T) { - s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) + s := NewLocalSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.Ready() for i := 0; i <= outBufferLength; i++ { diff --git a/subscriberlist.go b/subscriberlist.go index 97221ab1..57ef57dc 100644 --- a/subscriberlist.go +++ b/subscriberlist.go @@ -26,7 +26,7 @@ var replacer = strings.NewReplacer( func NewSubscriberList(size int) *SubscriberList { return &SubscriberList{ skipfilter: skipfilter.New(func(s interface{}, filter interface{}) bool { - return s.(*Subscriber).MatchTopics(decode(filter.(string))) + return s.(*LocalSubscriber).MatchTopics(decode(filter.(string))) }, size), } } @@ -89,25 +89,25 @@ func decode(f string) (topics []string, private bool) { return topics, private } -func (sl *SubscriberList) MatchAny(u *Update) (res []*Subscriber) { +func (sl *SubscriberList) MatchAny(u *Update) (res []*LocalSubscriber) { for _, m := range sl.skipfilter.MatchAny(encode(u.Topics, u.Private)) { - res = append(res, m.(*Subscriber)) + res = append(res, m.(*LocalSubscriber)) } return } -func (sl *SubscriberList) Walk(start uint64, callback func(s *Subscriber) bool) uint64 { +func (sl *SubscriberList) Walk(start uint64, callback func(s *LocalSubscriber) bool) uint64 { return sl.skipfilter.Walk(start, func(val interface{}) bool { - return callback(val.(*Subscriber)) + return callback(val.(*LocalSubscriber)) }) } -func (sl *SubscriberList) Add(s *Subscriber) { +func (sl *SubscriberList) Add(s *LocalSubscriber) { sl.skipfilter.Add(s) } -func (sl *SubscriberList) Remove(s *Subscriber) { +func (sl *SubscriberList) Remove(s *LocalSubscriber) { sl.skipfilter.Remove(s) } diff --git a/subscriberlist_test.go b/subscriberlist_test.go index 0c25b5ed..5cd6a255 100644 --- a/subscriberlist_test.go +++ b/subscriberlist_test.go @@ -26,7 +26,7 @@ func BenchmarkSubscriberList(b *testing.B) { l := NewSubscriberList(100) for i := 0; i < 100; i++ { - s := NewSubscriber("", logger, tss) + s := NewLocalSubscriber("", logger, tss) t := fmt.Sprintf("https://example.com/%d", i%10) s.SetTopics([]string{"https://example.org/foo", t}, []string{"https://example.net/bar", t}) diff --git a/subscription_test.go b/subscription_test.go index 13363b8a..25334c01 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -87,11 +87,11 @@ func TestSubscriptionsHandler(t *testing.T) { hub := createDummy(WithLogger(logger)) tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) s1.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, hub.transport.AddSubscriber(s1)) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) s2.SetTopics([]string{"http://example.com/bar"}, nil) require.NoError(t, hub.transport.AddSubscriber(s2)) @@ -128,11 +128,11 @@ func TestSubscriptionsHandlerForTopic(t *testing.T) { hub := createDummy(WithLogger(logger)) tss := &TopicSelectorStore{} - s1 := NewSubscriber("", logger, tss) + s1 := NewLocalSubscriber("", logger, tss) s1.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, hub.transport.AddSubscriber(s1)) - s2 := NewSubscriber("", logger, tss) + s2 := NewLocalSubscriber("", logger, tss) s2.SetTopics([]string{"http://example.com/bar"}, nil) require.NoError(t, hub.transport.AddSubscriber(s2)) @@ -175,11 +175,11 @@ func TestSubscriptionHandler(t *testing.T) { hub := createDummy(WithLogger(logger)) tss := &TopicSelectorStore{} - otherS := NewSubscriber("", logger, tss) + otherS := NewLocalSubscriber("", logger, tss) otherS.SetTopics([]string{"http://example.com/other"}, nil) require.NoError(t, hub.transport.AddSubscriber(otherS)) - s := NewSubscriber("", logger, tss) + s := NewLocalSubscriber("", logger, tss) s.SetTopics([]string{"http://example.com/other", "http://example.com/{foo}"}, nil) require.NoError(t, hub.transport.AddSubscriber(s)) diff --git a/transport.go b/transport.go index 1bb4c477..41c3d4ca 100644 --- a/transport.go +++ b/transport.go @@ -45,10 +45,10 @@ type Transport interface { Dispatch(update *Update) error // AddSubscriber adds a new subscriber to the transport. - AddSubscriber(s *Subscriber) error + AddSubscriber(s *LocalSubscriber) error // RemoveSubscriber removes a subscriber from the transport. - RemoveSubscriber(s *Subscriber) error + RemoveSubscriber(s *LocalSubscriber) error // Close closes the Transport. Close() error @@ -96,8 +96,8 @@ func (e *TransportError) Unwrap() error { } func getSubscribers(sl *SubscriberList) (subscribers []*Subscriber) { - sl.Walk(0, func(s *Subscriber) bool { - subscribers = append(subscribers, s) + sl.Walk(0, func(s *LocalSubscriber) bool { + subscribers = append(subscribers, &s.Subscriber) return true })