diff --git a/client.go b/client.go index 8bdd0f4b..22689c6e 100644 --- a/client.go +++ b/client.go @@ -497,9 +497,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown) - var plugin rivertype.Plugin - if p, ok := driver.(Pluginable[TTx]); ok { - plugin = p.Plugin(client, config.Logger) + plugin, _ := driver.(driverPlugin[TTx]) + if plugin != nil { + plugin.PluginInit(archetype, client) } // There are a number of internal components that are only needed/desired if @@ -542,11 +542,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client startstop.StartStopFunc(client.handleLeadershipChangeLoop)) if plugin != nil { - // TODO: can just use rivertype.Service for client.Services to avoid this map: - pluginServices := sliceutil.Map(plugin.Services(), func(s rivertype.Service) startstop.Service { - return startstop.Service(s) - }) - client.services = append(client.services, pluginServices...) + client.services = append(client.services, plugin.PluginServices()...) } // @@ -621,11 +617,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client } if plugin != nil { - // TODO: can just use rivertype.Service for NewQueueMaintainer services to avoid this map: - pluginServices := sliceutil.Map(plugin.MaintenanceServices(), func(s rivertype.Service) startstop.Service { - return startstop.Service(s) - }) - maintenanceServices = append(maintenanceServices, pluginServices...) + maintenanceServices = append(maintenanceServices, plugin.PluginMaintenanceServices()...) } // Not added to the main services list because the queue maintainer is diff --git a/plugin.go b/plugin.go index 2487ed94..11b06bb2 100644 --- a/plugin.go +++ b/plugin.go @@ -1,11 +1,25 @@ package river import ( - "log/slog" - - "github.com/riverqueue/river/rivertype" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" ) -type Pluginable[TTx any] interface { - Plugin(client *Client[TTx], logger *slog.Logger) rivertype.Plugin +// A plugin API that drivers may implemented to extend a River client. Driver +// plugins may, for example, add additional maintenance services. +// +// This should be considered a River internal API and its stability is not +// guaranteed. DO NOT USE. +type driverPlugin[TTx any] interface { + // PluginInit initializes a plugin with an archetype and client. It's + // invoked on Client.NewClient. + PluginInit(archetype *baseservice.Archetype, client *Client[TTx]) + + // PluginMaintenanceServices returns additional maintenance services (will + // only run on an elected leader) for a River client. + PluginMaintenanceServices() []startstop.Service + + // PluginServices returns additional non-maintenance services (will run on + // all clients) for a River client. + PluginServices() []startstop.Service } diff --git a/plugin_test.go b/plugin_test.go new file mode 100644 index 00000000..4bdd4389 --- /dev/null +++ b/plugin_test.go @@ -0,0 +1,99 @@ +package river + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstop" +) + +func TestClientDriverPlugin(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + pluginDriver *TestDriverWithPlugin + } + + setup := func(t *testing.T) (*Client[pgx.Tx], *testBundle) { + t.Helper() + + pluginDriver := newDriverWithPlugin(t, riverinternaltest.TestDB(ctx, t)) + + client, err := NewClient(pluginDriver, newTestConfig(t, nil)) + require.NoError(t, err) + + return client, &testBundle{ + pluginDriver: pluginDriver, + } + } + + t.Run("ServicesStart", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + startClient(ctx, t, client) + + riversharedtest.WaitOrTimeout(t, startstop.WaitAllStartedC( + bundle.pluginDriver.maintenanceService, + bundle.pluginDriver.service, + )) + }) +} + +var _ driverPlugin[pgx.Tx] = &TestDriverWithPlugin{} + +type TestDriverWithPlugin struct { + *riverpgxv5.Driver + maintenanceService startstop.Service + service startstop.Service +} + +func newDriverWithPlugin(t *testing.T, dbPool *pgxpool.Pool) *TestDriverWithPlugin { + t.Helper() + + newService := func(name string) startstop.Service { + return startstop.StartStopFunc(func(ctx context.Context, shouldStart bool, started, stopped func()) error { + if !shouldStart { + return nil + } + + go func() { + started() + defer stopped() // this defer should come first so it's last out + + t.Logf("Test service started: %s", name) + + <-ctx.Done() + }() + + return nil + }) + } + + return &TestDriverWithPlugin{ + Driver: riverpgxv5.New(dbPool), + maintenanceService: newService("maintenance service"), + service: newService("other service"), + } +} + +func (d *TestDriverWithPlugin) PluginInit(archetype *baseservice.Archetype, client *Client[pgx.Tx]) {} + +func (d *TestDriverWithPlugin) PluginMaintenanceServices() []startstop.Service { + return []startstop.Service{d.maintenanceService} +} + +func (d *TestDriverWithPlugin) PluginServices() []startstop.Service { + return []startstop.Service{d.service} +} diff --git a/rivershared/startstop/start_stop.go b/rivershared/startstop/start_stop.go index 782a2ffc..b32c95ea 100644 --- a/rivershared/startstop/start_stop.go +++ b/rivershared/startstop/start_stop.go @@ -107,11 +107,17 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, f s.mu.Lock() defer s.mu.Unlock() - if s.started != nil { + // Used stopped rather than started to track started state because the + // started channel may be preallocated by a call to Started. + if s.stopped != nil { return ctx, false, nil, nil } - s.started = make(chan struct{}) + // Only allocate a started channel if one was preallocated by Started. + if s.started == nil { + s.started = make(chan struct{}) + } + s.stopped = make(chan struct{}) ctx, s.cancelFunc = context.WithCancelCause(ctx) @@ -133,6 +139,14 @@ func (s *BaseStartStop) Started() <-chan struct{} { s.mu.Lock() defer s.mu.Unlock() + // If the call to Started is before the service was actually started, + // preallocate the started channel so that regardless of whether the wait + // started before or after the service started, it will still do the right + // thing. + if s.started == nil { + s.started = make(chan struct{}) + } + return s.started } diff --git a/rivershared/startstop/start_stop_test.go b/rivershared/startstop/start_stop_test.go index 391033e0..b82a12be 100644 --- a/rivershared/startstop/start_stop_test.go +++ b/rivershared/startstop/start_stop_test.go @@ -134,6 +134,20 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped) wg.Wait() }) + + t.Run("StartedPreallocated", func(t *testing.T) { + t.Parallel() + + service, _ := setup(t) + + // Make sure we get the start channel before the service is started. + started := service.Started() + + require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) + + riversharedtest.WaitOrTimeout(t, started) + }) } func TestBaseStartStop(t *testing.T) { diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 6a1e62d1..df4f6038 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -4,7 +4,6 @@ package rivertype import ( - "context" "errors" "time" ) @@ -219,11 +218,6 @@ type AttemptError struct { // subsequently remove the periodic job with `Remove()`. type PeriodicJobHandle int -type Plugin interface { - MaintenanceServices() []Service - Services() []Service -} - // Queue is a configuration for a queue that is currently (or recently was) in // use by a client. type Queue struct { @@ -249,24 +243,3 @@ type Queue struct { // deleted from the table by a maintenance process. UpdatedAt time.Time } - -type Service interface { - // Start starts a service. Services are responsible for backgrounding - // themselves, so this function should be invoked synchronously. Services - // may return an error if they have trouble starting up, so the caller - // should wait and respond to the error if necessary. - Start(ctx context.Context) error - - // Started returns a channel that's closed when a service finishes starting, - // or if failed to start and is stopped instead. It can be used in - // conjunction with WaitAllStarted to verify startup of a constellation of - // services. - Started() <-chan struct{} - - // Stop stops a service. Services are responsible for making sure their stop - // is complete before returning so a caller can wait on this invocation - // synchronously and be guaranteed the service is fully stopped. Services - // are expected to be able to tolerate (1) being stopped without having been - // started, and (2) being double-stopped. - Stop() -}