Skip to content

Commit

Permalink
Add basic plugin system for clients
Browse files Browse the repository at this point in the history
Here, follow up #429 to add a basic plugin system for River clients
which allows a driver to add maintenance and non-maintenance services to
a client before it starts up. The plugin interface is implemented by the
drivers themselves, and looks like this:

    type driverPlugin[TTx any] interface {
        // PluginInit initializes a plugin with an archetype and client. It's
        // invoked on Client.NewClient.
        PluginInit(archetype *baseservice.Archetype, client *Client[TTx])

        // PluginMaintenanceServices returns additional maintenance services (will
        // only run on an elected leader) for a River client.
        PluginMaintenanceServices() []startstop.Service

        // PluginServices returns additional non-maintenance services (will run on
        // all clients) for a River client.
        PluginServices() []startstop.Service
    }

The change is fairly straightforward, and we make sure to bring in some
test cases verifying the plugin services were indeed added correctly.
  • Loading branch information
brandur committed Jul 8, 2024
1 parent 4f32c24 commit ef13ffc
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 47 deletions.
18 changes: 5 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is
client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown)

var plugin rivertype.Plugin
if p, ok := driver.(Pluginable[TTx]); ok {
plugin = p.Plugin(client, config.Logger)
plugin, _ := driver.(driverPlugin[TTx])
if plugin != nil {
plugin.PluginInit(archetype, client)
}

// There are a number of internal components that are only needed/desired if
Expand Down Expand Up @@ -542,11 +542,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
startstop.StartStopFunc(client.handleLeadershipChangeLoop))

if plugin != nil {
// TODO: can just use rivertype.Service for client.Services to avoid this map:
pluginServices := sliceutil.Map(plugin.Services(), func(s rivertype.Service) startstop.Service {
return startstop.Service(s)
})
client.services = append(client.services, pluginServices...)
client.services = append(client.services, plugin.PluginServices()...)
}

//
Expand Down Expand Up @@ -621,11 +617,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
}

if plugin != nil {
// TODO: can just use rivertype.Service for NewQueueMaintainer services to avoid this map:
pluginServices := sliceutil.Map(plugin.MaintenanceServices(), func(s rivertype.Service) startstop.Service {
return startstop.Service(s)
})
maintenanceServices = append(maintenanceServices, pluginServices...)
maintenanceServices = append(maintenanceServices, plugin.PluginMaintenanceServices()...)
}

// Not added to the main services list because the queue maintainer is
Expand Down
24 changes: 19 additions & 5 deletions plugin.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
package river

import (
"log/slog"

"github.com/riverqueue/river/rivertype"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/startstop"
)

type Pluginable[TTx any] interface {
Plugin(client *Client[TTx], logger *slog.Logger) rivertype.Plugin
// A plugin API that drivers may implement to extend a River client. Driver
// plugins may, for example, add additional maintenance services.
//
// This should be considered a River internal API and its stability is not
// guaranteed. DO NOT USE.
type driverPlugin[TTx any] interface {
// PluginInit initializes a plugin with an archetype and client. It's
// invoked on Client.NewClient.
PluginInit(archetype *baseservice.Archetype, client *Client[TTx])

// PluginMaintenanceServices returns additional maintenance services (will
// only run on an elected leader) for a River client.
PluginMaintenanceServices() []startstop.Service

// PluginServices returns additional non-maintenance services (will run on
// all clients) for a River client.
PluginServices() []startstop.Service
}
110 changes: 110 additions & 0 deletions plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package river

import (
"context"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"

"github.com/riverqueue/river/internal/riverinternaltest"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/riversharedtest"
"github.com/riverqueue/river/rivershared/startstop"
)

func TestClientDriverPlugin(t *testing.T) {
t.Parallel()

ctx := context.Background()

type testBundle struct {
pluginDriver *TestDriverWithPlugin
}

setup := func(t *testing.T) (*Client[pgx.Tx], *testBundle) {
t.Helper()

pluginDriver := newDriverWithPlugin(t, riverinternaltest.TestDB(ctx, t))

client, err := NewClient(pluginDriver, newTestConfig(t, nil))
require.NoError(t, err)

return client, &testBundle{
pluginDriver: pluginDriver,
}
}

t.Run("ServicesStart", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

startClient(ctx, t, client)

riversharedtest.WaitOrTimeout(t, startstop.WaitAllStartedC(
bundle.pluginDriver.maintenanceService,
bundle.pluginDriver.service,
))
})
}

var _ driverPlugin[pgx.Tx] = &TestDriverWithPlugin{}

type TestDriverWithPlugin struct {
*riverpgxv5.Driver
initCalled bool
maintenanceService startstop.Service
service startstop.Service
}

func newDriverWithPlugin(t *testing.T, dbPool *pgxpool.Pool) *TestDriverWithPlugin {
t.Helper()

newService := func(name string) startstop.Service {
return startstop.StartStopFunc(func(ctx context.Context, shouldStart bool, started, stopped func()) error {
if !shouldStart {
return nil
}

go func() {
started()
defer stopped() // this defer should come first so it's last out

t.Logf("Test service started: %s", name)

<-ctx.Done()
}()

return nil
})
}

return &TestDriverWithPlugin{
Driver: riverpgxv5.New(dbPool),
maintenanceService: newService("maintenance service"),
service: newService("other service"),
}
}

func (d *TestDriverWithPlugin) PluginInit(archetype *baseservice.Archetype, client *Client[pgx.Tx]) {
d.initCalled = true
}

func (d *TestDriverWithPlugin) PluginMaintenanceServices() []startstop.Service {
if !d.initCalled {
panic("expected PluginInit to be called before this function")
}

return []startstop.Service{d.maintenanceService}
}

