From fd5fd36f83c2ac06a1bf2e366766afb0c9fdb2b9 Mon Sep 17 00:00:00 2001 From: jasonmills Date: Tue, 1 Nov 2022 17:05:20 -0700 Subject: [PATCH] Refactor Done() to use a broadcasted Signal type for future features. (#984) This creates a new public struct type `ShutdownSignal` and moves broadcast of operating system signals to a standalone type in it's own source file. This allows for the cleaner expansion of signaling features. Co-authored-by: Abhinav Gupta Co-authored-by: Sung Yoon Whang --- app.go | 24 ++-------- shutdown.go | 33 +------------- shutdown_test.go | 2 +- signal.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++ signal_test.go | 65 +++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 53 deletions(-) create mode 100644 signal.go create mode 100644 signal_test.go diff --git a/app.go b/app.go index 0f51531ea..ad2b08ce7 100644 --- a/app.go +++ b/app.go @@ -26,11 +26,9 @@ import ( "errors" "fmt" "os" - "os/signal" "reflect" "sort" "strings" - "sync" "time" "go.uber.org/dig" @@ -286,10 +284,9 @@ type App struct { // Decides how we react to errors when building the graph. errorHooks []ErrorHandler validate bool + // Used to signal shutdowns. - donesMu sync.Mutex // guards dones and shutdownSig - dones []chan os.Signal - shutdownSig os.Signal + receivers signalReceivers osExit func(code int) // os.Exit override; used for testing only } @@ -393,6 +390,7 @@ func New(opts ...Option) *App { clock: fxclock.System, startTimeout: DefaultTimeout, stopTimeout: DefaultTimeout, + receivers: newSignalReceivers(), } app.root = &module{ app: app, @@ -666,21 +664,7 @@ func (app *App) Stop(ctx context.Context) (err error) { // Alternatively, a signal can be broadcast to all done channels manually by // using the Shutdown functionality (see the Shutdowner documentation for details). func (app *App) Done() <-chan os.Signal { - c := make(chan os.Signal, 1) - - app.donesMu.Lock() - defer app.donesMu.Unlock() - // If shutdown signal has been received already - // send it and return. If not, wait for user to send a termination - // signal. - if app.shutdownSig != nil { - c <- app.shutdownSig - return c - } - - signal.Notify(c, os.Interrupt, _sigINT, _sigTERM) - app.dones = append(app.dones, c) - return c + return app.receivers.Done() } // StartTimeout returns the configured startup timeout. Apps default to using diff --git a/shutdown.go b/shutdown.go index d5b8488c0..eebb5f1b5 100644 --- a/shutdown.go +++ b/shutdown.go @@ -20,11 +20,6 @@ package fx -import ( - "fmt" - "os" -) - // Shutdowner provides a method that can manually trigger the shutdown of the // application by sending a signal to all open Done channels. Shutdowner works // on applications using Run as well as Start, Done, and Stop. The Shutdowner is @@ -49,35 +44,9 @@ type shutdowner struct { // In practice this means Shutdowner.Shutdown should not be called from an // fx.Invoke, but from a fx.Lifecycle.OnStart hook. func (s *shutdowner) Shutdown(opts ...ShutdownOption) error { - return s.app.broadcastSignal(_sigTERM) + return s.app.receivers.Broadcast(ShutdownSignal{Signal: _sigTERM}) } func (app *App) shutdowner() Shutdowner { return &shutdowner{app: app} } - -func (app *App) broadcastSignal(signal os.Signal) error { - app.donesMu.Lock() - defer app.donesMu.Unlock() - - app.shutdownSig = signal - - var unsent int - for _, done := range app.dones { - select { - case done <- signal: - default: - // shutdown called when done channel has already received a - // termination signal that has not been cleared - unsent++ - } - } - - if unsent != 0 { - return fmt.Errorf("failed to send %v signal to %v out of %v channels", - signal, unsent, len(app.dones), - ) - } - - return nil -} diff --git a/shutdown_test.go b/shutdown_test.go index b6af93f13..a6d0ad508 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -64,7 +64,7 @@ func TestShutdown(t *testing.T) { defer app.RequireStart().RequireStop() assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") - assert.EqualError(t, s.Shutdown(), "failed to send terminated signal to 1 out of 1 channels", + assert.EqualError(t, s.Shutdown(), "send terminated signal: 1/1 channels are blocked", "unexpected error returned when shutdown is called with a blocked channel") assert.NotNil(t, <-done, "done channel did not receive signal") }) diff --git a/signal.go b/signal.go new file mode 100644 index 000000000..4c0b28763 --- /dev/null +++ b/signal.go @@ -0,0 +1,115 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPSignalE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx + +import ( + "fmt" + "os" + "os/signal" + "sync" +) + +// ShutdownSignal is a signal that caused the application to exit. +type ShutdownSignal struct { + Signal os.Signal +} + +// String will render a ShutdownSignal type as a string suitable for printing. +func (sig ShutdownSignal) String() string { + return fmt.Sprintf("%v", sig.Signal) +} + +func newSignalReceivers() signalReceivers { + return signalReceivers{notify: signal.Notify} +} + +type signalReceivers struct { + m sync.Mutex + last *ShutdownSignal + done []chan os.Signal + notify func(c chan<- os.Signal, sig ...os.Signal) +} + +func (recv *signalReceivers) Done() <-chan os.Signal { + recv.m.Lock() + defer recv.m.Unlock() + + ch := make(chan os.Signal, 1) + + // If we had received a signal prior to the call of done, send it's + // os.Signal to the new channel. + // However we still want to have the operating system notify signals to this + // channel should the application receive another. + if recv.last != nil { + ch <- recv.last.Signal + } + + recv.notify(ch, os.Interrupt, _sigINT, _sigTERM) + recv.done = append(recv.done, ch) + return ch +} + +func (recv *signalReceivers) Broadcast(signal ShutdownSignal) error { + recv.m.Lock() + defer recv.m.Unlock() + recv.last = &signal + + channels, unsent := recv.broadcastDone(signal) + + if unsent != 0 { + return &unsentSignalError{ + Signal: signal, + Total: channels, + Unsent: unsent, + } + } + + return nil +} + +func (recv *signalReceivers) broadcastDone(signal ShutdownSignal) (int, int) { + var unsent int + + for _, reader := range recv.done { + select { + case reader <- signal.Signal: + default: + unsent++ + } + } + + return len(recv.done), unsent +} + +type unsentSignalError struct { + Signal ShutdownSignal + Unsent int + Total int +} + +func (err *unsentSignalError) Error() string { + return fmt.Sprintf( + "send %v signal: %v/%v channels are blocked", + err.Signal, + err.Unsent, + err.Total, + ) +} diff --git a/signal_test.go b/signal_test.go new file mode 100644 index 000000000..481b74ec6 --- /dev/null +++ b/signal_test.go @@ -0,0 +1,65 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPSignalE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "syscall" + "testing" +) + +func assertUnsentSignalError( + t *testing.T, + err error, + expected *unsentSignalError, +) { + t.Helper() + + actual := new(unsentSignalError) + + assert.ErrorContains(t, err, "channels are blocked") + if assert.ErrorAs(t, err, &actual, "is unsentSignalError") { + assert.Equal(t, expected, actual) + } +} + +func TestSignal(t *testing.T) { + t.Parallel() + recv := newSignalReceivers() + a := recv.Done() + _ = recv.Done() // we never listen on this + + expected := ShutdownSignal{ + Signal: syscall.SIGTERM, + } + + require.NoError(t, recv.Broadcast(expected), "first broadcast should succeed") + + assertUnsentSignalError(t, recv.Broadcast(expected), &unsentSignalError{ + Signal: expected, + Total: 2, + Unsent: 2, + }) + + assert.Equal(t, expected.Signal, <-a) + assert.Equal(t, expected.Signal, <-recv.Done(), "expect cached signal") +}