Skip to content

Commit

Permalink
contextutil: improvements; to be squashed
Browse files Browse the repository at this point in the history
- change semantics of Err(ctx) to tie it to the closest WithErrCancel
  ctx, not the furthest
- override Err() on WithErrCancel such that it returns the same as
  Err(ctx)
  • Loading branch information
andreimatei committed Jan 8, 2022
1 parent 8497ce1 commit 4d2a15f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 31 deletions.
93 changes: 78 additions & 15 deletions pkg/util/contextutil/cancel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,39 @@ import (

type errCancelKey struct{}

// CtxCanceledError is returned by a Context created with WithErrCancel() from
// its Err() method after it is canceled.
//
// errors.Is(CtxCanceledError, context.Canceled) returns true so that
// CtxCanceledError looks somewhat like context.Canceled.
type CtxCanceledError struct {
inner error
}

// Error implements the error interface.
func (e CtxCanceledError) Error() string {
return e.inner.Error()
}

// Unwrap implements the causer interface needed by the errors library.
func (e CtxCanceledError) Unwrap() error {
return e.inner
}

// Is makes errors.Is(CtxCanceledError{}, context.Canceled) return true.
func (e CtxCanceledError) Is(other error) bool {
return other == context.Canceled
}

// NormalFinish is a sentinel error that can be passed to the cancel() function
// returned by WithErrCancel() to signal that the cancellation is done after the
// respective operation has finished. As such, the cancel function will be
// cheaper than otherwise, as it will avoid capturing a stack trace on the
// argument that nobody is expected to be looking at this error.
var NormalFinish = CtxCanceledError{inner: errors.Wrapf(context.Canceled, "operation finished normally")}

// WithErrCancel returns a cancelable context that whose cancellation function
// takes an error. While that error will *not* be returned from `ctx.Err`, the
// takes an error. error will *not* be returned from `ctx.Err`, the
// package-level method `Err` will return the error (annotated with
// errors.WithStackDepth) for the returned context and its descendants.
func WithErrCancel(parent context.Context) (context.Context, func(error)) {
Expand All @@ -32,24 +63,48 @@ func WithErrCancel(parent context.Context) (context.Context, func(error)) {
if err == nil {
err = context.Canceled
}
err = errors.WithStackDepth(err, 1 /* depth */)
defer wrappedCancel() // actually cancel after we've populated our ctx's err
if err != NormalFinish {
err = errors.WithStackDepth(CtxCanceledError{inner: err}, 1 /* depth */)
}
defer wrappedCancel() // actually cancel after we've populated our ctx's inner
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.err = err

// The function has already been called.
if ctx.err != nil {
return
}

// If the parent has already been canceled, we primarily keep the parent's
// error.
if pErr := Err(wrappedCtx); pErr != nil {
if err != NormalFinish {
ctx.err = errors.WithSecondaryError(pErr, err)
} else {
ctx.err = pErr
}
} else {
// From this moment on, ctx.Err() and Err(ctx) will return err. Similarly,
// Err(childCtx) will also return err for every derived context that has
// not yet been canceled, with the exception of contexts created by
// context.WithCancel() for whom Err(childCtx) will start returning err
// even if those contexts were already canceled (because we can't do
// better).
ctx.err = err
}
}
}

// Err returns an error associated to the Context. This is nil unless the
// Context is canceled, and will match `ctx.Err()` for contexts that were
// not derived from a since-canceled parent created via WithErrCancel.
// Err returns an error associated to the Context. This is nil if the Context is
// not canceled. Otherwise, this will match `ctx.Err()` for contexts created
// with WithErrCancel. For such contexts, this is the error passed to the cancel
// function returned by WithErrCancel, or to the cancel function of one of its
// parents.
//
// However, for a Context derived from a since-canceled parent created via
// WithErrCancel, Err returns the error passed to the cancellation function,
// wrapped in an `errors.WithStackDepth` that identifies the caller of the
// `cancel(err)` call. When the Context passed to `Err` has multiple such
// parents, the "most distant" one is returned, under the assumption that
// it provides the original reason for the context chain's cancellation.
// For other contexts, Err returns the parentCtx.Error() on the closest parent
// created with WithErrCancel, if this context was canceled (explicitly or
// implicitly by calling one of its parents). If such a parent doesn't exist, or
// it exists but hasn't yet been been canceled, Err(ctx) returns ctx.Err().
//
// See ExampleWithErrCancel for an example.
func Err(ctx context.Context) error {
Expand All @@ -66,9 +121,8 @@ func Err(ctx context.Context) error {
}
ctx = c.Context

// If it's canceled, remember the error.
if extErr := c.getErr(); extErr != nil {
err = extErr
return extErr
}

// Keep walking.
Expand All @@ -94,3 +148,12 @@ func (ctx *errCancelCtx) Value(key interface{}) interface{} {
}
return ctx.Context.Value(key)
}

func (ctx *errCancelCtx) Err() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
if ctx.err != nil {
return ctx.err
}
return Err(ctx.Context)
}
65 changes: 49 additions & 16 deletions pkg/util/contextutil/cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func ExampleWithErrCancel() {
cancel1(errors.New("explody"))
fmt.Println("ctx1 also canceled:")
fmt.Println(Err(ctx1))
// Note that Err(ctx2) changes from "context canceled" to "explody". That is
// unfortunate, but we don't have enough control to prevent it; we cannot tell
// whether ctx2 or ctx1 was canceled first, so we prefer to go with the rich
// error from ctx1.
fmt.Println(Err(ctx2))
fmt.Println(Err(ctx3))

Expand All @@ -53,7 +57,7 @@ func ExampleWithErrCancel() {
// ctx1 also canceled:
// explody
// explody
// explody
// boom
}

type ctxs struct {
Expand Down Expand Up @@ -102,38 +106,40 @@ func TestWithErrCancel(t *testing.T) {
require.Equal(t, context.Canceled, Err(c.ctx2))
require.Equal(t, context.Canceled, Err(c.ctx3))
}},
{name: "cancel1", do: func(t *testing.T, c ctxs) {
{name: "cancel11", do: func(t *testing.T, c ctxs) { // !!! remove
// ctx1 is an extended context, so it does nice things.
c.cancel1(err1)
require.Equal(t, context.Canceled, c.ctx1.Err())
require.Equal(t, context.Canceled, c.ctx2.Err())
require.Equal(t, context.Canceled, c.ctx3.Err())
require.True(t, errors.Is(c.ctx1.Err(), context.Canceled))
require.True(t, errors.Is(c.ctx2.Err(), context.Canceled))
require.True(t, errors.Is(c.ctx3.Err(), context.Canceled))
require.True(t, errors.Is(Err(c.ctx1), err1))
require.True(t, errors.Is(Err(c.ctx2), err1))
require.True(t, errors.Is(Err(c.ctx3), err1)) // vanilla context
require.True(t, errors.Is(Err(c.ctx3), err1))
}},
{name: "cancel2", do: func(t *testing.T, c ctxs) {
// ctx2 is an extended context, so it does nice things.
c.cancel2(err2)
require.Nil(t, c.ctx1.Err())
require.Equal(t, context.Canceled, c.ctx2.Err())
require.Equal(t, context.Canceled, c.ctx3.Err())
require.True(t, errors.Is(c.ctx2.Err(), context.Canceled))
require.True(t, errors.Is(c.ctx3.Err(), context.Canceled))
require.Nil(t, Err(c.ctx1))
require.True(t, errors.Is(Err(c.ctx2), err2))
require.True(t, errors.Is(Err(c.ctx3), err2)) // vanilla context
}},
{name: "cancel123", do: func(t *testing.T, c ctxs) {
// When multiple rich contexts are canceled, we get the topmost
// nice error back.
// When multiple contexts are canceled, the one canceled first matters.
c.cancel0()
c.cancel1(err1)
c.cancel2(err2)
require.Equal(t, context.Canceled, c.ctx1.Err())
require.Equal(t, context.Canceled, c.ctx2.Err())
require.Equal(t, context.Canceled, c.ctx3.Err())
require.True(t, errors.Is(Err(c.ctx1), err1))
require.True(t, errors.Is(Err(c.ctx2), err1))
require.True(t, errors.Is(Err(c.ctx3), err1)) // vanilla context
require.True(t, errors.Is(c.ctx1.Err(), context.Canceled))
require.True(t, errors.Is(c.ctx2.Err(), context.Canceled))
require.True(t, errors.Is(c.ctx3.Err(), context.Canceled))
// !!! require.Equal(t, context.Canceled, Err(c.ctx1))
require.True(t, errors.Is(Err(c.ctx1), context.Canceled))
require.False(t, errors.Is(Err(c.ctx2), err1))
require.False(t, errors.Is(Err(c.ctx2), err2))
require.False(t, errors.Is(Err(c.ctx3), err1)) // vanilla context
require.False(t, errors.Is(Err(c.ctx3), err2)) // vanilla context
}},
}
for _, tt := range tests {
Expand All @@ -143,6 +149,33 @@ func TestWithErrCancel(t *testing.T) {
}
}

type boomError struct{}

func (boomError) Error() string {
return "boom"
}

func TestWithErr(t *testing.T) {
t.Run("cancel with arbitrary error", func(t *testing.T) {
ctx, cancel := WithErrCancel(context.Background())
cancel(boomError{})
err := ctx.Err()
require.True(t, errors.Is(err, boomError{}))
require.True(t, errors.As(err, &CtxCanceledError{}))
require.True(t, errors.Is(err, context.Canceled))
})

t.Run("cancel NormalFinish", func(t *testing.T) {
ctx, cancel := WithErrCancel(context.Background())
cancel(NormalFinish)
err := ctx.Err()
require.Regexp(t, "context canceled", err)
require.Equal(t, NormalFinish, err)
require.True(t, errors.As(err, &CtxCanceledError{}))
require.True(t, errors.Is(err, context.Canceled))
})
}

func TestWithErrCancelStack(t *testing.T) {
ctx, cancel := WithErrCancel(context.Background())
cancel(errors.New("boom"))
Expand Down

0 comments on commit 4d2a15f

Please sign in to comment.