func (d *TestDriverWithPlugin) PluginServices() []startstop.Service {
if !d.initCalled {
panic("expected PluginInit to be called before this function")
}

return []startstop.Service{d.service}
}
11 changes: 11 additions & 0 deletions rivershared/startstop/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package startstop

import (
"testing"

"github.com/riverqueue/river/rivershared/riversharedtest"
)

func TestMain(m *testing.M) {
riversharedtest.WrapTestMain(m)
}
18 changes: 16 additions & 2 deletions rivershared/startstop/start_stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,17 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, f
s.mu.Lock()
defer s.mu.Unlock()

if s.started != nil {
// Used stopped rather than started to track started state because the
// started channel may be preallocated by a call to Started.
if s.stopped != nil {
return ctx, false, nil, nil
}

s.started = make(chan struct{})
// Only allocate a started channel if one was preallocated by Started.
if s.started == nil {
s.started = make(chan struct{})
}

s.stopped = make(chan struct{})
ctx, s.cancelFunc = context.WithCancelCause(ctx)

Expand All @@ -133,6 +139,14 @@ func (s *BaseStartStop) Started() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()

// If the call to Started is before the service was actually started,
// preallocate the started channel so that regardless of whether the wait
// started before or after the service started, it will still do the right
// thing.
if s.started == nil {
s.started = make(chan struct{})
}

return s.started
}

Expand Down
32 changes: 32 additions & 0 deletions rivershared/startstop/start_stop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped)
service, _ := setup(t)

require.NoError(t, service.Start(ctx))
t.Cleanup(service.Stop)

riversharedtest.WaitOrTimeout(t, service.Started())
})
Expand Down Expand Up @@ -134,6 +135,20 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped)

wg.Wait()
})

t.Run("StartedPreallocated", func(t *testing.T) {
t.Parallel()

service, _ := setup(t)

// Make sure we get the start channel before the service is started.
started := service.Started()

require.NoError(t, service.Start(ctx))
t.Cleanup(service.Stop)

riversharedtest.WaitOrTimeout(t, started)
})
}

func TestBaseStartStop(t *testing.T) {
Expand Down Expand Up @@ -223,6 +238,7 @@ func TestSampleService(t *testing.T) {
service, _ := setup(t)

require.NoError(t, service.Start(ctx))
t.Cleanup(service.Stop)

riversharedtest.WaitOrTimeout(t, service.Started())
require.True(t, service.state)
Expand Down Expand Up @@ -398,6 +414,10 @@ func TestWaitAllStarted(t *testing.T) {
require.NoError(t, service2.Start(ctx))
require.NoError(t, service3.Start(ctx))

t.Cleanup(service1.Stop)
t.Cleanup(service2.Stop)
t.Cleanup(service3.Stop)

WaitAllStarted(service1, service2, service3)

require.True(t, service1.state)
Expand All @@ -418,6 +438,10 @@ func TestWaitAllStarted(t *testing.T) {
require.NoError(t, service2.Start(ctx))
require.ErrorIs(t, service3.Start(ctx), service3.startErr)

t.Cleanup(service1.Stop)
t.Cleanup(service2.Stop)
t.Cleanup(service3.Stop)

WaitAllStarted(service1, service2, service3)
})
}
Expand All @@ -440,6 +464,10 @@ func TestWaitAllStartedC(t *testing.T) {
require.NoError(t, service2.Start(ctx))
require.NoError(t, service3.Start(ctx))

t.Cleanup(service1.Stop)
t.Cleanup(service2.Stop)
t.Cleanup(service3.Stop)

riversharedtest.WaitOrTimeout(t, WaitAllStartedC(service1, service2, service3))

require.True(t, service1.state)
Expand All @@ -460,6 +488,10 @@ func TestWaitAllStartedC(t *testing.T) {
require.NoError(t, service2.Start(ctx))
require.ErrorIs(t, service3.Start(ctx), service3.startErr)

t.Cleanup(service1.Stop)
t.Cleanup(service2.Stop)
t.Cleanup(service3.Stop)

riversharedtest.WaitOrTimeout(t, WaitAllStartedC(service1, service2, service3))
})
}
27 changes: 0 additions & 27 deletions rivertype/river_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package rivertype

import (
"context"
"errors"
"time"
)
Expand Down Expand Up @@ -219,11 +218,6 @@ type AttemptError struct {
// subsequently remove the periodic job with `Remove()`.
type PeriodicJobHandle int

type Plugin interface {
MaintenanceServices() []Service
Services() []Service
}

// Queue is a configuration for a queue that is currently (or recently was) in
// use by a client.
type Queue struct {
Expand All @@ -249,24 +243,3 @@ type Queue struct {
// deleted from the table by a maintenance process.
UpdatedAt time.Time
}

type Service interface {
// Start starts a service. Services are responsible for backgrounding
// themselves, so this function should be invoked synchronously. Services
// may return an error if they have trouble starting up, so the caller
// should wait and respond to the error if necessary.
Start(ctx context.Context) error

// Started returns a channel that's closed when a service finishes starting,
// or if failed to start and is stopped instead. It can be used in
// conjunction with WaitAllStarted to verify startup of a constellation of
// services.
Started() <-chan struct{}

// Stop stops a service. Services are responsible for making sure their stop
// is complete before returning so a caller can wait on this invocation
// synchronously and be guaranteed the service is fully stopped. Services
// are expected to be able to tolerate (1) being stopped without having been
// started, and (2) being double-stopped.
Stop()
}

0 comments on commit ef13ffc

Please sign in to comment.