Skip to content

Commit

Permalink
GODRIVER-2496 Simplify maxTimeMS appension. (#1028)
Browse files Browse the repository at this point in the history
Co-authored-by: Preston Vasquez <prestonvs10@gmail.com>
  • Loading branch information
benjirewis and prestonvasquez authored Aug 5, 2022
1 parent f62a497 commit c9fbb05
Show file tree
Hide file tree
Showing 18 changed files with 205 additions and 162 deletions.
52 changes: 16 additions & 36 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) {
Crypt(a.client.cryptFLE).
ServerAPI(a.client.serverAPI).
HasOutputStage(hasOutputStage).
Timeout(a.client.timeout)
Timeout(a.client.timeout).
MaxTime(ao.MaxTime)

if ao.AllowDiskUse != nil {
op.AllowDiskUse(*ao.AllowDiskUse)
Expand All @@ -862,9 +863,6 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) {
if ao.Collation != nil {
op.Collation(bsoncore.Document(ao.Collation.ToDocument()))
}
if ao.MaxTime != nil {
op.MaxTimeMS(int64(*ao.MaxTime / time.Millisecond))
}
if ao.MaxAwaitTime != nil {
cursorOpts.MaxTimeMS = int64(*ao.MaxAwaitTime / time.Millisecond)
}
Expand Down Expand Up @@ -971,16 +969,13 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{},
op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference).
CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name).
Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout)
Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime)
if countOpts.Collation != nil {
op.Collation(bsoncore.Document(countOpts.Collation.ToDocument()))
}
if countOpts.Comment != nil {
op.Comment(*countOpts.Comment)
}
if countOpts.MaxTime != nil {
op.MaxTimeMS(int64(*countOpts.MaxTime / time.Millisecond))
}
if countOpts.Hint != nil {
hintVal, err := transformValue(coll.registry, countOpts.Hint, false, "hint")
if err != nil {
Expand Down Expand Up @@ -1052,32 +1047,30 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context,
rc = nil
}

co := options.MergeEstimatedDocumentCountOptions(opts...)

selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold)
op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock).
Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor).
Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference).
ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout)
Timeout(coll.client.timeout).MaxTime(co.MaxTime)

co := options.MergeEstimatedDocumentCountOptions(opts...)
if co.Comment != nil {
comment, err := transformValue(coll.registry, co.Comment, false, "comment")
if err != nil {
return 0, err
}
op = op.Comment(comment)
}
if co.MaxTime != nil {
op = op.MaxTimeMS(int64(*co.MaxTime / time.Millisecond))
}

retry := driver.RetryNone
if coll.client.retryReads {
retry = driver.RetryOncePerCommand
}
op.Retry(retry)

err = op.Execute(ctx)

return op.Result().N, replaceErrors(err)
}

Expand Down Expand Up @@ -1131,7 +1124,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor).
Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference).
ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout)
Timeout(coll.client.timeout).MaxTime(option.MaxTime)

if option.Collation != nil {
op.Collation(bsoncore.Document(option.Collation.ToDocument()))
Expand All @@ -1143,9 +1136,6 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
}
op.Comment(comment)
}
if option.MaxTime != nil {
op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond))
}
retry := driver.RetryNone
if coll.client.retryReads {
retry = driver.RetryOncePerCommand
Expand Down Expand Up @@ -1225,17 +1215,17 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
rc = nil
}

fo := options.MergeFindOptions(opts...)

selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold)
op := operation.NewFind(f).
Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference).
CommandMonitor(coll.client.monitor).ServerSelector(selector).
ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name).
Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI).
Timeout(coll.client.timeout)
Timeout(coll.client.timeout).MaxTime(fo.MaxTime)

fo := options.MergeFindOptions(opts...)
cursorOpts := coll.client.createBaseCursorOptions()

