diff --git a/event/bus.go b/event/bus.go index ac0334f4..d038cea0 100644 --- a/event/bus.go +++ b/event/bus.go @@ -6,13 +6,13 @@ import "context" // Bus is the pub-sub client for Events. type Bus interface { - // Publish should send the Event to subscribers who subscribed to Events - // whose name is evt.Name(). - Publish(ctx context.Context, evt Event) error + // Publish should send each Event evt in events to subscribers who + // subscribed to Events with a name of evt.Name(). + Publish(ctx context.Context, events ...Event) error // Subscribe returns a channel of Events. For every published Event evt // where evt.Name() is one of names, that Event should be received from the - // returned Events channel. When ctx is canceled, events channel should be + // returned Event channel. When ctx is canceled, events channel should be // closed by the implementing Bus. Subscribe(ctx context.Context, names ...string) (<-chan Event, error) } diff --git a/event/eventbus/chanbus/chan.go b/event/eventbus/chanbus/bus.go similarity index 50% rename from event/eventbus/chanbus/chan.go rename to event/eventbus/chanbus/bus.go index 6e886210..3ca4fef1 100644 --- a/event/eventbus/chanbus/chan.go +++ b/event/eventbus/chanbus/bus.go @@ -9,51 +9,71 @@ import ( type eventBus struct { mux sync.RWMutex - subs map[string]map[chan event.Event]bool + subs map[string]map[subscriber]bool queue chan event.Event } +type subscriber struct { + ctx context.Context + events chan event.Event +} + // New returns a Bus that communicates over channels. func New() event.Bus { bus := &eventBus{ - subs: make(map[string]map[chan event.Event]bool), + subs: make(map[string]map[subscriber]bool), queue: make(chan event.Event), } go bus.run() return bus } -// Publish sends the Event to the Event channels that have been returned by -// previous calls to bus.Subscribe() with evt.Name() as the Event name. If ctx -// is canceled, ctx.Err() is returned. -func (bus *eventBus) Publish(ctx context.Context, evt event.Event) error { - select { - case <-ctx.Done(): - return ctx.Err() - case bus.queue <- evt: - return nil +// Publish sends events to the channels that have been returned by previous +// calls to bus.Subscribe() where the subscribed Event name matches the +// evt.Name() for an Event in events. If ctx is canceled before every Event has +// been queued, ctx.Err() is returned. +func (bus *eventBus) Publish(ctx context.Context, events ...event.Event) error { + for _, evt := range events { + select { + case <-ctx.Done(): + return ctx.Err() + case bus.queue <- evt: + } } + return nil } // Subscribe returns a channel of Events. For every published Event evt where // evt.Name() is one of names, that Event will be received from the returned -// Events channel. When ctx is canceled, events will be closed. +// Events channel. When ctx is canceled, events won't accept any new Events and +// will be closed. func (bus *eventBus) Subscribe(ctx context.Context, names ...string) (<-chan event.Event, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + events := make(chan event.Event, 1) + sub := subscriber{ + ctx: ctx, + events: events, + } + bus.mux.Lock() defer bus.mux.Unlock() for _, name := range names { if bus.subs[name] == nil { - bus.subs[name] = make(map[chan event.Event]bool) + bus.subs[name] = make(map[subscriber]bool) } - bus.subs[name][events] = true + bus.subs[name][sub] = true } // unsubscribe when ctx canceled go func() { defer close(events) <-ctx.Done() - bus.unsubscribe(events, names...) + bus.unsubscribe(sub, names...) }() return events, nil @@ -67,23 +87,27 @@ func (bus *eventBus) run() { func (bus *eventBus) publish(evt event.Event) { subs := bus.subscribers(evt.Name()) - for _, events := range subs { - events := events - go func() { events <- evt }() + for _, sub := range subs { + go func(sub subscriber) { + select { + case <-sub.ctx.Done(): + case sub.events <- evt: + } + }(sub) } } -func (bus *eventBus) subscribers(name string) []chan event.Event { +func (bus *eventBus) subscribers(name string) []subscriber { bus.mux.RLock() defer bus.mux.RUnlock() - subs := make([]chan event.Event, 0, len(bus.subs[name])) + subs := make([]subscriber, 0, len(bus.subs[name])) for sub := range bus.subs[name] { subs = append(subs, sub) } return subs } -func (bus *eventBus) unsubscribe(sub chan event.Event, names ...string) { +func (bus *eventBus) unsubscribe(sub subscriber, names ...string) { bus.mux.Lock() defer bus.mux.Unlock() for _, name := range names { diff --git a/event/eventbus/chanbus/chan_test.go b/event/eventbus/chanbus/bus_test.go similarity index 56% rename from event/eventbus/chanbus/chan_test.go rename to event/eventbus/chanbus/bus_test.go index 13d3a1f0..7849c0bd 100644 --- a/event/eventbus/chanbus/chan_test.go +++ b/event/eventbus/chanbus/bus_test.go @@ -2,7 +2,9 @@ package chanbus_test import ( "context" + "errors" "fmt" + "reflect" "testing" "time" @@ -102,18 +104,6 @@ func TestEventBus_Subscribe_multipleNames(t *testing.T) { } } -func expectEvent(name string, events <-chan event.Event) error { - select { - case <-time.After(100 * time.Millisecond): - return fmt.Errorf(`didn't receive "%s" event after 100ms`, name) - case evt := <-events: - if evt.Name() != name { - return fmt.Errorf(`expected "%s" event; got "%s"`, name, evt.Name()) - } - } - return nil -} - func TestEventBus_Subscribe_cancel(t *testing.T) { bus := chanbus.New() @@ -135,11 +125,119 @@ func TestEventBus_Subscribe_cancel(t *testing.T) { // events should be closed select { - case _, ok := <-events: + case evt, ok := <-events: if ok { - t.Fatal("events channel should be closed") + t.Fatal(fmt.Errorf("event channel should be closed; got %v", evt)) } case <-time.After(10 * time.Millisecond): t.Fatal("didn't receive from events channel after 10ms") } } + +func TestEventBus_Subscribe_canceledContext(t *testing.T) { + bus := chanbus.New() + + // given a canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // when subscribing to "foo" events + events, err := bus.Subscribe(ctx, "foo") + + // it should fail with context.Canceled + if !errors.Is(err, context.Canceled) { + t.Error(fmt.Errorf("err should be context.Canceled; got %v", err)) + } + + // events should be nil + if events != nil { + t.Error(fmt.Errorf("events should be nil")) + } +} + +func TestEventBus_Publish_multipleEvents(t *testing.T) { + foo := event.New("foo", eventData{A: "foo"}) + bar := event.New("bar", eventData{A: "bar"}) + baz := event.New("baz", eventData{A: "baz"}) + + tests := []struct { + name string + subscribe []string + publish []event.Event + want []event.Event + }{ + { + name: "subscribed to 1 event", + subscribe: []string{"foo"}, + publish: []event.Event{foo, bar}, + want: []event.Event{foo}, + }, + { + name: "subscribed to all events", + subscribe: []string{"foo", "bar"}, + publish: []event.Event{foo, bar}, + want: []event.Event{foo, bar}, + }, + { + name: "subscribed to even more events", + subscribe: []string{"foo", "bar", "baz"}, + publish: []event.Event{foo, bar}, + want: []event.Event{foo, bar}, + }, + { + name: "subscribed to no events", + subscribe: nil, + publish: []event.Event{foo, bar, baz}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bus := chanbus.New() + ctx := context.Background() + + events, err := bus.Subscribe(ctx, tt.subscribe...) + if err != nil { + t.Fatal(fmt.Errorf("subscribe to %v: %w", tt.subscribe, err)) + } + + if err = bus.Publish(ctx, tt.publish...); err != nil { + t.Fatal(fmt.Errorf("publish: %w", err)) + } + + var received []event.Event + for len(received) < len(tt.want) { + select { + case <-time.After(100 * time.Millisecond): + t.Fatal(fmt.Errorf("didn't receive event after 100ms")) + case evt := <-events: + received = append(received, evt) + } + } + + // check that events channel has no extra events + select { + case evt := <-events: + t.Fatal(fmt.Errorf("shouldn't have received another event; got %v", evt)) + default: + } + + if !reflect.DeepEqual(received, tt.want) { + t.Fatal(fmt.Errorf("expected events %v; got %v", tt.want, received)) + } + }) + } +} + +func expectEvent(name string, events <-chan event.Event) error { + select { + case <-time.After(100 * time.Millisecond): + return fmt.Errorf(`didn't receive "%s" event after 100ms`, name) + case evt := <-events: + if evt.Name() != name { + return fmt.Errorf(`expected "%s" event; got "%s"`, name, evt.Name()) + } + } + return nil +} diff --git a/event/mocks/bus.go b/event/mocks/bus.go index df16bc12..6061c3bc 100644 --- a/event/mocks/bus.go +++ b/event/mocks/bus.go @@ -35,17 +35,22 @@ func (m *MockBus) EXPECT() *MockBusMockRecorder { } // Publish mocks base method -func (m *MockBus) Publish(ctx context.Context, evt event.Event) error { +func (m *MockBus) Publish(ctx context.Context, events ...event.Event) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Publish", ctx, evt) + varargs := []interface{}{ctx} + for _, a := range events { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Publish", varargs...) ret0, _ := ret[0].(error) return ret0 } // Publish indicates an expected call of Publish -func (mr *MockBusMockRecorder) Publish(ctx, evt interface{}) *gomock.Call { +func (mr *MockBusMockRecorder) Publish(ctx interface{}, events ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockBus)(nil).Publish), ctx, evt) + varargs := append([]interface{}{ctx}, events...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockBus)(nil).Publish), varargs...) } // Subscribe mocks base method