Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport of events: Remove subscriptions on timeout and cancel into release/1.13.x #19204

Merged
merged 2 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions vault/eventbus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"errors"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/eventlogger/formatter_filters/cloudevents"
"github.com/hashicorp/go-hclog"
Expand All @@ -17,9 +19,13 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

var ErrNotStarted = errors.New("event broker has not been started")
const defaultTimeout = 60 * time.Second

var cloudEventsFormatterFilter *cloudevents.FormatterFilter
var (
ErrNotStarted = errors.New("event broker has not been started")
cloudEventsFormatterFilter *cloudevents.FormatterFilter
subscriptions atomic.Int64 // keeps track of event subscription count in all event buses
)

// EventBus contains the main logic of running an event broker for Vault.
// Start() must be called before the EventBus will accept events for sending.
Expand All @@ -28,6 +34,7 @@ type EventBus struct {
broker *eventlogger.Broker
started atomic.Bool
formatterNodeID eventlogger.NodeID
timeout time.Duration
}

type pluginEventBus struct {
Expand All @@ -42,6 +49,13 @@ type asyncChanNode struct {
ch chan *logical.EventReceived
namespace *namespace.Namespace
logger hclog.Logger

// used to close the connection
closeOnce sync.Once
cancelFunc context.CancelFunc
pipelineID eventlogger.PipelineID
eventType eventlogger.EventType
broker *eventlogger.Broker
}

var (
Expand Down Expand Up @@ -79,6 +93,10 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
Timestamp: timestamppb.New(time.Now()),
}
bus.logger.Info("Sending event", "event", eventReceived)

// We can't easily know when the Send is complete, so we can't call the cancel function.
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
ctx, _ = context.WithTimeout(ctx, bus.timeout)
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived)
if err != nil {
// if no listeners for this event type are registered, that's okay, the event
Expand Down Expand Up @@ -142,6 +160,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
logger: logger,
broker: broker,
formatterNodeID: formatterNodeID,
timeout: defaultTimeout,
}, nil
}

Expand Down Expand Up @@ -178,7 +197,18 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve
defer cancel()
return nil, nil, err
}
return asyncNode.ch, cancel, nil
addSubscriptions(1)
// add info needed to cancel the subscription
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
asyncNode.eventType = eventlogger.EventType(eventType)
asyncNode.cancelFunc = cancel
return asyncNode.ch, asyncNode.Close, nil
}

// SetSendTimeout sets the timeout of sending events. If the events are not accepted by the
// underlying channel before this timeout, then the channel closed.
func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
bus.timeout = timeout
}

func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
Expand All @@ -190,8 +220,21 @@ func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hc
}
}

// Close tells the bus to stop sending us events.
func (node *asyncChanNode) Close() {
node.closeOnce.Do(func() {
defer node.cancelFunc()
if node.broker != nil {
err := node.broker.RemovePipeline(node.eventType, node.pipelineID)
if err != nil {
node.logger.Warn("Error removing pipeline for closing node", "error", err)
}
}
addSubscriptions(-1)
})
}

func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) {
// TODO: add timeout on sending to node.ch
// sends to the channel async in another goroutine
go func() {
eventRecv := e.Payload.(*logical.EventReceived)
Expand All @@ -200,12 +243,17 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
if eventRecv.Namespace != node.namespace.Path {
return
}
var timeout bool
select {
case node.ch <- eventRecv:
case <-ctx.Done():
return
timeout = errors.Is(ctx.Err(), context.DeadlineExceeded)
case <-node.ctx.Done():
return
timeout = errors.Is(node.ctx.Err(), context.DeadlineExceeded)
}
if timeout {
node.logger.Info("Subscriber took too long to process event, closing", "ID", eventRecv.Event.ID())
node.Close()
}
}()
return e, nil
Expand All @@ -218,3 +266,7 @@ func (node *asyncChanNode) Reopen() error {
func (node *asyncChanNode) Type() eventlogger.NodeType {
return eventlogger.NodeTypeSink
}

func addSubscriptions(delta int64) {
metrics.SetGauge([]string{"events", "subscriptions"}, float32(subscriptions.Add(delta)))
}
117 changes: 117 additions & 0 deletions vault/eventbus/bus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package eventbus

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)

// TestBusBasics tests that basic event sending and subscribing function.
func TestBusBasics(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -62,6 +65,7 @@ func TestBusBasics(t *testing.T) {
}
}

// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus.
func TestNamespaceFiltering(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -121,6 +125,7 @@ func TestNamespaceFiltering(t *testing.T) {
}
}

// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers.
func TestBus2Subscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
Expand Down Expand Up @@ -180,3 +185,115 @@ func TestBus2Subscriptions(t *testing.T) {
t.Error("Timeout waiting for event2")
}
}

// TestBusSubscriptionsCancel verifies that canceled subscriptions are cleaned up.
func TestBusSubscriptionsCancel(t *testing.T) {
testCases := []struct {
cancel bool
}{
{cancel: true},
{cancel: false},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) {
subscriptions.Store(0)
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if !tc.cancel {
// set the timeout very short to make the test faster if we aren't canceling explicitly
bus.SetSendTimeout(100 * time.Millisecond)
}
bus.Start()

// create and stop a bunch of subscriptions
const create = 100
const stop = 50

eventType := logical.EventType("someType")

var channels []<-chan *logical.EventReceived
var cancels []context.CancelFunc
stopped := atomic.Int32{}

received := atomic.Int32{}

for i := 0; i < create; i++ {
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
if err != nil {
t.Fatal(err)
}
t.Cleanup(cancelFunc)
channels = append(channels, ch)
cancels = append(cancels, cancelFunc)

go func(i int32) {
<-ch // always receive one message
received.Add(1)
// continue receiving messages as long as are not stopped
for i < int32(stop) {
<-ch
received.Add(1)
}
if tc.cancel {
cancelFunc() // stop explicitly to unsubscribe
}
stopped.Add(1)
}(int32(i))
}

// check that all channels receive a message
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create) })
waitFor(t, 1*time.Second, func() bool { return stopped.Load() == int32(stop) })

// send another message, but half should stop receiving
event, err = logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create*2-stop) })
// the sends should time out and the subscriptions should drop when cancelFunc is called or the context cancels
waitFor(t, 1*time.Second, func() bool { return subscriptions.Load() == int64(create-stop) })
})
}
}

// waitFor waits for a condition to be true, up to the maximum timeout.
// It waits with a capped exponential backoff starting at 1ms.
// It is guaranteed to try f() at least once.
func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
t.Helper()
start := time.Now()

if f() {
return
}
sleepAmount := 1 * time.Millisecond
for time.Now().Sub(start) <= maxWait {
left := time.Now().Sub(start)
sleepAmount = sleepAmount * 2
if sleepAmount > left {
sleepAmount = left
}
time.Sleep(sleepAmount)
if f() {
return
}
}
t.Error("Timeout waiting for condition")
}