if fo.AllowDiskUse != nil {
op.AllowDiskUse(*fo.AllowDiskUse)
}
Expand Down Expand Up @@ -1300,9 +1290,6 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
if fo.MaxAwaitTime != nil {
cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond)
}
if fo.MaxTime != nil {
op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond))
}
if fo.Min != nil {
min, err := transformBsoncoreDocument(coll.registry, fo.Min, true, "min")
if err != nil {
Expand Down Expand Up @@ -1482,7 +1469,8 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}
return &SingleResult{err: err}
}
fod := options.MergeFindOneAndDeleteOptions(opts...)
op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout)
op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).
MaxTime(fod.MaxTime)
if fod.Collation != nil {
op = op.Collation(bsoncore.Document(fod.Collation.ToDocument()))
}
Expand All @@ -1493,9 +1481,6 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}
}
op = op.Comment(comment)
}
if fod.MaxTime != nil {
op = op.MaxTimeMS(int64(*fod.MaxTime / time.Millisecond))
}
if fod.Projection != nil {
proj, err := transformBsoncoreDocument(coll.registry, fod.Projection, true, "projection")
if err != nil {
Expand Down Expand Up @@ -1559,7 +1544,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{

fo := options.MergeFindOneAndReplaceOptions(opts...)
op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}).
ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout)
ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime)
if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation {
op = op.BypassDocumentValidation(*fo.BypassDocumentValidation)
}
Expand All @@ -1573,9 +1558,6 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{
}
op = op.Comment(comment)
}
if fo.MaxTime != nil {
op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond))
}
if fo.Projection != nil {
proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection")
if err != nil {
Expand Down Expand Up @@ -1642,7 +1624,8 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}
}

fo := options.MergeFindOneAndUpdateOptions(opts...)
op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout)
op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).
MaxTime(fo.MaxTime)

u, err := transformUpdateValue(coll.registry, update, true)
if err != nil {
Expand Down Expand Up @@ -1670,9 +1653,6 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}
}
op = op.Comment(comment)
}
if fo.MaxTime != nil {
op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond))
}
if fo.Projection != nil {
proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection")
if err != nil {
Expand Down
16 changes: 3 additions & 13 deletions mongo/index_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"errors"
"fmt"
"strconv"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
Expand Down Expand Up @@ -104,9 +103,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption
op = op.BatchSize(*lio.BatchSize)
cursorOpts.BatchSize = *lio.BatchSize
}
if lio.MaxTime != nil {
op = op.MaxTimeMS(int64(*lio.MaxTime / time.Millisecond))
}
op = op.MaxTime(lio.MaxTime)
retry := driver.RetryNone
if iv.coll.client.retryReads {
retry = driver.RetryOncePerCommand
Expand Down Expand Up @@ -258,11 +255,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts ..
Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock).
Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor).
Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI).
Timeout(iv.coll.client.timeout)

if option.MaxTime != nil {
op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond))
}
Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime)
if option.CommitQuorum != nil {
commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum, true, "commitQuorum")
if err != nil {
Expand Down Expand Up @@ -403,10 +396,7 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop
ServerSelector(selector).ClusterClock(iv.coll.client.clock).
Database(iv.coll.db.name).Collection(iv.coll.name).
Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI).
Timeout(iv.coll.client.timeout)
if dio.MaxTime != nil {
op.MaxTimeMS(int64(*dio.MaxTime / time.Millisecond))
}
Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime)

err = op.Execute(ctx)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/operation_legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ func runFindWithOptions(mt *mtest.T) opQuery {
{"$comment", "hello"},
{"$hint", "hintFoo"},
{"$max", maxDoc},
{"$maxTimeMS", int64(10000)},
{"$min", minDoc},
{"$returnKey", false},
{"$showDiskLoc", false},
{"$snapshot", false},
{"$orderby", sort},
{"$maxTimeMS", int64(10000)},
}
return opQuery{
flags: wiremessage.Partial | wiremessage.TailableCursor | wiremessage.NoCursorTimeout | wiremessage.OplogReplay | wiremessage.SecondaryOK,
Expand Down
6 changes: 3 additions & 3 deletions mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,9 @@ func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions {
// be honored if there is no deadline on the operation Context. Timeout can also be set through the "timeoutMS" URI option
// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not inherit a timeout from the Client.
//
// If any Timeout is set (even 0) on the Client, the values of MaxTime on operations, TransactionOptions.MaxCommitTime and
// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and ClientOptions.SocketTimeout or WriteConcern.wTimeout
// will result in undefined behavior.
// If any Timeout is set (even 0) on the Client, the values of MaxTime on operation options, TransactionOptions.MaxCommitTime and
// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and SocketTimeout or WriteConcern.wTimeout will result
// in undefined behavior.
//
// NOTE(benjirewis): SetTimeout represents unstable, provisional API. The behavior of the driver when a Timeout is specified is
// subject to change.
Expand Down
5 changes: 1 addition & 4 deletions mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error {
Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment).
WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand).
CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).
ServerAPI(s.client.serverAPI)
if s.clientSession.CurrentMct != nil {
op.MaxTimeMS(int64(*s.clientSession.CurrentMct / time.Millisecond))
}
ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct)

