From 35c957097431cf87d5257de6ae950af132f01f2b Mon Sep 17 00:00:00 2001 From: Tyler Yahn Date: Thu, 4 Apr 2024 13:36:34 -0700 Subject: [PATCH] Prevent default ErrorHandler self-delegation (#5137) --- CHANGELOG.md | 1 + handler.go | 12 +- handler_test.go | 3 + internal/global/handler.go | 71 +------ internal/global/handler_test.go | 230 ++--------------------- internal/global/internal_logging_test.go | 2 +- internal/global/state.go | 54 ++++++ internal/global/state_test.go | 63 +++++++ internal/global/util_test.go | 2 + 9 files changed, 156 insertions(+), 282 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f71f4e1d062..0126380d5dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed - Clarify the documentation about equivalence guarantees for the `Set` and `Distinct` types in `go.opentelemetry.io/otel/attribute`. (#5027) +- Prevent default `ErrorHandler` self-delegation. (#5137) - Update all dependencies to address [GO-2024-2687]. (#5139) ### Removed diff --git a/handler.go b/handler.go index f63fbdfc13a..07623b67914 100644 --- a/handler.go +++ b/handler.go @@ -7,12 +7,8 @@ import ( "go.opentelemetry.io/otel/internal/global" ) -var ( - // Compile-time check global.ErrDelegator implements ErrorHandler. - _ ErrorHandler = (*global.ErrDelegator)(nil) - // Compile-time check global.ErrLogger implements ErrorHandler. - _ ErrorHandler = (*global.ErrLogger)(nil) -) +// Compile-time check global.ErrDelegator implements ErrorHandler. +var _ ErrorHandler = (*global.ErrDelegator)(nil) // GetErrorHandler returns the global ErrorHandler instance. // @@ -33,5 +29,5 @@ func GetErrorHandler() ErrorHandler { return global.GetErrorHandler() } // delegate errors to h. func SetErrorHandler(h ErrorHandler) { global.SetErrorHandler(h) } -// Handle is a convenience function for ErrorHandler().Handle(err). -func Handle(err error) { global.Handle(err) } +// Handle is a convenience function for GetErrorHandler().Handle(err). +func Handle(err error) { global.GetErrorHandler().Handle(err) } diff --git a/handler_test.go b/handler_test.go index 4bf8283c358..14cf8ff9c68 100644 --- a/handler_test.go +++ b/handler_test.go @@ -18,6 +18,9 @@ var _ ErrorHandler = &testErrHandler{} func (eh *testErrHandler) Handle(err error) { eh.err = err } func TestGlobalErrorHandler(t *testing.T) { + SetErrorHandler(GetErrorHandler()) + assert.NotPanics(t, func() { Handle(assert.AnError) }, "Default assignment") + e1 := &testErrHandler{} SetErrorHandler(e1) Handle(assert.AnError) diff --git a/internal/global/handler.go b/internal/global/handler.go index 423ee4f8f86..c657ff8e755 100644 --- a/internal/global/handler.go +++ b/internal/global/handler.go @@ -5,23 +5,9 @@ package global // import "go.opentelemetry.io/otel/internal/global" import ( "log" - "os" "sync/atomic" ) -var ( - // GlobalErrorHandler provides an ErrorHandler that can be used - // throughout an OpenTelemetry instrumented project. When a user - // specified ErrorHandler is registered (`SetErrorHandler`) all calls to - // `Handle` and will be delegated to the registered ErrorHandler. - GlobalErrorHandler = defaultErrorHandler() - - // Compile-time check that delegator implements ErrorHandler. - _ ErrorHandler = (*ErrDelegator)(nil) - // Compile-time check that errLogger implements ErrorHandler. - _ ErrorHandler = (*ErrLogger)(nil) -) - // ErrorHandler handles irremediable events. type ErrorHandler interface { // Handle handles any error deemed irremediable by an OpenTelemetry @@ -33,59 +19,18 @@ type ErrDelegator struct { delegate atomic.Pointer[ErrorHandler] } -func (d *ErrDelegator) Handle(err error) { - d.getDelegate().Handle(err) -} +// Compile-time check that delegator implements ErrorHandler. +var _ ErrorHandler = (*ErrDelegator)(nil) -func (d *ErrDelegator) getDelegate() ErrorHandler { - return *d.delegate.Load() +func (d *ErrDelegator) Handle(err error) { + if eh := d.delegate.Load(); eh != nil { + (*eh).Handle(err) + return + } + log.Print(err) } // setDelegate sets the ErrorHandler delegate. func (d *ErrDelegator) setDelegate(eh ErrorHandler) { d.delegate.Store(&eh) } - -func defaultErrorHandler() *ErrDelegator { - d := &ErrDelegator{} - d.setDelegate(&ErrLogger{l: log.New(os.Stderr, "", log.LstdFlags)}) - return d -} - -// ErrLogger logs errors if no delegate is set, otherwise they are delegated. -type ErrLogger struct { - l *log.Logger -} - -// Handle logs err if no delegate is set, otherwise it is delegated. -func (h *ErrLogger) Handle(err error) { - h.l.Print(err) -} - -// GetErrorHandler returns the global ErrorHandler instance. -// -// The default ErrorHandler instance returned will log all errors to STDERR -// until an override ErrorHandler is set with SetErrorHandler. All -// ErrorHandler returned prior to this will automatically forward errors to -// the set instance instead of logging. -// -// Subsequent calls to SetErrorHandler after the first will not forward errors -// to the new ErrorHandler for prior returned instances. -func GetErrorHandler() ErrorHandler { - return GlobalErrorHandler -} - -// SetErrorHandler sets the global ErrorHandler to h. -// -// The first time this is called all ErrorHandler previously returned from -// GetErrorHandler will send errors to h instead of the default logging -// ErrorHandler. Subsequent calls will set the global ErrorHandler, but not -// delegate errors to h. -func SetErrorHandler(h ErrorHandler) { - GlobalErrorHandler.setDelegate(h) -} - -// Handle is a convenience function for ErrorHandler().Handle(err). -func Handle(err error) { - GetErrorHandler().Handle(err) -} diff --git a/internal/global/handler_test.go b/internal/global/handler_test.go index db2c1e6e0b0..b887060894c 100644 --- a/internal/global/handler_test.go +++ b/internal/global/handler_test.go @@ -6,225 +6,35 @@ package global import ( "bytes" "errors" - "io" "log" - "sync" + "os" + "strings" "testing" - - "github.com/stretchr/testify/suite" ) -type testErrCatcher []string - -func (l *testErrCatcher) Write(p []byte) (int, error) { - msg := bytes.TrimRight(p, "\n") - (*l) = append(*l, string(msg)) - return len(msg), nil -} - -func (l *testErrCatcher) Reset() { - *l = testErrCatcher([]string{}) -} - -func (l *testErrCatcher) Got() []string { - return []string(*l) -} - -func causeErr(text string) { - Handle(errors.New(text)) -} - -type HandlerTestSuite struct { - suite.Suite - - origHandler ErrorHandler - errCatcher *testErrCatcher -} - -func (s *HandlerTestSuite) SetupSuite() { - s.errCatcher = new(testErrCatcher) - s.origHandler = GlobalErrorHandler.getDelegate() - - GlobalErrorHandler.setDelegate(&ErrLogger{l: log.New(s.errCatcher, "", 0)}) -} - -func (s *HandlerTestSuite) TearDownSuite() { - GlobalErrorHandler.setDelegate(s.origHandler) -} - -func (s *HandlerTestSuite) SetupTest() { - s.errCatcher.Reset() -} - -func (s *HandlerTestSuite) TearDownTest() { - GlobalErrorHandler.setDelegate(&ErrLogger{l: log.New(s.errCatcher, "", 0)}) -} - -func (s *HandlerTestSuite) TestGlobalHandler() { - errs := []string{"one", "two"} - GetErrorHandler().Handle(errors.New(errs[0])) - Handle(errors.New(errs[1])) - s.Assert().Equal(errs, s.errCatcher.Got()) -} +func TestErrDelegator(t *testing.T) { + buf := new(bytes.Buffer) + log.Default().SetOutput(buf) + t.Cleanup(func() { log.Default().SetOutput(os.Stderr) }) -func (s *HandlerTestSuite) TestDelegatedHandler() { - eh := GetErrorHandler() + e := &ErrDelegator{} - newErrLogger := new(testErrCatcher) - SetErrorHandler(&ErrLogger{l: log.New(newErrLogger, "", 0)}) + err := errors.New("testing") + e.Handle(err) - errs := []string{"TestDelegatedHandler"} - eh.Handle(errors.New(errs[0])) - s.Assert().Equal(errs, newErrLogger.Got()) -} - -func (s *HandlerTestSuite) TestNoDropsOnDelegate() { - causeErr("") - s.Require().Len(s.errCatcher.Got(), 1) - - // Change to another Handler. We are testing this is loss-less. - newErrLogger := new(testErrCatcher) - secondary := &ErrLogger{ - l: log.New(newErrLogger, "", 0), + got := buf.String() + if !strings.Contains(got, err.Error()) { + t.Error("default handler did not log") } - SetErrorHandler(secondary) - - causeErr("") - s.Assert().Len(s.errCatcher.Got(), 1, "original Handler used after delegation") - s.Assert().Len(newErrLogger.Got(), 1, "new Handler not used after delegation") -} - -func (s *HandlerTestSuite) TestAllowMultipleSets() { - notUsed := new(testErrCatcher) + buf.Reset() - secondary := &ErrLogger{l: log.New(notUsed, "", 0)} - SetErrorHandler(secondary) - s.Require().Same(GetErrorHandler(), GlobalErrorHandler, "set changed globalErrorHandler") - s.Require().Same(GlobalErrorHandler.getDelegate(), secondary, "new Handler not set") + var gotErr error + e.setDelegate(fnErrHandler(func(e error) { gotErr = e })) + e.Handle(err) - tertiary := &ErrLogger{l: log.New(notUsed, "", 0)} - SetErrorHandler(tertiary) - s.Require().Same(GetErrorHandler(), GlobalErrorHandler, "set changed globalErrorHandler") - s.Assert().Same(GlobalErrorHandler.getDelegate(), tertiary, "user Handler not overridden") -} - -func TestHandlerTestSuite(t *testing.T) { - suite.Run(t, new(HandlerTestSuite)) -} - -func TestHandlerConcurrentSafe(t *testing.T) { - // In order not to pollute the test output. - SetErrorHandler(&ErrLogger{log.New(io.Discard, "", 0)}) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - SetErrorHandler(&ErrLogger{log.New(io.Discard, "", 0)}) - }() - wg.Add(1) - go func() { - defer wg.Done() - Handle(errors.New("error")) - }() - - wg.Wait() - reset() -} - -func BenchmarkErrorHandler(b *testing.B) { - primary := &ErrLogger{l: log.New(io.Discard, "", 0)} - secondary := &ErrLogger{l: log.New(io.Discard, "", 0)} - tertiary := &ErrLogger{l: log.New(io.Discard, "", 0)} - - GlobalErrorHandler.setDelegate(primary) - - err := errors.New("benchmark error handler") - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - GetErrorHandler().Handle(err) - Handle(err) - - SetErrorHandler(secondary) - GetErrorHandler().Handle(err) - Handle(err) - - SetErrorHandler(tertiary) - GetErrorHandler().Handle(err) - Handle(err) - - GlobalErrorHandler.setDelegate(primary) - } - - reset() -} - -var eh ErrorHandler - -func BenchmarkGetDefaultErrorHandler(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - eh = GetErrorHandler() + if buf.String() != "" { + t.Error("delegate not set") + } else if !errors.Is(gotErr, err) { + t.Error("error not passed to delegate") } } - -func BenchmarkGetDelegatedErrorHandler(b *testing.B) { - SetErrorHandler(&ErrLogger{l: log.New(io.Discard, "", 0)}) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - eh = GetErrorHandler() - } - - reset() -} - -func BenchmarkDefaultErrorHandlerHandle(b *testing.B) { - GlobalErrorHandler.setDelegate( - &ErrLogger{l: log.New(io.Discard, "", 0)}, - ) - - eh := GetErrorHandler() - err := errors.New("benchmark default error handler handle") - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - eh.Handle(err) - } - - reset() -} - -func BenchmarkDelegatedErrorHandlerHandle(b *testing.B) { - eh := GetErrorHandler() - SetErrorHandler(&ErrLogger{l: log.New(io.Discard, "", 0)}) - err := errors.New("benchmark delegated error handler handle") - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - eh.Handle(err) - } - - reset() -} - -func BenchmarkSetErrorHandlerDelegation(b *testing.B) { - alt := &ErrLogger{l: log.New(io.Discard, "", 0)} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - SetErrorHandler(alt) - - reset() - } -} - -func reset() { - GlobalErrorHandler = defaultErrorHandler() -} diff --git a/internal/global/internal_logging_test.go b/internal/global/internal_logging_test.go index 2b55050c8bc..96287146add 100644 --- a/internal/global/internal_logging_test.go +++ b/internal/global/internal_logging_test.go @@ -33,7 +33,7 @@ func TestLoggerConcurrentSafe(t *testing.T) { }() wg.Wait() - reset() + ResetForTest(t) } func TestLogLevel(t *testing.T) { diff --git a/internal/global/state.go b/internal/global/state.go index 976fa610387..204ea142a50 100644 --- a/internal/global/state.go +++ b/internal/global/state.go @@ -14,6 +14,10 @@ import ( ) type ( + errorHandlerHolder struct { + eh ErrorHandler + } + tracerProviderHolder struct { tp trace.TracerProvider } @@ -28,15 +32,59 @@ type ( ) var ( + globalErrorHandler = defaultErrorHandler() globalTracer = defaultTracerValue() globalPropagators = defaultPropagatorsValue() globalMeterProvider = defaultMeterProvider() + delegateErrorHandlerOnce sync.Once delegateTraceOnce sync.Once delegateTextMapPropagatorOnce sync.Once delegateMeterOnce sync.Once ) +// GetErrorHandler returns the global ErrorHandler instance. +// +// The default ErrorHandler instance returned will log all errors to STDERR +// until an override ErrorHandler is set with SetErrorHandler. All +// ErrorHandler returned prior to this will automatically forward errors to +// the set instance instead of logging. +// +// Subsequent calls to SetErrorHandler after the first will not forward errors +// to the new ErrorHandler for prior returned instances. +func GetErrorHandler() ErrorHandler { + return globalErrorHandler.Load().(errorHandlerHolder).eh +} + +// SetErrorHandler sets the global ErrorHandler to h. +// +// The first time this is called all ErrorHandler previously returned from +// GetErrorHandler will send errors to h instead of the default logging +// ErrorHandler. Subsequent calls will set the global ErrorHandler, but not +// delegate errors to h. +func SetErrorHandler(h ErrorHandler) { + current := GetErrorHandler() + + if _, cOk := current.(*ErrDelegator); cOk { + if _, ehOk := h.(*ErrDelegator); ehOk && current == h { + // Do not assign to the delegate of the default ErrDelegator to be + // itself. + Error( + errors.New("no ErrorHandler delegate configured"), + "ErrorHandler remains its current value.", + ) + return + } + } + + delegateErrorHandlerOnce.Do(func() { + if def, ok := current.(*ErrDelegator); ok { + def.setDelegate(h) + } + }) + globalErrorHandler.Store(errorHandlerHolder{eh: h}) +} + // TracerProvider is the internal implementation for global.TracerProvider. func TracerProvider() trace.TracerProvider { return globalTracer.Load().(tracerProviderHolder).tp @@ -126,6 +174,12 @@ func SetMeterProvider(mp metric.MeterProvider) { globalMeterProvider.Store(meterProviderHolder{mp: mp}) } +func defaultErrorHandler() *atomic.Value { + v := &atomic.Value{} + v.Store(errorHandlerHolder{eh: &ErrDelegator{}}) + return v +} + func defaultTracerValue() *atomic.Value { v := &atomic.Value{} v.Store(tracerProviderHolder{tp: &tracerProvider{}}) diff --git a/internal/global/state_test.go b/internal/global/state_test.go index f54ba7ef09b..fe74e7eab8f 100644 --- a/internal/global/state_test.go +++ b/internal/global/state_test.go @@ -15,6 +15,12 @@ import ( tracenoop "go.opentelemetry.io/otel/trace/noop" ) +type nonComparableErrorHandler struct { + ErrorHandler + + nonComparable func() //nolint:structcheck,unused // This is not called. +} + type nonComparableTracerProvider struct { trace.TracerProvider @@ -27,6 +33,63 @@ type nonComparableMeterProvider struct { nonComparable func() //nolint:structcheck,unused // This is not called. } +type fnErrHandler func(error) + +func (f fnErrHandler) Handle(err error) { f(err) } + +var noopEH = fnErrHandler(func(error) {}) + +func TestSetErrorHandler(t *testing.T) { + t.Run("Set With default is a noop", func(t *testing.T) { + ResetForTest(t) + SetErrorHandler(GetErrorHandler()) + + eh, ok := GetErrorHandler().(*ErrDelegator) + if !ok { + t.Fatal("Global ErrorHandler should be the default ErrorHandler") + } + + if eh.delegate.Load() != nil { + t.Fatal("ErrorHandler should not delegate when setting itself") + } + }) + + t.Run("First Set() should replace the delegate", func(t *testing.T) { + ResetForTest(t) + + SetErrorHandler(noopEH) + + _, ok := GetErrorHandler().(*ErrDelegator) + if ok { + t.Fatal("Global ErrorHandler was not changed") + } + }) + + t.Run("Set() should delegate existing ErrorHandlers", func(t *testing.T) { + ResetForTest(t) + + eh := GetErrorHandler() + SetErrorHandler(noopEH) + + errDel, ok := eh.(*ErrDelegator) + if !ok { + t.Fatal("Wrong ErrorHandler returned") + } + + if errDel.delegate.Load() == nil { + t.Fatal("The ErrDelegator should have a delegate") + } + }) + + t.Run("non-comparable types should not panic", func(t *testing.T) { + ResetForTest(t) + + eh := nonComparableErrorHandler{} + assert.NotPanics(t, func() { SetErrorHandler(eh) }, "delegate") + assert.NotPanics(t, func() { SetErrorHandler(eh) }, "replacement") + }) +} + func TestSetTracerProvider(t *testing.T) { t.Run("Set With default is a noop", func(t *testing.T) { ResetForTest(t) diff --git a/internal/global/util_test.go b/internal/global/util_test.go index 0e0659c0ac3..a23d6228d3e 100644 --- a/internal/global/util_test.go +++ b/internal/global/util_test.go @@ -12,9 +12,11 @@ import ( // its Cleanup step. func ResetForTest(t testing.TB) { t.Cleanup(func() { + globalErrorHandler = defaultErrorHandler() globalTracer = defaultTracerValue() globalPropagators = defaultPropagatorsValue() globalMeterProvider = defaultMeterProvider() + delegateErrorHandlerOnce = sync.Once{} delegateTraceOnce = sync.Once{} delegateTextMapPropagatorOnce = sync.Once{} delegateMeterOnce = sync.Once{}