diff --git a/.github/workflows/go-coverage.yml b/.github/workflows/go-coverage.yml index 808a1d9..410a579 100644 --- a/.github/workflows/go-coverage.yml +++ b/.github/workflows/go-coverage.yml @@ -9,10 +9,6 @@ on: permissions: contents: read -env: - GOCACHE: "/tmp/go-cache" - GOMODCACHE: "/tmp/go-mod-cache" - jobs: build: runs-on: ubuntu-latest @@ -55,4 +51,4 @@ jobs: -Dsonar.tests=. -Dsonar.test.inclusions=**/*_test.go -Dsonar.language=go - -Dsonar.sourceEncoding=UTF-8 \ No newline at end of file + -Dsonar.sourceEncoding=UTF-8 diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2ef995a..a27e82f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -9,10 +9,6 @@ on: permissions: contents: read -env: - GOCACHE: "/tmp/go-cache" - GOMODCACHE: "/tmp/go-mod-cache" - jobs: build: runs-on: ubuntu-latest diff --git a/examples/composite/main.go b/examples/composite/main.go index 4783ca5..025ac9a 100644 --- a/examples/composite/main.go +++ b/examples/composite/main.go @@ -77,7 +77,6 @@ func main() { // Create composite runner runner, err := composite.NewRunner( configCallback, - composite.WithContext[*Worker](ctx), ) if err != nil { logger.Error("Failed to create composite runner", "error", err) diff --git a/examples/composite/reload_membership_test.go b/examples/composite/reload_membership_test.go index af66431..d855862 100644 --- a/examples/composite/reload_membership_test.go +++ b/examples/composite/reload_membership_test.go @@ -145,7 +145,6 @@ func TestMembershipChangesBasic(t *testing.T) { runner, err := composite.NewRunner[*TestWorker]( configCallback, - composite.WithContext[*TestWorker](ctx), composite.WithLogHandler[*TestWorker](logger.Handler()), ) require.NoError(t, err) diff --git a/examples/http/main.go b/examples/http/main.go index 49f23a7..53dfd4e 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -81,27 +81,22 @@ func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { } // RunServer initializes and runs the HTTP server with supervisor -// Returns the supervisor and a cleanup function func RunServer( ctx context.Context, logHandler slog.Handler, routes []httpserver.Route, -) (*supervisor.PIDZero, func(), error) { +) (*supervisor.PIDZero, error) { // Create a config callback function that will be used by the runner configCallback := func() (*httpserver.Config, error) { return httpserver.NewConfig(ListenOn, routes, httpserver.WithDrainTimeout(DrainTimeout)) } - // Create the HTTP server runner with a custom context - customCtx, customCancel := context.WithCancel(ctx) - + // Create the HTTP server runner runner, err := httpserver.NewRunner( - httpserver.WithContext(customCtx), httpserver.WithConfigCallback(configCallback), httpserver.WithLogHandler(logHandler)) if err != nil { - customCancel() - return nil, nil, fmt.Errorf("failed to create HTTP server runner: %w", err) + return nil, fmt.Errorf("failed to create HTTP server runner: %w", err) } // Create a PIDZero supervisor and add the runner @@ -110,11 +105,10 @@ func RunServer( supervisor.WithLogHandler(logHandler), supervisor.WithRunnables(runner)) if err != nil { - customCancel() - return nil, nil, fmt.Errorf("failed to create supervisor: %w", err) + return nil, fmt.Errorf("failed to create supervisor: %w", err) } - return sv, customCancel, nil + return sv, nil } func main() { @@ -135,12 +129,11 @@ func main() { os.Exit(1) } - sv, cancel, err := RunServer(ctx, handler, routes) + sv, err := RunServer(ctx, handler, routes) if err != nil { slog.Error("Failed to setup server", "error", err) os.Exit(1) } - defer cancel() // Start the supervisor - this will block until shutdown slog.Info("Starting supervisor with HTTP server on " + ListenOn) diff --git a/examples/http/main_test.go b/examples/http/main_test.go index cff05ae..170bcc8 100644 --- a/examples/http/main_test.go +++ b/examples/http/main_test.go @@ -30,10 +30,9 @@ func TestRunServer(t *testing.T) { require.NoError(t, err, "Failed to build routes") require.NotEmpty(t, routes, "Routes should not be empty") - sv, cleanup, err := RunServer(ctx, logHandler, routes) + sv, err := RunServer(ctx, logHandler, routes) require.NoError(t, err, "RunServer should not return an error") require.NotNil(t, sv, "Supervisor should not be nil") - require.NotNil(t, cleanup, "Cleanup function should not be nil") // Start the server in a goroutine to avoid blocking the test errCh := make(chan error, 1) @@ -54,8 +53,8 @@ func TestRunServer(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "Status: OK\n", string(body)) - // Clean up - cleanup() + // Stop the supervisor + sv.Shutdown() // Check that Run() didn't return an error select { @@ -88,7 +87,6 @@ func TestRunServerInvalidPort(t *testing.T) { // Create HTTP server runner with invalid port runner, err := httpserver.NewRunner( - httpserver.WithContext(ctx), httpserver.WithConfigCallback(configCallback), httpserver.WithLogHandler(logHandler.WithGroup("httpserver")), ) diff --git a/examples/httpcluster/main.go b/examples/httpcluster/main.go index f3a0125..8b650a6 100644 --- a/examples/httpcluster/main.go +++ b/examples/httpcluster/main.go @@ -241,7 +241,6 @@ func createHTTPCluster( // Create the httpcluster cluster, err := httpcluster.NewRunner( httpcluster.WithLogger(logger.WithGroup("httpcluster")), - httpcluster.WithContext(ctx), ) if err != nil { return nil, nil, fmt.Errorf("failed to create httpcluster: %w", err) diff --git a/runnables/composite/options.go b/runnables/composite/options.go index a1c9275..465b8d8 100644 --- a/runnables/composite/options.go +++ b/runnables/composite/options.go @@ -1,7 +1,6 @@ package composite import ( - "context" "log/slog" ) @@ -16,13 +15,3 @@ func WithLogHandler[T runnable](handler slog.Handler) Option[T] { } } } - -// WithContext sets a custom context for the CompositeRunner instance. -// This allows for more granular control over cancellation and timeouts. -func WithContext[T runnable](ctx context.Context) Option[T] { - return func(c *Runner[T]) { - if ctx != nil { - c.parentCtx, c.parentCancel = context.WithCancel(ctx) - } - } -} diff --git a/runnables/composite/options_test.go b/runnables/composite/options_test.go index 95cd299..c1e5eda 100644 --- a/runnables/composite/options_test.go +++ b/runnables/composite/options_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // Mock runnable implementation for testing @@ -17,10 +16,6 @@ func (m *mockRunnable) Run(ctx context.Context) error { return nil } func (m *mockRunnable) Stop() {} func (m *mockRunnable) String() string { return "mockRunnable" } -type contextKeyType string - -const testContextKey contextKeyType = "testKey" - func TestWithLogHandler(t *testing.T) { t.Parallel() @@ -38,54 +33,3 @@ func TestWithLogHandler(t *testing.T) { WithLogHandler[*mockRunnable](nil)(runner) assert.Equal(t, slog.Default(), runner.logger, "Logger should not change with nil handler") } - -func TestWithContext(t *testing.T) { - t.Parallel() - - originalCtx := context.Background() - runner := &Runner[*mockRunnable]{ - parentCtx: nil, - parentCancel: nil, - } - - WithContext[*mockRunnable](originalCtx)(runner) - - // Verify the context and cancel function are set - require.NotNil(t, runner.parentCtx, "Context should be set") - require.NotNil(t, runner.parentCancel, "Cancel function should be set") - - // Test that the contexts are related (child can access parent values) - originalCtx = context.WithValue(context.Background(), testContextKey, "value") - runner = &Runner[*mockRunnable]{} - WithContext[*mockRunnable](originalCtx)(runner) - - assert.Equal(t, - "value", runner.parentCtx.Value(testContextKey), - "Child context should inherit values from parent", - ) - - // Test with empty context - we should get a new cancellable context - // but still be able to verify it's connected to the background context - runner = &Runner[*mockRunnable]{ - parentCtx: nil, - parentCancel: nil, - } - // Use context.Background() instead of nil - WithContext[*mockRunnable](context.Background())(runner) - - // Instead of checking for equality, verify: - // 1. The context is not nil - // 2. The cancel function is not nil - // 3. The context is derived from Background() (it will be a cancel context) - require.NotNil(t, runner.parentCtx, "Context should be set with Background()") - require.NotNil(t, runner.parentCancel, "Cancel function should be set with Background()") - - // Verify it's a cancel context by calling the cancel function and checking if Done channel closes - runner.parentCancel() - select { - case <-runner.parentCtx.Done(): - // This is what we want - the context was canceled - default: - t.Error("Context should be cancellable when created with Background()") - } -} diff --git a/runnables/composite/reload.go b/runnables/composite/reload.go index 773c806..5f01962 100644 --- a/runnables/composite/reload.go +++ b/runnables/composite/reload.go @@ -2,7 +2,6 @@ package composite import ( "fmt" - "log/slog" "github.com/robbyt/go-supervisor/internal/finitestate" "github.com/robbyt/go-supervisor/supervisor" @@ -34,11 +33,13 @@ func (r *Runner[T]) Reload() { newConfig, err := r.configCallback() if err != nil { logger.Error("Failed to get updated config", "error", err) + // TODO: consider removing the setStateError() call here r.setStateError() return } if newConfig == nil { logger.Error("Config callback returned nil during reload") + // TODO: consider removing the setStateError() call here r.setStateError() return } @@ -55,14 +56,14 @@ func (r *Runner[T]) Reload() { logger.Debug( "Membership change detected, stopping all existing runnables before updating membership and config", ) - if err := r.reloadMembershipChanged(newConfig); err != nil { + if err := r.reloadWithRestart(newConfig); err != nil { logger.Error("Failed to reload runnables due to membership change", "error", err) r.setStateError() return } logger.Debug("Reloaded runnables due to membership change") } else { - r.reloadConfig(logger, newConfig) + r.reloadSkipRestart(newConfig) logger.Debug("Reloaded runnables without membership change") } @@ -74,47 +75,60 @@ func (r *Runner[T]) Reload() { } } -// reloadMembershipChanged handles the case where the membership of runnables has changed. -func (r *Runner[T]) reloadMembershipChanged(newConfig *Config[T]) error { +// reloadWithRestart handles the case where the membership of runnables has changed. +func (r *Runner[T]) reloadWithRestart(newConfig *Config[T]) error { + logger := r.logger.WithGroup("reloadWithRestart") + logger.Debug("Reloading runnables due to membership change") + defer logger.Debug("Completed.") + // Stop all existing runnables while we still have the old config // This acquires the runnables mutex - if err := r.stopRunnables(); err != nil { + if err := r.stopAllRunnables(); err != nil { return fmt.Errorf("%w: failed to stop existing runnables during membership change", err) } // Now update the stored config after stopping old runnables // Lock the config mutex for writing + logger.Debug("Updating config after stopping existing runnables") r.configMu.Lock() r.setConfig(newConfig) r.configMu.Unlock() // Start all runnables from the new config // This acquires the runnables mutex - if err := r.boot(r.runCtx); err != nil { + if err := r.boot(r.ctx); err != nil { return fmt.Errorf("%w: failed to start new runnables during membership change", err) } return nil } -// reloadConfig handles the case where the membership of runnables has not changed. -func (r *Runner[T]) reloadConfig(logger *slog.Logger, newConfig *Config[T]) { - logger = logger.WithGroup("reloadConfig") - // No membership change, update config and reload existing runnables +// reloadSkipRestart handles the case where the membership of runnables has not changed. +func (r *Runner[T]) reloadSkipRestart(newConfig *Config[T]) { + logger := r.logger.WithGroup("reloadSkipRestart") + logger.Debug("Reloading runnables without membership change") + defer logger.Debug("Completed.") + + logger.Debug("Updating config") r.configMu.Lock() r.setConfig(newConfig) r.configMu.Unlock() + logger.Debug("Reloading configs of existing runnables") // Reload configs of existing runnables // Runnables mutex not locked as membership is not changing for _, entry := range newConfig.Entries { + logger := logger.With("runnable", entry.Runnable.String()) + if reloadableWithConfig, ok := any(entry.Runnable).(ReloadableWithConfig); ok { // If the runnable implements our ReloadableWithConfig interface, use that to pass the new config - logger.Debug("Reloading child runnable with config", "runnable", entry.Runnable) + logger.Debug("Reloading child runnable with config") reloadableWithConfig.ReloadWithConfig(entry.Config) } else if reloadable, ok := any(entry.Runnable).(supervisor.Reloadable); ok { // Fall back to standard Reloadable interface, assume the configCallback // has somehow updated the runnable's internal state - logger.Debug("Reloading child runnable", "runnable", entry.Runnable) + logger.Debug("Reloading child runnable") reloadable.Reload() + } else { + logger.Warn("Child runnable does not implement Reloadable or ReloadableWithConfig") } } } diff --git a/runnables/composite/reload_test.go b/runnables/composite/reload_test.go index a9bc441..8036715 100644 --- a/runnables/composite/reload_test.go +++ b/runnables/composite/reload_test.go @@ -109,7 +109,6 @@ func TestCompositeRunner_Reload(t *testing.T) { ctx := context.Background() runner, err := NewRunner( configCallback, - WithContext[*mocks.Runnable](ctx), WithLogHandler[*mocks.Runnable](handler), ) require.NoError(t, err) @@ -195,10 +194,7 @@ func TestCompositeRunner_Reload(t *testing.T) { // Create runner and set state to Running ctx := t.Context() - runner, err := NewRunner( - configCallback, - WithContext[*mocks.Runnable](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) require.Equal(t, 0, callbackCalls) @@ -292,10 +288,7 @@ func TestCompositeRunner_Reload(t *testing.T) { // Create runner and set state to Running ctx := t.Context() - runner, err := NewRunner( - configCallback, - WithContext[*mocks.Runnable](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) assert.Equal(t, 0, callbackCalls) @@ -404,10 +397,7 @@ func TestCompositeRunner_Reload(t *testing.T) { // Create runner ctx := t.Context() - runner, err := NewRunner( - configCallback, - WithContext[*mocks.Runnable](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) require.Equal(t, 0, callbackCalls, "Callback should not be called during creation") @@ -497,10 +487,7 @@ func TestCompositeRunner_Reload(t *testing.T) { // Create runner with context ctx := t.Context() - runner, err := NewRunner( - configCallback, - WithContext[*MockReloadableWithConfig](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) assert.Equal(t, 0, callbackCalls) @@ -1038,8 +1025,7 @@ func TestReloadConfig(t *testing.T) { require.NoError(t, err) // Call reloadConfig directly - logger := runner.logger.WithGroup("test") - runner.reloadConfig(logger, config) + runner.reloadSkipRestart(config) // Verify reloadable interface methods were called mockRunnable1.AssertExpectations(t) @@ -1078,8 +1064,7 @@ func TestReloadConfig(t *testing.T) { require.NoError(t, err) // Call reloadConfig directly - logger := runner.logger.WithGroup("test") - runner.reloadConfig(logger, config) + runner.reloadSkipRestart(config) // Verify ReloadWithConfig was called with correct configs mockReloadable1.AssertExpectations(t) @@ -1128,9 +1113,8 @@ func TestReloadConfig(t *testing.T) { require.NoError(t, err) // Call reloadConfig on both runners - logger := runner1.logger.WithGroup("test") - runner1.reloadConfig(logger, config1) - runner2.reloadConfig(logger, config2) + runner1.reloadSkipRestart(config1) + runner2.reloadSkipRestart(config2) // Verify expectations mockRunnable.AssertExpectations(t) @@ -1185,15 +1169,13 @@ func TestReloadMembershipChanged(t *testing.T) { newConfig, err := NewConfig("test", newEntries) require.NoError(t, err) - ctx := t.Context() - // Create callback that initially returns oldEntries configCallback := func() (*Config[*mocks.Runnable], error) { return initialConfig, nil } // Create runner - runner, err := NewRunner(configCallback, WithContext[*mocks.Runnable](ctx)) + runner, err := NewRunner(configCallback) require.NoError(t, err) // Make sure initial config is loaded @@ -1202,13 +1184,12 @@ func TestReloadMembershipChanged(t *testing.T) { assert.Equal(t, 2, len(initialConfigLoaded.Entries)) // Set runCtx (normally done by Run) - runCtx := t.Context() runner.runnablesMu.Lock() - runner.runCtx = runCtx + runner.ctx = t.Context() runner.runnablesMu.Unlock() // Call reloadMembershipChanged directly - err = runner.reloadMembershipChanged(newConfig) + err = runner.reloadWithRestart(newConfig) require.NoError(t, err) // Verify config was updated @@ -1247,14 +1228,13 @@ func TestReloadMembershipChanged(t *testing.T) { require.NoError(t, err) // Set runCtx (normally done by Run) - runCtx := t.Context() runner.runnablesMu.Lock() - runner.runCtx = runCtx + runner.ctx = t.Context() runner.runnablesMu.Unlock() // Call reloadMembershipChanged // This should fail because getConfig() returns nil in stopRunnables - err = runner.reloadMembershipChanged(newConfig) + err = runner.reloadWithRestart(newConfig) // Verify error require.Error(t, err) @@ -1292,13 +1272,10 @@ func TestReloadMembershipChanged(t *testing.T) { runner, err := NewRunner(configCallback) require.NoError(t, err) - // setup a new context for the runner that will be cancelled in bit - ctx, cancel := context.WithCancel(context.Background()) - // Run in a goroutine to avoid blocking errCh := make(chan error, 1) go func() { - errCh <- runner.Run(ctx) + errCh <- runner.Run(t.Context()) }() // Wait for the runner to start @@ -1326,7 +1303,7 @@ func TestReloadMembershipChanged(t *testing.T) { assert.Same(t, mockRunnable, updatedConfig.Entries[0].Runnable) // Clean shutdown - cancel() + runner.Stop() // Wait for runner to complete select { @@ -1352,32 +1329,25 @@ func TestReloadMembershipChanged(t *testing.T) { newConfig, err := NewConfig("test", newEntries) require.NoError(t, err) - // context for the runner - ctx := t.Context() - // Create callback that returns the config configCallback := func() (*Config[*mocks.Runnable], error) { return initialConfig, nil } // Create runner - runner, err := NewRunner(configCallback, - WithContext[*mocks.Runnable](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) // Make sure initial config is loaded runner.currentConfig.Store(initialConfig) // Set runCtx (normally done by Run) - runCtx, runCancel := context.WithCancel(context.Background()) - defer runCancel() runner.runnablesMu.Lock() - runner.runCtx = runCtx + runner.ctx = t.Context() runner.runnablesMu.Unlock() // Call reloadMembershipChanged directly - should no longer error with empty entries - err = runner.reloadMembershipChanged(newConfig) + err = runner.reloadWithRestart(newConfig) require.NoError(t, err) // Verify config was updated @@ -1529,17 +1499,16 @@ func TestHasMembershipChanged(t *testing.T) { } ctx := context.Background() // Use a clean context instead of t.Context() - runner, err := NewRunner(configCallback, - WithContext[*mocks.Runnable](ctx), - ) + runner, err := NewRunner(configCallback) require.NoError(t, err) - runCtx, cancel := context.WithCancel(ctx) - defer cancel() + runner.runnablesMu.Lock() + runner.ctx = ctx + runner.runnablesMu.Unlock() errCh := make(chan error, 1) go func() { - errCh <- runner.Run(runCtx) + errCh <- runner.Run(ctx) }() // Verify runner reaches Running state @@ -1571,7 +1540,7 @@ func TestHasMembershipChanged(t *testing.T) { assert.Equal(t, mockRunnable4, config.Entries[1].Runnable) // Clean shutdown before verification - cancel() + runner.Stop() // Wait for runner to complete require.Eventually(t, func() bool { diff --git a/runnables/composite/runner.go b/runnables/composite/runner.go index 7d3df85..3655306 100644 --- a/runnables/composite/runner.go +++ b/runnables/composite/runner.go @@ -21,12 +21,13 @@ type Runner[T runnable] struct { currentConfig atomic.Pointer[Config[T]] configCallback ConfigCallback[T] - runnablesMu sync.Mutex // Protects runnable operations (start/stop) + runnablesMu sync.Mutex fsm finitestate.Machine - runCtx context.Context // set by Run() to track THIS instance run context - parentCtx context.Context // set by NewRunner() to track the parent context - parentCancel context.CancelFunc + // will be set by Run() + ctx context.Context + cancel context.CancelFunc + serverErrors chan error logger *slog.Logger } @@ -41,18 +42,12 @@ func NewRunner[T runnable]( configCallback ConfigCallback[T], opts ...Option[T], ) (*Runner[T], error) { - // Setup defaults logger := slog.Default().WithGroup("composite.Runner") - parentCtx, parentCancel := context.WithCancel(context.Background()) r := &Runner[T]{ currentConfig: atomic.Pointer[Config[T]]{}, configCallback: configCallback, - - runCtx: context.Background(), // will be replaced in Run() - parentCtx: parentCtx, - parentCancel: parentCancel, - serverErrors: make(chan error, 1), - logger: logger, + serverErrors: make(chan error, 1), + logger: logger, } // Apply options, to override defaults if provided @@ -95,9 +90,10 @@ func (r *Runner[T]) Run(ctx context.Context) error { runCtx, runCancel := context.WithCancel(ctx) defer runCancel() + // store the Run context and cancel function in the runner so that Reload() and Stop() can use them later r.runnablesMu.Lock() - // store the Run context in the runner so that Reload() can use it later - r.runCtx = runCtx + r.ctx = runCtx + r.cancel = runCancel r.runnablesMu.Unlock() // Transition from New to Booting @@ -121,25 +117,18 @@ func (r *Runner[T]) Run(ctx context.Context) error { select { case <-runCtx.Done(): r.logger.Debug("Local context canceled") - case <-r.parentCtx.Done(): - r.logger.Debug("Parent context canceled") case err := <-r.serverErrors: r.setStateError() return fmt.Errorf("%w: %w", ErrRunnableFailed, err) } - // Try to transition to Stopping state - if !r.fsm.TransitionBool(finitestate.StatusStopping) { - if r.fsm.GetState() == finitestate.StatusStopping { - r.logger.Debug("Already in Stopping state, continuing shutdown") - } else { - r.setStateError() - return fmt.Errorf("failed to transition to Stopping state") - } + if err := r.fsm.TransitionIfCurrentState(finitestate.StatusRunning, finitestate.StatusStopping); err != nil { + // This error is expected if we're already stopping, so only log at debug level + r.logger.Debug("Not transitioning to Stopping state", "error", err) } // Stop all child runnables - if err := r.stopRunnables(); err != nil { + if err := r.stopAllRunnables(); err != nil { r.setStateError() return fmt.Errorf("failed to stop runnables: %w", err) } @@ -154,14 +143,17 @@ func (r *Runner[T]) Run(ctx context.Context) error { return nil } -// Stop will cancel the parent context, causing all child runnables to stop. +// Stop will cancel the context, causing all child runnables to stop. func (r *Runner[T]) Stop() { - // Only transition to Stopping if we're currently Running - if err := r.fsm.TransitionIfCurrentState(finitestate.StatusRunning, finitestate.StatusStopping); err != nil { - // This error is expected if we're already stopping, so only log at debug level - r.logger.Debug("Not transitioning to Stopping state", "error", err) + r.runnablesMu.Lock() + cancel := r.cancel + r.runnablesMu.Unlock() + + if cancel == nil { + r.logger.Warn("Cancel function is nil, skipping Stop") + return } - r.parentCancel() + cancel() } // boot starts all child runnables in the order they're defined. @@ -227,8 +219,8 @@ func (r *Runner[T]) startRunnable(ctx context.Context, subRunnable T, idx int) { } } -// stopRunnables stops all child runnables in reverse order (last to first). -func (r *Runner[T]) stopRunnables() error { +// stopAllRunnables stops all child runnables in reverse order (last to first). +func (r *Runner[T]) stopAllRunnables() error { r.runnablesMu.Lock() defer r.runnablesMu.Unlock() diff --git a/runnables/composite/runner_run_test.go b/runnables/composite/runner_run_test.go index 6018bd6..1daf55a 100644 --- a/runnables/composite/runner_run_test.go +++ b/runnables/composite/runner_run_test.go @@ -59,66 +59,6 @@ func TestCompositeRunner_Run_AdditionalScenarios(t *testing.T) { mockFSM.AssertExpectations(t) }) - t.Run("parent context cancellation", func(t *testing.T) { - t.Parallel() - - // Setup mock runnables - mockRunnable1 := mocks.NewMockRunnable() - mockRunnable1.On("String").Return("runnable1").Maybe() - mockRunnable1.On("Run", mock.Anything).Run(func(args mock.Arguments) { - // Wait until context is cancelled - <-args.Get(0).(context.Context).Done() - }).Return(nil) - mockRunnable1.On("Stop").Maybe() - - // Create entries - entries := []RunnableEntry[*mocks.Runnable]{ - {Runnable: mockRunnable1, Config: nil}, - } - - // Create config callback - configCallback := func() (*Config[*mocks.Runnable], error) { - return NewConfig("test", entries) - } - - // Create runner with a cancellable parent context - parentCtx, parentCancel := context.WithCancel(context.Background()) - defer parentCancel() - - runner, err := NewRunner( - configCallback, - WithContext[*mocks.Runnable](parentCtx), - ) - require.NoError(t, err) - - // Run in goroutine - errCh := make(chan error, 1) - go func() { - errCh <- runner.Run(context.Background()) - }() - - // Wait for states to transition to Running - require.Eventually(t, func() bool { - return runner.GetState() == finitestate.StatusRunning - }, 500*time.Millisecond, 10*time.Millisecond, "Runner should transition to Running state") - - // Cancel the parent context - parentCancel() - - // Wait for Run to complete - var runErr error - select { - case runErr = <-errCh: - case <-time.After(200 * time.Millisecond): - t.Fatal("timeout waiting for Run to complete") - } - - // Verify clean shutdown - assert.NoError(t, runErr) - assert.Equal(t, finitestate.StatusStopped, runner.GetState()) - mockRunnable1.AssertExpectations(t) - }) - t.Run("child runnable error propagation", func(t *testing.T) { t.Parallel() @@ -158,7 +98,8 @@ func TestCompositeRunner_Run_AdditionalScenarios(t *testing.T) { mockFSM := new(MockStateMachine) mockFSM.On("Transition", finitestate.StatusBooting).Return(nil) mockFSM.On("Transition", finitestate.StatusRunning).Return(nil) - mockFSM.On("TransitionBool", finitestate.StatusStopping).Return(true) + mockFSM.On("TransitionIfCurrentState", finitestate.StatusRunning, finitestate.StatusStopping). + Return(nil) mockFSM.On("Transition", finitestate.StatusStopped).Return(errors.New("transition error")) mockFSM.On("SetState", finitestate.StatusError).Return(nil) mockFSM.On("GetState").Return(finitestate.StatusStopping).Maybe() @@ -249,7 +190,8 @@ func TestCompositeRunner_Run_AdditionalScenarios(t *testing.T) { mockFSM := new(MockStateMachine) mockFSM.On("Transition", finitestate.StatusBooting).Return(nil) mockFSM.On("Transition", finitestate.StatusRunning).Return(nil) - mockFSM.On("TransitionBool", finitestate.StatusStopping).Return(true) + mockFSM.On("TransitionIfCurrentState", finitestate.StatusRunning, finitestate.StatusStopping). + Return(nil) mockFSM.On("Transition", finitestate.StatusStopped).Return(nil).Maybe() mockFSM.On("GetState").Return(finitestate.StatusStopping).Maybe() diff --git a/runnables/composite/runner_test.go b/runnables/composite/runner_test.go index 8b7311b..eec8921 100644 --- a/runnables/composite/runner_test.go +++ b/runnables/composite/runner_test.go @@ -77,17 +77,6 @@ func TestNewRunner(t *testing.T) { }, expectError: false, }, - { - name: "with custom context", - callback: func() (*Config[*mocks.Runnable], error) { - entries := []RunnableEntry[*mocks.Runnable]{} - return NewConfig("test", entries) - }, - opts: []Option[*mocks.Runnable]{ - WithContext[*mocks.Runnable](context.Background()), - }, - expectError: false, - }, } for _, tt := range tests { @@ -470,17 +459,33 @@ func TestCompositeRunner_Stop(t *testing.T) { // Setup mock runnables and config _, configCallback := setupMocksAndConfig() - // Create runner and manually set state to Running + // Create runner runner, err := NewRunner(configCallback) require.NoError(t, err) + + // Set up cancel function as Run() would + ctx, cancel := context.WithCancel(t.Context()) + runner.runnablesMu.Lock() + runner.ctx = ctx + runner.cancel = cancel + runner.runnablesMu.Unlock() + err = runner.fsm.SetState(finitestate.StatusRunning) require.NoError(t, err) - // Call Stop - just testing the state transition + // Call Stop - should just cancel the context, not change state runner.Stop() - // Verify state transition - assert.Equal(t, finitestate.StatusStopping, runner.GetState()) + // Verify state did not change (Stop only cancels context) + assert.Equal(t, finitestate.StatusRunning, runner.GetState()) + + // Verify context was cancelled + select { + case <-ctx.Done(): + // Good, context was cancelled + default: + t.Error("Context should be cancelled after Stop()") + } }) t.Run("stop from non-running state", func(t *testing.T) { diff --git a/runnables/httpcluster/options.go b/runnables/httpcluster/options.go index 67a8873..862c5ae 100644 --- a/runnables/httpcluster/options.go +++ b/runnables/httpcluster/options.go @@ -1,7 +1,6 @@ package httpcluster import ( - "context" "fmt" "log/slog" "time" @@ -12,14 +11,6 @@ import ( // Option is a function that configures a Runner. type Option func(*Runner) error -// WithContext sets the parent context for the cluster. -func WithContext(ctx context.Context) Option { - return func(r *Runner) error { - r.parentCtx = ctx - return nil - } -} - // WithLogger sets the logger for the cluster. func WithLogger(logger *slog.Logger) Option { return func(r *Runner) error { diff --git a/runnables/httpcluster/options_test.go b/runnables/httpcluster/options_test.go index c070c34..74731a7 100644 --- a/runnables/httpcluster/options_test.go +++ b/runnables/httpcluster/options_test.go @@ -1,7 +1,6 @@ package httpcluster import ( - "context" "log/slog" "testing" @@ -74,18 +73,15 @@ func TestOptionApplicationOrder(t *testing.T) { // Test that multiple options are applied correctly logger := slog.Default().WithGroup("test") - ctx := context.Background() runner, err := NewRunner( WithLogger(logger), - WithContext(ctx), WithStateChanBufferSize(15), WithSiphonBuffer(3), ) require.NoError(t, err) assert.Equal(t, logger, runner.logger) - assert.Equal(t, ctx, runner.parentCtx) assert.Equal(t, 15, runner.stateChanBufferSize) assert.Equal(t, 3, cap(runner.configSiphon)) } diff --git a/runnables/httpcluster/runner.go b/runnables/httpcluster/runner.go index 8ad7c6c..aa35edb 100644 --- a/runnables/httpcluster/runner.go +++ b/runnables/httpcluster/runner.go @@ -27,10 +27,9 @@ type Runner struct { restartDelay time.Duration deadlineServerStart time.Duration - // Context management - similar to composite pattern - parentCtx context.Context // Set during construction - runCtx context.Context // Set during Run() - runCancel context.CancelFunc + // Set by Run() + ctx context.Context + cancel context.CancelFunc // Configuration siphon channel configSiphon chan map[string]*httpserver.Config @@ -64,7 +63,6 @@ func defaultRunnerFactory( return httpserver.NewRunner( httpserver.WithName(id), httpserver.WithConfig(cfg), - httpserver.WithContext(ctx), httpserver.WithLogHandler(handler), ) } @@ -76,7 +74,6 @@ func NewRunner(opts ...Option) (*Runner, error) { logger: slog.Default().WithGroup("httpcluster.Runner"), restartDelay: defaultRestartDelay, deadlineServerStart: defaultDeadlineServerStart, - parentCtx: context.Background(), configSiphon: make( chan map[string]*httpserver.Config, ), // unbuffered by default @@ -172,8 +169,8 @@ func (r *Runner) Run(ctx context.Context) error { r.mu.Lock() runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - r.runCtx = runCtx - r.runCancel = runCancel + r.ctx = runCtx + r.cancel = runCancel r.mu.Unlock() // Transition to running (no servers initially) @@ -189,10 +186,6 @@ func (r *Runner) Run(ctx context.Context) error { logger.Debug("Run context cancelled, initiating shutdown") return r.shutdown(runCtx) - case <-r.parentCtx.Done(): - logger.Debug("Parent context cancelled, initiating shutdown") - return r.shutdown(runCtx) - case newConfigs, ok := <-r.configSiphon: if !ok { logger.Debug("Config siphon closed, initiating shutdown") @@ -214,7 +207,7 @@ func (r *Runner) Stop() { logger.Debug("Stopping") r.mu.Lock() - r.runCancel() + r.cancel() r.mu.Unlock() } diff --git a/runnables/httpcluster/runner_test.go b/runnables/httpcluster/runner_test.go index fad638e..2b9976e 100644 --- a/runnables/httpcluster/runner_test.go +++ b/runnables/httpcluster/runner_test.go @@ -89,7 +89,6 @@ func TestNewRunner(t *testing.T) { require.NotNil(t, runner) assert.NotNil(t, runner.logger) - assert.NotNil(t, runner.parentCtx) assert.NotNil(t, runner.configSiphon) assert.NotNil(t, runner.currentEntries) assert.NotNil(t, runner.fsm) @@ -97,13 +96,10 @@ func TestNewRunner(t *testing.T) { }) t.Run("with options", func(t *testing.T) { - ctx := t.Context() runner, err := NewRunner( - WithContext(ctx), WithSiphonBuffer(1), ) require.NoError(t, err) - assert.Equal(t, ctx, runner.parentCtx) cfg := make(map[string]*httpserver.Config) select { @@ -558,7 +554,7 @@ func TestRunnerExecuteActions(t *testing.T) { ctx := t.Context() runner.mu.Lock() - runner.runCtx = ctx + runner.ctx = ctx runner.mu.Unlock() mockEntries := &MockEntriesManager{} @@ -603,35 +599,6 @@ func TestRunnerExecuteActions(t *testing.T) { func TestRunnerContextManagement(t *testing.T) { t.Parallel() - t.Run("parent context cancellation", func(t *testing.T) { - parentCtx, parentCancel := context.WithCancel(t.Context()) - - runner, err := NewRunner(WithContext(parentCtx)) - require.NoError(t, err) - - runCtx := context.Background() - - runErr := make(chan error, 1) - go func() { - runErr <- runner.Run(runCtx) - }() - - // Wait for running - require.Eventually(t, func() bool { - return runner.IsRunning() - }, time.Second, 10*time.Millisecond) - - // Cancel parent context - parentCancel() - - // Should stop gracefully - select { - case err := <-runErr: - assert.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("Runner should stop when parent context cancelled") - } - }) t.Run("run context setup", func(t *testing.T) { runner, err := NewRunner() @@ -652,8 +619,8 @@ func TestRunnerContextManagement(t *testing.T) { // Check run context is set runner.mu.RLock() - runCtx := runner.runCtx - runCancel := runner.runCancel + runCtx := runner.ctx + runCancel := runner.cancel runner.mu.RUnlock() assert.NotNil(t, runCtx) diff --git a/runnables/httpserver/config_test.go b/runnables/httpserver/config_test.go index 32675f7..b371844 100644 --- a/runnables/httpserver/config_test.go +++ b/runnables/httpserver/config_test.go @@ -298,12 +298,10 @@ func TestContextPropagation(t *testing.T) { return NewConfig(listenPort, hConfig, WithDrainTimeout(2*time.Second)) } - // Create a new context that we'll cancel to trigger shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // Context for the server Run method + ctx := t.Context() server, err := NewRunner( - WithContext(ctx), WithConfigCallback(cfgCallback), ) require.NoError(t, err) @@ -313,7 +311,7 @@ func TestContextPropagation(t *testing.T) { // Start the server in a goroutine go func() { - err := server.Run(context.Background()) + err := server.Run(ctx) runComplete <- err }() @@ -346,7 +344,7 @@ func TestContextPropagation(t *testing.T) { } // Initiate server shutdown - cancel() // This should cancel the context passed to the server + server.Stop() // This should cancel the server's context // Verify that the handler's context was canceled select { diff --git a/runnables/httpserver/helpers_test.go b/runnables/httpserver/helpers_test.go index a52cb98..139216e 100644 --- a/runnables/httpserver/helpers_test.go +++ b/runnables/httpserver/helpers_test.go @@ -1,7 +1,6 @@ package httpserver import ( - "context" "fmt" "net" "net/http" @@ -44,7 +43,7 @@ func createTestServer( return NewConfig(listenPort, hConfig, WithDrainTimeout(drainTimeout)) } - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) diff --git a/runnables/httpserver/options.go b/runnables/httpserver/options.go index d702b15..f91164e 100644 --- a/runnables/httpserver/options.go +++ b/runnables/httpserver/options.go @@ -1,7 +1,6 @@ package httpserver import ( - "context" "log/slog" ) @@ -21,16 +20,6 @@ func WithLogHandler(handler slog.Handler) Option { } } -// WithContext sets a custom context for the Runner instance. -// This allows for more granular control over cancellation and timeouts. -func WithContext(ctx context.Context) Option { - return func(r *Runner) { - if ctx != nil { - r.ctx, r.cancel = context.WithCancel(ctx) - } - } -} - // WithConfigCallback sets the function that will be called to load or reload configuration. // Either this option or WithConfig initializes the Runner instance by providing the // configuration for the HTTP server managed by the Runner. diff --git a/runnables/httpserver/options_test.go b/runnables/httpserver/options_test.go index d1463b6..78f89f4 100644 --- a/runnables/httpserver/options_test.go +++ b/runnables/httpserver/options_test.go @@ -1,7 +1,6 @@ package httpserver import ( - "context" "log/slog" "net/http" "strings" @@ -13,47 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -// Define a custom type for context keys to avoid string collision -type contextKey string - -// TestWithContext verifies the WithContext option works correctly -func TestWithContext(t *testing.T) { - t.Parallel() - // Create a custom context with a value using the type-safe key - testKey := contextKey("test-key") - customCtx := context.WithValue(context.Background(), testKey, "test-value") - // Create a server with the custom context - handler := func(w http.ResponseWriter, r *http.Request) {} - route, err := NewRoute("v1", "/", handler) - require.NoError(t, err) - hConfig := Routes{*route} - cfgCallback := func() (*Config, error) { - return NewConfig(":0", hConfig, WithDrainTimeout(1*time.Second)) - } - server, err := NewRunner(WithContext(context.Background()), - WithConfigCallback(cfgCallback), - WithContext(customCtx)) - require.NoError(t, err) - // Verify the custom context was applied - actualValue := server.ctx.Value(testKey) - assert.Equal(t, "test-value", actualValue, "Context value should be preserved") - // Verify cancellation works through server.Stop() - done := make(chan struct{}) - go func() { - <-server.ctx.Done() - close(done) - }() - // Call Stop to cancel the internal context - server.Stop() - // Wait for the server context to be canceled or timeout - select { - case <-done: - // Success, context was canceled - case <-time.After(1 * time.Second): - t.Fatal("Context cancellation not propagated") - } -} - func TestWithLogHandler(t *testing.T) { t.Parallel() // Create a custom logger with a buffer for testing output @@ -69,7 +27,6 @@ func TestWithLogHandler(t *testing.T) { } // Create a server with the custom logger server, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(cfgCallback), WithLogHandler(customHandler), ) @@ -95,7 +52,6 @@ func TestWithConfig(t *testing.T) { require.NoError(t, err) // Create a server with the static config server, err := NewRunner( - WithContext(context.Background()), WithConfig(staticConfig), ) require.NoError(t, err) @@ -143,7 +99,6 @@ func TestWithServerCreator(t *testing.T) { } // Create a server with the config that has a custom server creator server, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(cfgCallback), ) require.NoError(t, err) diff --git a/runnables/httpserver/reload_mocked_test.go b/runnables/httpserver/reload_mocked_test.go index aa93bf9..d62e733 100644 --- a/runnables/httpserver/reload_mocked_test.go +++ b/runnables/httpserver/reload_mocked_test.go @@ -1,7 +1,6 @@ package httpserver import ( - "context" "errors" "net/http" "testing" @@ -203,7 +202,6 @@ func TestReloadConfig_WithFullRunner(t *testing.T) { // Create the Runner with the config callback runner, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(configCallback), ) require.NoError(t, err) diff --git a/runnables/httpserver/reload_test.go b/runnables/httpserver/reload_test.go index 1a4c1c4..c85a57f 100644 --- a/runnables/httpserver/reload_test.go +++ b/runnables/httpserver/reload_test.go @@ -58,14 +58,19 @@ func TestRapidReload(t *testing.T) { } // Create the Runner instance - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) // Start the server + // Use a background context here because this test needs to control + // the server lifecycle independent of test timeouts + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + done := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(serverCtx) done <- err }() t.Cleanup(func() { @@ -79,7 +84,7 @@ func TestRapidReload(t *testing.T) { }, 2*time.Second, 10*time.Millisecond) // Create context for state monitoring - stateCtx, stateCancel := context.WithCancel(context.Background()) + stateCtx, stateCancel := context.WithCancel(t.Context()) defer stateCancel() stateChan := server.GetStateChan(stateCtx) @@ -170,14 +175,19 @@ func TestReload(t *testing.T) { } // Create the Runner instance - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) // Start the server with valid config + // Use a background context here because this test needs to control + // the server lifecycle independent of test timeouts + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + done := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(serverCtx) done <- err }() @@ -233,14 +243,19 @@ func TestReload(t *testing.T) { } // Create the Runner instance - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) // Start the server + // Use a background context here because this test needs to control + // the server lifecycle independent of test timeouts + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + errChan := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(serverCtx) errChan <- err close(errChan) }() @@ -299,7 +314,7 @@ func TestReload(t *testing.T) { // Create the Runner instance but don't start it - leaving it in New state // This will make the Reloading transition fail since it's not valid from New state - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) @@ -360,7 +375,6 @@ func TestReload(t *testing.T) { // Create a new runner with our callback updatedServer, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(newCfgCallback), ) require.NoError(t, err) @@ -411,7 +425,7 @@ func TestReload(t *testing.T) { } // Set up the runner with the callback - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) // Store the initial config @@ -452,14 +466,19 @@ func TestReload(t *testing.T) { } // Create the Runner instance - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) require.NotNil(t, server) // Start the server + // Use a background context here because this test needs to control + // the server lifecycle independent of test timeouts + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + done := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(serverCtx) done <- err }() t.Cleanup(func() { @@ -475,7 +494,7 @@ func TestReload(t *testing.T) { server.Reload() // Setup state monitoring - stateCtx, stateCancel := context.WithCancel(context.Background()) + stateCtx, stateCancel := context.WithCancel(t.Context()) defer stateCancel() stateChan := server.GetStateChan(stateCtx) diff --git a/runnables/httpserver/runner.go b/runnables/httpserver/runner.go index 184cd10..04bfd76 100644 --- a/runnables/httpserver/runner.go +++ b/runnables/httpserver/runner.go @@ -46,7 +46,8 @@ type Runner struct { serverMutex sync.RWMutex serverErrors chan error - fsm finitestate.Machine + fsm finitestate.Machine + // Set during Run() ctx context.Context cancel context.CancelFunc logger *slog.Logger @@ -57,16 +58,11 @@ func NewRunner(opts ...Option) (*Runner, error) { // Set default logger logger := slog.Default().WithGroup("httpserver.Runner") - // Initialize with a background context by default - ctx, cancel := context.WithCancel(context.Background()) - r := &Runner{ name: "", config: atomic.Pointer[Config]{}, serverCloseOnce: sync.Once{}, serverErrors: make(chan error, 1), - ctx: ctx, - cancel: cancel, logger: logger, } @@ -121,6 +117,12 @@ func (r *Runner) Run(ctx context.Context) error { runCtx, runCancel := context.WithCancel(ctx) defer runCancel() + // Store the context and cancel function + r.mutex.Lock() + r.ctx = runCtx + r.cancel = runCancel + r.mutex.Unlock() + // Transition from New to Booting err := r.fsm.Transition(finitestate.StatusBooting) if err != nil { @@ -145,9 +147,7 @@ func (r *Runner) Run(ctx context.Context) error { select { case <-runCtx.Done(): - r.logger.Debug("Local context canceled") - case <-r.ctx.Done(): - r.logger.Debug("Parent context canceled") + r.logger.Debug("Context canceled") case err := <-r.serverErrors: r.setStateError() return fmt.Errorf("%w: %w", ErrHttpServer, err) @@ -159,7 +159,16 @@ func (r *Runner) Run(ctx context.Context) error { // Stop signals the HTTP server to shut down by canceling its context. func (r *Runner) Stop() { r.logger.Debug("Stopping HTTP server") - r.cancel() + + r.mutex.RLock() + cancel := r.cancel + r.mutex.RUnlock() + + if cancel == nil { + r.logger.Warn("Cancel function is nil, skipping Stop") + return + } + cancel() } // serverReadinessProbe verifies the HTTP server is accepting connections by @@ -227,7 +236,7 @@ func (r *Runner) boot() error { r.serverCloseOnce = sync.Once{} r.serverMutex.Unlock() - r.logger.Info("Starting HTTP server", + r.logger.Debug("Starting HTTP server", "listenOn", listenAddr, "readTimeout", serverCfg.ReadTimeout, "writeTimeout", serverCfg.WriteTimeout, @@ -307,6 +316,7 @@ func (r *Runner) getConfig() *Config { // It uses sync.Once to ensure shutdown occurs only once per server instance. func (r *Runner) stopServer(ctx context.Context) error { var shutdownErr error + //nolint:contextcheck // We intentionally use context.Background() for shutdown timeout r.serverCloseOnce.Do(func() { r.serverMutex.RLock() defer r.serverMutex.RUnlock() @@ -325,7 +335,7 @@ func (r *Runner) stopServer(ctx context.Context) error { } r.logger.Debug("Waiting for graceful HTTP server shutdown...", "timeout", drainTimeout) - shutdownCtx, shutdownCancel := context.WithTimeout(ctx, drainTimeout) + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), drainTimeout) defer shutdownCancel() localErr := r.server.Shutdown(shutdownCtx) diff --git a/runnables/httpserver/runner_context_test.go b/runnables/httpserver/runner_context_test.go index 8f6dcb2..ca474fe 100644 --- a/runnables/httpserver/runner_context_test.go +++ b/runnables/httpserver/runner_context_test.go @@ -22,7 +22,7 @@ func TestContextValuePropagation(t *testing.T) { type contextKey string const testKey contextKey = "test-key" const testValue = "test-value" - parentCtx := context.WithValue(context.Background(), testKey, testValue) + parentCtx := context.WithValue(t.Context(), testKey, testValue) // Create a cancellable context to test cancellation propagation ctx, cancel := context.WithCancel(parentCtx) @@ -66,7 +66,6 @@ func TestContextValuePropagation(t *testing.T) { // Create the runner runner, err := NewRunner( - WithContext(ctx), WithConfigCallback(cfgCallback), ) require.NoError(t, err) diff --git a/runnables/httpserver/runner_race_test.go b/runnables/httpserver/runner_race_test.go index f407370..5c86c33 100644 --- a/runnables/httpserver/runner_race_test.go +++ b/runnables/httpserver/runner_race_test.go @@ -41,7 +41,7 @@ func TestConcurrentReloadsRaceCondition(t *testing.T) { require.NoError(t, err) errChan := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() go func() { @@ -125,7 +125,7 @@ func TestRunnerRaceConditions(t *testing.T) { require.NoError(t, err) errChan := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() go func() { diff --git a/runnables/httpserver/runner_readiness_test.go b/runnables/httpserver/runner_readiness_test.go index fba3526..61abf18 100644 --- a/runnables/httpserver/runner_readiness_test.go +++ b/runnables/httpserver/runner_readiness_test.go @@ -28,7 +28,7 @@ func TestServerReadinessProbe(t *testing.T) { require.NoError(t, err) t.Run("probe_timeout", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) defer cancel() err := runner.serverReadinessProbe(ctx, "test.invalid:80") diff --git a/runnables/httpserver/runner_test.go b/runnables/httpserver/runner_test.go index d63d86f..a207e6a 100644 --- a/runnables/httpserver/runner_test.go +++ b/runnables/httpserver/runner_test.go @@ -28,7 +28,6 @@ func TestBootFailure(t *testing.T) { t.Run("Config callback returns nil", func(t *testing.T) { callback := func() (*Config, error) { return nil, nil } runner, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(callback), ) @@ -40,7 +39,6 @@ func TestBootFailure(t *testing.T) { t.Run("Config callback returns error", func(t *testing.T) { callback := func() (*Config, error) { return nil, errors.New("failed to load config") } runner, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(callback), ) @@ -63,7 +61,6 @@ func TestBootFailure(t *testing.T) { } runner, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(callback), ) @@ -71,7 +68,7 @@ func TestBootFailure(t *testing.T) { assert.NotNil(t, runner) // Test actual run - err = runner.Run(context.Background()) + err = runner.Run(t.Context()) assert.Error(t, err) // With our readiness probe, the error format is different but should be propagated properly assert.ErrorIs(t, err, ErrServerBoot) @@ -114,7 +111,6 @@ func TestCustomServerCreator(t *testing.T) { // Create the runner runner, err := NewRunner( - WithContext(context.Background()), WithConfigCallback(cfgCallback), ) require.NoError(t, err) @@ -177,7 +173,7 @@ func TestRun_ShutdownDeadlineExceeded(t *testing.T) { return NewConfig(listenPort, hConfig, WithDrainTimeout(drainTimeout)) } - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) // Channel to capture Run's completion @@ -185,7 +181,7 @@ func TestRun_ShutdownDeadlineExceeded(t *testing.T) { // Start the server in a goroutine go func() { - err := server.Run(context.Background()) + err := server.Run(t.Context()) done <- err }() @@ -259,7 +255,7 @@ func TestRun_ShutdownWithDrainTimeout(t *testing.T) { return NewConfig(listenPort, hConfig, WithDrainTimeout(drainTimeout)) } - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) // Channel to capture Run's completion @@ -267,7 +263,7 @@ func TestRun_ShutdownWithDrainTimeout(t *testing.T) { // Start the server in a goroutine go func() { - err := server.Run(context.Background()) + err := server.Run(t.Context()) done <- err }() @@ -326,7 +322,7 @@ func TestServerErr(t *testing.T) { // Create two server configs using the same port cfg1 := func() (*Config, error) { return NewConfig(port, hConfig, WithDrainTimeout(0)) } - server1, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfg1)) + server1, err := NewRunner(WithConfigCallback(cfg1)) require.NoError(t, err) // Start the first server @@ -347,11 +343,11 @@ func TestServerErr(t *testing.T) { // Create a second server with the same port cfg2 := func() (*Config, error) { return NewConfig(port, hConfig, WithDrainTimeout(0)) } - server2, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfg2)) + server2, err := NewRunner(WithConfigCallback(cfg2)) require.NoError(t, err) // The second server should fail to start with "address already in use" - err = server2.Run(context.Background()) + err = server2.Run(t.Context()) require.Error(t, err) // The error contains ErrServerBoot, but we can't use ErrorIs here directly // because of how the error is wrapped @@ -380,13 +376,13 @@ func TestServerLifecycle(t *testing.T) { return NewConfig(listenPort, Routes{*route}, WithDrainTimeout(1*time.Second)) } - server, err := NewRunner(WithContext(context.Background()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) // Run the server in a goroutine done := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(t.Context()) done <- err }() @@ -445,7 +441,7 @@ func TestString(t *testing.T) { return NewConfig(listenPort, hConfig, WithDrainTimeout(0)) } - server, err := NewRunner(WithContext(t.Context()), WithConfigCallback(cfgCallback)) + server, err := NewRunner(WithConfigCallback(cfgCallback)) require.NoError(t, err) // Test string representation before starting @@ -455,7 +451,7 @@ func TestString(t *testing.T) { // Start the server done := make(chan error, 1) go func() { - err := server.Run(context.Background()) + err := server.Run(t.Context()) done <- err }() @@ -490,7 +486,6 @@ func TestString(t *testing.T) { } testName := "TestServer" server, err := NewRunner( - WithContext(t.Context()), WithConfigCallback(cfgCallback), WithName(testName), ) diff --git a/runnables/httpserver/state_test.go b/runnables/httpserver/state_test.go index 7d21df2..4bfea81 100644 --- a/runnables/httpserver/state_test.go +++ b/runnables/httpserver/state_test.go @@ -52,7 +52,7 @@ func TestGetStateChan(t *testing.T) { }) // Create a context with timeout for safety - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() // Get the state channel