err = op.Execute(ctx)
// Return error without updating transaction state if it is a timeout, as the transaction has not
Expand Down
2 changes: 2 additions & 0 deletions x/mongo/driver/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ var (
// ErrDeadlineWouldBeExceeded is returned when a Timeout set on an operation would be exceeded
// if the operation were sent to the server.
ErrDeadlineWouldBeExceeded = errors.New("operation not sent to server, as Timeout would be exceeded")
// ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value.
ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation")
)

// QueryFailureError is an error representing a command failure as a document.
Expand Down
67 changes: 45 additions & 22 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ type Operation struct {
// read preference will not be added to the command on wire versions < 13.
IsOutputAggregate bool

// MaxTime specifies the maximum amount of time to allow the operation to run on the server.
MaxTime *time.Duration

// Timeout is the amount of time that this operation can execute before returning an error. The default value
// nil, which means that the timeout of the operation's caller will be used.
Timeout *time.Duration
Expand Down Expand Up @@ -444,12 +447,18 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
first = false
}

// Calculate maxTimeMS value to potentially be appended to the wire message.
maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90(), srvr.RTTMonitor().Stats())
if err != nil {
return err
}

desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}
scratch = scratch[:0]
if desc.WireVersion == nil || desc.WireVersion.Max < 4 {
switch op.Legacy {
case LegacyFind:
return op.legacyFind(ctx, scratch, srvr, conn, desc)
return op.legacyFind(ctx, scratch, srvr, conn, desc, maxTimeMS)
case LegacyGetMore:
return op.legacyGetMore(ctx, scratch, srvr, conn, desc)
case LegacyKillCursors:
Expand All @@ -461,7 +470,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
case LegacyListCollections:
return op.legacyListCollections(ctx, scratch, srvr, conn, desc)
case LegacyListIndexes:
return op.legacyListIndexes(ctx, scratch, srvr, conn, desc)
return op.legacyListIndexes(ctx, scratch, srvr, conn, desc, maxTimeMS)
}
}

Expand All @@ -483,26 +492,6 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
}
}

// Calculate value of 'maxTimeMS' field to potentially append to the wire message based on the current
// context's deadline and the 90th percentile RTT if the ctx is a Timeout Context.
var maxTimeMS uint64
if internal.IsTimeoutContext(ctx) {
if deadline, ok := ctx.Deadline(); ok {
remainingTimeout := time.Until(deadline)

maxTimeMSVal := int64(remainingTimeout/time.Millisecond) -
int64(srvr.RTTMonitor().P90()/time.Millisecond)

// A maxTimeMS value <= 0 indicates that we are already at or past the Context's deadline.
if maxTimeMSVal <= 0 {
return internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"remaining time %v until context deadline is less than or equal to 90th percentile RTT\n%v",
remainingTimeout, srvr.RTTMonitor().Stats())
}
maxTimeMS = uint64(maxTimeMSVal)
}
}

// convert to wire message
if len(scratch) > 0 {
scratch = scratch[:0]
Expand Down Expand Up @@ -1273,6 +1262,40 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer)
// return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime)
}

// calculateMaxTimeMS calculates the value of the 'maxTimeMS' field to potentially append
// to the wire message based on the current context's deadline and the 90th percentile RTT
// if the ctx is a Timeout context. If the context is not a Timeout context, it uses the
// operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is
// not a Timeout context, calculateMaxTimeMS returns 0.
func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, rttStats string) (uint64, error) {
if internal.IsTimeoutContext(ctx) {
if deadline, ok := ctx.Deadline(); ok {
remainingTimeout := time.Until(deadline)
maxTime := remainingTimeout - rtt90

// Always round up to the next millisecond value so we never truncate the calculated
// maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms).
maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond)
if maxTimeMS <= 0 {
return 0, internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"remaining time %v until context deadline is less than or equal to 90th percentile RTT\n%v",
remainingTimeout, rttStats)
}
return uint64(maxTimeMS), nil
}
} else if op.MaxTime != nil {
// Users are not allowed to pass a negative value as MaxTime. A value of 0 would indicate
// no timeout and is allowed.
if *op.MaxTime < 0 {
return 0, ErrNegativeMaxTime
}
// Always round up to the next millisecond value so we never truncate the requested
// MaxTime value (e.g. 400 microseconds evaluates to 1ms, not 0ms).
return uint64((*op.MaxTime + (time.Millisecond - 1)) / time.Millisecond), nil
}
return 0, nil
}

// updateClusterTimes updates the cluster times for the session and cluster clock attached to this
// operation. While the session's AdvanceClusterTime may return an error, this method does not
// because an error being returned from this method will not be returned further up.
Expand Down
Loading

0 comments on commit c9fbb05

Please sign in to comment.