Skip to content

Commit

Permalink
Stop treating context errors as network errors where possible. (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored Aug 17, 2022
1 parent d4a625b commit e720278
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 60 deletions.
11 changes: 3 additions & 8 deletions mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
default:
}

// End if context has timed out or been canceled, as retrying has no chance of success.
if ctx.Err() != nil {
return res, err
}
if errorHasLabel(err, driver.TransientTransactionError) {
continue
}
Expand All @@ -218,10 +214,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
CommitLoop:
for {
err = s.CommitTransaction(ctx)
// End when error is nil (transaction has been committed), or when context has timed out or been
// canceled, as retrying has no chance of success.
if err == nil || ctx.Err() != nil {
return res, err
// End when error is nil, as transaction has been committed.
if err == nil {
return res, nil
}

select {
Expand Down
12 changes: 7 additions & 5 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,17 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
serviceID: startedInfo.serviceID,
}

// Check if there's enough time to perform a round trip before the Context deadline. If ctx is
// a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed
// RTT.
if deadline, ok := ctx.Deadline(); ok {
// Check for possible context error. If no context error, check if there's enough time to perform a
// round trip before the Context deadline. If ctx is a Timeout Context, use the 90th percentile RTT
// as a threshold. Otherwise, use the minimum observed RTT.
if ctx.Err() != nil {
err = ctx.Err()
} else if deadline, ok := ctx.Deadline(); ok {
if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) {
err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"remaining time %v until context deadline is less than 90th percentile RTT\n%v", time.Until(deadline), srvr.RTTMonitor().Stats())
} else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) {
err = op.networkError(context.DeadlineExceeded)
err = context.DeadlineExceeded
}
}

Expand Down
42 changes: 42 additions & 0 deletions x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,48 @@ func TestOperation(t *testing.T) {
assert.Nil(t, err, "ExecuteExhaust error: %v", err)
assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
})
t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) {
conn := new(mockConnection)
// Create a context that's already timed out.
ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0))
defer cancel()

op := Operation{
Database: "foobar",
Deployment: SingleConnectionDeployment{C: conn},
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
return dst, nil
},
}

err := op.Execute(ctx, nil)
assert.NotNil(t, err, "expected an error from Execute(), got nil")
// Assert that error is just context deadline exceeded and is therefore not a driver.Error marked
// with the TransientTransactionError label.
assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err)
})
t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) {
conn := new(mockConnection)
// Create a context and cancel it immediately.
ctx, cancel := context.WithCancel(context.Background())
cancel()

op := Operation{
Database: "foobar",
Deployment: SingleConnectionDeployment{C: conn},
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
return dst, nil
},
}

err := op.Execute(ctx, nil)
assert.NotNil(t, err, "expected an error from Execute(), got nil")
// Assert that error is just context canceled and is therefore not a driver.Error marked with
// the TransientTransactionError label.
assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err)
})
}

func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
Expand Down
13 changes: 0 additions & 13 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
default:
}

var deadline time.Time
if c.writeTimeout != 0 {
Expand Down Expand Up @@ -388,14 +383,6 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}

select {
case <-ctx.Done():
// We closeConnection the connection because we don't know if there is an unread message on the wire.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
default:
}

var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
Expand Down
14 changes: 0 additions & 14 deletions x/mongo/driver/topology/connection_errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,5 @@ func TestConnectionErrors(t *testing.T) {
err := conn.connect(ctx)
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
t.Run("write error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
err := conn.writeWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
t.Run("read error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
_, err := conn.readWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
})
}
20 changes: 0 additions & 20 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,6 @@ func TestConnection(t *testing.T) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
got := conn.writeWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("deadlines", func(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -490,16 +480,6 @@ func TestConnection(t *testing.T) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
_, got := conn.readWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("deadlines", func(t *testing.T) {
testCases := []struct {
name string
Expand Down

0 comments on commit e720278

Please sign in to comment.