From 051a05211230dce595ec6b766cde60db65d4d8c3 Mon Sep 17 00:00:00 2001 From: Brandur Leach Date: Sun, 7 Jul 2024 18:48:32 -0700 Subject: [PATCH] Add basic plugin system for clients (#430) Add basic plugin system for clients Here, follow up #429 to add a basic plugin system for River clients which allows a driver to add maintenance and non-maintenance services to a client before it starts up. The plugin interface is implemented by the drivers themselves, and looks like this: type driverPlugin[TTx any] interface { // PluginInit initializes a plugin with an archetype and client. It's // invoked on Client.NewClient. PluginInit(archetype *baseservice.Archetype, client *Client[TTx]) // PluginMaintenanceServices returns additional maintenance services (will // only run on an elected leader) for a River client. PluginMaintenanceServices() []startstop.Service // PluginServices returns additional non-maintenance services (will run on // all clients) for a River client. PluginServices() []startstop.Service } The change is fairly straightforward, and we make sure to bring in some test cases verifying the plugin services were indeed added correctly. --------- Co-authored-by: Blake Gentry --- client.go | 13 +++ plugin.go | 25 ++++++ plugin_test.go | 110 +++++++++++++++++++++++ rivershared/startstop/main_test.go | 11 +++ rivershared/startstop/start_stop.go | 18 +++- rivershared/startstop/start_stop_test.go | 32 +++++++ 6 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 plugin.go create mode 100644 plugin_test.go create mode 100644 rivershared/startstop/main_test.go diff --git a/client.go b/client.go index babea228..22689c6e 100644 --- a/client.go +++ b/client.go @@ -497,6 +497,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown) + plugin, _ := driver.(driverPlugin[TTx]) + if plugin != nil { + plugin.PluginInit(archetype, client) + } + // There are a number of internal components that are only needed/desired if // we're actually going to be working jobs (as opposed to just enqueueing // them): @@ -536,6 +541,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.services = append(client.services, startstop.StartStopFunc(client.handleLeadershipChangeLoop)) + if plugin != nil { + client.services = append(client.services, plugin.PluginServices()...) + } + // // Maintenance services // @@ -607,6 +616,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.testSignals.reindexer = &reindexer.TestSignals } + if plugin != nil { + maintenanceServices = append(maintenanceServices, plugin.PluginMaintenanceServices()...) + } + // Not added to the main services list because the queue maintainer is // started conditionally based on whether the client is the leader. client.queueMaintainer = maintenance.NewQueueMaintainer(archetype, maintenanceServices) diff --git a/plugin.go b/plugin.go new file mode 100644 index 00000000..d0676ab6 --- /dev/null +++ b/plugin.go @@ -0,0 +1,25 @@ +package river + +import ( + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" +) + +// A plugin API that drivers may implement to extend a River client. Driver +// plugins may, for example, add additional maintenance services. +// +// This should be considered a River internal API and its stability is not +// guaranteed. DO NOT USE. +type driverPlugin[TTx any] interface { + // PluginInit initializes a plugin with an archetype and client. It's + // invoked on Client.NewClient. + PluginInit(archetype *baseservice.Archetype, client *Client[TTx]) + + // PluginMaintenanceServices returns additional maintenance services (will + // only run on an elected leader) for a River client. + PluginMaintenanceServices() []startstop.Service + + // PluginServices returns additional non-maintenance services (will run on + // all clients) for a River client. + PluginServices() []startstop.Service +} diff --git a/plugin_test.go b/plugin_test.go new file mode 100644 index 00000000..b7654ba1 --- /dev/null +++ b/plugin_test.go @@ -0,0 +1,110 @@ +package river + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstop" +) + +func TestClientDriverPlugin(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + pluginDriver *TestDriverWithPlugin + } + + setup := func(t *testing.T) (*Client[pgx.Tx], *testBundle) { + t.Helper() + + pluginDriver := newDriverWithPlugin(t, riverinternaltest.TestDB(ctx, t)) + + client, err := NewClient(pluginDriver, newTestConfig(t, nil)) + require.NoError(t, err) + + return client, &testBundle{ + pluginDriver: pluginDriver, + } + } + + t.Run("ServicesStart", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + startClient(ctx, t, client) + + riversharedtest.WaitOrTimeout(t, startstop.WaitAllStartedC( + bundle.pluginDriver.maintenanceService, + bundle.pluginDriver.service, + )) + }) +} + +var _ driverPlugin[pgx.Tx] = &TestDriverWithPlugin{} + +type TestDriverWithPlugin struct { + *riverpgxv5.Driver + initCalled bool + maintenanceService startstop.Service + service startstop.Service +} + +func newDriverWithPlugin(t *testing.T, dbPool *pgxpool.Pool) *TestDriverWithPlugin { + t.Helper() + + newService := func(name string) startstop.Service { + return startstop.StartStopFunc(func(ctx context.Context, shouldStart bool, started, stopped func()) error { + if !shouldStart { + return nil + } + + go func() { + started() + defer stopped() // this defer should come first so it's last out + + t.Logf("Test service started: %s", name) + + <-ctx.Done() + }() + + return nil + }) + } + + return &TestDriverWithPlugin{ + Driver: riverpgxv5.New(dbPool), + maintenanceService: newService("maintenance service"), + service: newService("other service"), + } +} + +func (d *TestDriverWithPlugin) PluginInit(archetype *baseservice.Archetype, client *Client[pgx.Tx]) { + d.initCalled = true +} + +func (d *TestDriverWithPlugin) PluginMaintenanceServices() []startstop.Service { + if !d.initCalled { + panic("expected PluginInit to be called before this function") + } + + return []startstop.Service{d.maintenanceService} +} + +func (d *TestDriverWithPlugin) PluginServices() []startstop.Service { + if !d.initCalled { + panic("expected PluginInit to be called before this function") + } + + return []startstop.Service{d.service} +} diff --git a/rivershared/startstop/main_test.go b/rivershared/startstop/main_test.go new file mode 100644 index 00000000..85d07b38 --- /dev/null +++ b/rivershared/startstop/main_test.go @@ -0,0 +1,11 @@ +package startstop + +import ( + "testing" + + "github.com/riverqueue/river/rivershared/riversharedtest" +) + +func TestMain(m *testing.M) { + riversharedtest.WrapTestMain(m) +} diff --git a/rivershared/startstop/start_stop.go b/rivershared/startstop/start_stop.go index 782a2ffc..b32c95ea 100644 --- a/rivershared/startstop/start_stop.go +++ b/rivershared/startstop/start_stop.go @@ -107,11 +107,17 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, f s.mu.Lock() defer s.mu.Unlock() - if s.started != nil { + // Used stopped rather than started to track started state because the + // started channel may be preallocated by a call to Started. + if s.stopped != nil { return ctx, false, nil, nil } - s.started = make(chan struct{}) + // Only allocate a started channel if one was preallocated by Started. + if s.started == nil { + s.started = make(chan struct{}) + } + s.stopped = make(chan struct{}) ctx, s.cancelFunc = context.WithCancelCause(ctx) @@ -133,6 +139,14 @@ func (s *BaseStartStop) Started() <-chan struct{} { s.mu.Lock() defer s.mu.Unlock() + // If the call to Started is before the service was actually started, + // preallocate the started channel so that regardless of whether the wait + // started before or after the service started, it will still do the right + // thing. + if s.started == nil { + s.started = make(chan struct{}) + } + return s.started } diff --git a/rivershared/startstop/start_stop_test.go b/rivershared/startstop/start_stop_test.go index 391033e0..6f75a23e 100644 --- a/rivershared/startstop/start_stop_test.go +++ b/rivershared/startstop/start_stop_test.go @@ -95,6 +95,7 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped) service, _ := setup(t) require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) riversharedtest.WaitOrTimeout(t, service.Started()) }) @@ -134,6 +135,20 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped) wg.Wait() }) + + t.Run("StartedPreallocated", func(t *testing.T) { + t.Parallel() + + service, _ := setup(t) + + // Make sure we get the start channel before the service is started. + started := service.Started() + + require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) + + riversharedtest.WaitOrTimeout(t, started) + }) } func TestBaseStartStop(t *testing.T) { @@ -223,6 +238,7 @@ func TestSampleService(t *testing.T) { service, _ := setup(t) require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) riversharedtest.WaitOrTimeout(t, service.Started()) require.True(t, service.state) @@ -398,6 +414,10 @@ func TestWaitAllStarted(t *testing.T) { require.NoError(t, service2.Start(ctx)) require.NoError(t, service3.Start(ctx)) + t.Cleanup(service1.Stop) + t.Cleanup(service2.Stop) + t.Cleanup(service3.Stop) + WaitAllStarted(service1, service2, service3) require.True(t, service1.state) @@ -418,6 +438,10 @@ func TestWaitAllStarted(t *testing.T) { require.NoError(t, service2.Start(ctx)) require.ErrorIs(t, service3.Start(ctx), service3.startErr) + t.Cleanup(service1.Stop) + t.Cleanup(service2.Stop) + t.Cleanup(service3.Stop) + WaitAllStarted(service1, service2, service3) }) } @@ -440,6 +464,10 @@ func TestWaitAllStartedC(t *testing.T) { require.NoError(t, service2.Start(ctx)) require.NoError(t, service3.Start(ctx)) + t.Cleanup(service1.Stop) + t.Cleanup(service2.Stop) + t.Cleanup(service3.Stop) + riversharedtest.WaitOrTimeout(t, WaitAllStartedC(service1, service2, service3)) require.True(t, service1.state) @@ -460,6 +488,10 @@ func TestWaitAllStartedC(t *testing.T) { require.NoError(t, service2.Start(ctx)) require.ErrorIs(t, service3.Start(ctx), service3.startErr) + t.Cleanup(service1.Stop) + t.Cleanup(service2.Stop) + t.Cleanup(service3.Stop) + riversharedtest.WaitOrTimeout(t, WaitAllStartedC(service1, service2, service3)) }) }