From c9fbb05fd7f7961e60bb79f6fa8e80f1e6b58a2b Mon Sep 17 00:00:00 2001 From: Benjamin Rewis <32186188+benjirewis@users.noreply.github.com> Date: Fri, 5 Aug 2022 14:15:41 -0400 Subject: [PATCH] GODRIVER-2496 Simplify maxTimeMS appension. (#1028) Co-authored-by: Preston Vasquez --- mongo/collection.go | 52 +++++--------- mongo/index_view.go | 16 +---- mongo/integration/operation_legacy_test.go | 2 +- mongo/options/clientoptions.go | 6 +- mongo/session.go | 5 +- x/mongo/driver/errors.go | 2 + x/mongo/driver/operation.go | 67 ++++++++++++------ x/mongo/driver/operation/aggregate.go | 15 ++-- .../driver/operation/commit_transaction.go | 13 ++-- x/mongo/driver/operation/count.go | 14 ++-- x/mongo/driver/operation/createIndexes.go | 13 ++-- x/mongo/driver/operation/distinct.go | 12 ++-- x/mongo/driver/operation/drop_indexes.go | 13 ++-- x/mongo/driver/operation/find.go | 13 ++-- x/mongo/driver/operation/find_and_modify.go | 15 ++-- x/mongo/driver/operation/list_indexes.go | 15 ++-- x/mongo/driver/operation_legacy.go | 24 +++++-- x/mongo/driver/operation_test.go | 70 +++++++++++++++++++ 18 files changed, 205 insertions(+), 162 deletions(-) diff --git a/mongo/collection.go b/mongo/collection.go index aa3ffbe958..a10a63120f 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -846,7 +846,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { Crypt(a.client.cryptFLE). ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). - Timeout(a.client.timeout) + Timeout(a.client.timeout). + MaxTime(ao.MaxTime) if ao.AllowDiskUse != nil { op.AllowDiskUse(*ao.AllowDiskUse) @@ -862,9 +863,6 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { if ao.Collation != nil { op.Collation(bsoncore.Document(ao.Collation.ToDocument())) } - if ao.MaxTime != nil { - op.MaxTimeMS(int64(*ao.MaxTime / time.Millisecond)) - } if ao.MaxAwaitTime != nil { cursorOpts.MaxTimeMS = int64(*ao.MaxAwaitTime / time.Millisecond) } @@ -971,16 +969,13 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } if countOpts.Comment != nil { op.Comment(*countOpts.Comment) } - if countOpts.MaxTime != nil { - op.MaxTimeMS(int64(*countOpts.MaxTime / time.Millisecond)) - } if countOpts.Hint != nil { hintVal, err := transformValue(coll.registry, countOpts.Hint, false, "hint") if err != nil { @@ -1052,14 +1047,15 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, rc = nil } + co := options.MergeEstimatedDocumentCountOptions(opts...) + selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).MaxTime(co.MaxTime) - co := options.MergeEstimatedDocumentCountOptions(opts...) if co.Comment != nil { comment, err := transformValue(coll.registry, co.Comment, false, "comment") if err != nil { @@ -1067,9 +1063,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, } op = op.Comment(comment) } - if co.MaxTime != nil { - op = op.MaxTimeMS(int64(*co.MaxTime / time.Millisecond)) - } + retry := driver.RetryNone if coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -1077,7 +1071,6 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, op.Retry(retry) err = op.Execute(ctx) - return op.Result().N, replaceErrors(err) } @@ -1131,7 +1124,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).MaxTime(option.MaxTime) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1143,9 +1136,6 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i } op.Comment(comment) } - if option.MaxTime != nil { - op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) - } retry := driver.RetryNone if coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -1225,17 +1215,17 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, rc = nil } + fo := options.MergeFindOptions(opts...) + selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewFind(f). Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout) + Timeout(coll.client.timeout).MaxTime(fo.MaxTime) - fo := options.MergeFindOptions(opts...) cursorOpts := coll.client.createBaseCursorOptions() - if fo.AllowDiskUse != nil { op.AllowDiskUse(*fo.AllowDiskUse) } @@ -1300,9 +1290,6 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if fo.MaxAwaitTime != nil { cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond) } - if fo.MaxTime != nil { - op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) - } if fo.Min != nil { min, err := transformBsoncoreDocument(coll.registry, fo.Min, true, "min") if err != nil { @@ -1482,7 +1469,8 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} return &SingleResult{err: err} } fod := options.MergeFindOneAndDeleteOptions(opts...) - op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + MaxTime(fod.MaxTime) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1493,9 +1481,6 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} } op = op.Comment(comment) } - if fod.MaxTime != nil { - op = op.MaxTimeMS(int64(*fod.MaxTime / time.Millisecond)) - } if fod.Projection != nil { proj, err := transformBsoncoreDocument(coll.registry, fod.Projection, true, "projection") if err != nil { @@ -1559,7 +1544,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := options.MergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -1573,9 +1558,6 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ } op = op.Comment(comment) } - if fo.MaxTime != nil { - op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) - } if fo.Projection != nil { proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") if err != nil { @@ -1642,7 +1624,8 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} } fo := options.MergeFindOneAndUpdateOptions(opts...) - op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + MaxTime(fo.MaxTime) u, err := transformUpdateValue(coll.registry, update, true) if err != nil { @@ -1670,9 +1653,6 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} } op = op.Comment(comment) } - if fo.MaxTime != nil { - op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) - } if fo.Projection != nil { proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") if err != nil { diff --git a/mongo/index_view.go b/mongo/index_view.go index a393c7e7c5..6bb33f07b2 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -12,7 +12,6 @@ import ( "errors" "fmt" "strconv" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -104,9 +103,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption op = op.BatchSize(*lio.BatchSize) cursorOpts.BatchSize = *lio.BatchSize } - if lio.MaxTime != nil { - op = op.MaxTimeMS(int64(*lio.MaxTime / time.Millisecond)) - } + op = op.MaxTime(lio.MaxTime) retry := driver.RetryNone if iv.coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -258,11 +255,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout) - - if option.MaxTime != nil { - op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) - } + Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime) if option.CommitQuorum != nil { commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum, true, "commitQuorum") if err != nil { @@ -403,10 +396,7 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout) - if dio.MaxTime != nil { - op.MaxTimeMS(int64(*dio.MaxTime / time.Millisecond)) - } + Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime) err = op.Execute(ctx) if err != nil { diff --git a/mongo/integration/operation_legacy_test.go b/mongo/integration/operation_legacy_test.go index 5801f126fc..48bf3fb086 100644 --- a/mongo/integration/operation_legacy_test.go +++ b/mongo/integration/operation_legacy_test.go @@ -190,12 +190,12 @@ func runFindWithOptions(mt *mtest.T) opQuery { {"$comment", "hello"}, {"$hint", "hintFoo"}, {"$max", maxDoc}, - {"$maxTimeMS", int64(10000)}, {"$min", minDoc}, {"$returnKey", false}, {"$showDiskLoc", false}, {"$snapshot", false}, {"$orderby", sort}, + {"$maxTimeMS", int64(10000)}, } return opQuery{ flags: wiremessage.Partial | wiremessage.TailableCursor | wiremessage.NoCursorTimeout | wiremessage.OplogReplay | wiremessage.SecondaryOK, diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 742a4da9b5..603978aff6 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -728,9 +728,9 @@ func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions { // be honored if there is no deadline on the operation Context. Timeout can also be set through the "timeoutMS" URI option // (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not inherit a timeout from the Client. // -// If any Timeout is set (even 0) on the Client, the values of MaxTime on operations, TransactionOptions.MaxCommitTime and -// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and ClientOptions.SocketTimeout or WriteConcern.wTimeout -// will result in undefined behavior. +// If any Timeout is set (even 0) on the Client, the values of MaxTime on operation options, TransactionOptions.MaxCommitTime and +// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and SocketTimeout or WriteConcern.wTimeout will result +// in undefined behavior. // // NOTE(benjirewis): SetTimeout represents unstable, provisional API. The behavior of the driver when a Timeout is specified is // subject to change. diff --git a/mongo/session.go b/mongo/session.go index 06e509810c..a4f18baf01 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -313,10 +313,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI) - if s.clientSession.CurrentMct != nil { - op.MaxTimeMS(int64(*s.clientSession.CurrentMct / time.Millisecond)) - } + ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 20a7de55d6..aa898fbe7a 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -48,6 +48,8 @@ var ( // ErrDeadlineWouldBeExceeded is returned when a Timeout set on an operation would be exceeded // if the operation were sent to the server. ErrDeadlineWouldBeExceeded = errors.New("operation not sent to server, as Timeout would be exceeded") + // ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value. + ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation") ) // QueryFailureError is an error representing a command failure as a document. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 7f4148b322..699c63cfb9 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -217,6 +217,9 @@ type Operation struct { // read preference will not be added to the command on wire versions < 13. IsOutputAggregate bool + // MaxTime specifies the maximum amount of time to allow the operation to run on the server. + MaxTime *time.Duration + // Timeout is the amount of time that this operation can execute before returning an error. The default value // nil, which means that the timeout of the operation's caller will be used. Timeout *time.Duration @@ -444,12 +447,18 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { first = false } + // Calculate maxTimeMS value to potentially be appended to the wire message. + maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90(), srvr.RTTMonitor().Stats()) + if err != nil { + return err + } + desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()} scratch = scratch[:0] if desc.WireVersion == nil || desc.WireVersion.Max < 4 { switch op.Legacy { case LegacyFind: - return op.legacyFind(ctx, scratch, srvr, conn, desc) + return op.legacyFind(ctx, scratch, srvr, conn, desc, maxTimeMS) case LegacyGetMore: return op.legacyGetMore(ctx, scratch, srvr, conn, desc) case LegacyKillCursors: @@ -461,7 +470,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { case LegacyListCollections: return op.legacyListCollections(ctx, scratch, srvr, conn, desc) case LegacyListIndexes: - return op.legacyListIndexes(ctx, scratch, srvr, conn, desc) + return op.legacyListIndexes(ctx, scratch, srvr, conn, desc, maxTimeMS) } } @@ -483,26 +492,6 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { } } - // Calculate value of 'maxTimeMS' field to potentially append to the wire message based on the current - // context's deadline and the 90th percentile RTT if the ctx is a Timeout Context. - var maxTimeMS uint64 - if internal.IsTimeoutContext(ctx) { - if deadline, ok := ctx.Deadline(); ok { - remainingTimeout := time.Until(deadline) - - maxTimeMSVal := int64(remainingTimeout/time.Millisecond) - - int64(srvr.RTTMonitor().P90()/time.Millisecond) - - // A maxTimeMS value <= 0 indicates that we are already at or past the Context's deadline. - if maxTimeMSVal <= 0 { - return internal.WrapErrorf(ErrDeadlineWouldBeExceeded, - "remaining time %v until context deadline is less than or equal to 90th percentile RTT\n%v", - remainingTimeout, srvr.RTTMonitor().Stats()) - } - maxTimeMS = uint64(maxTimeMSVal) - } - } - // convert to wire message if len(scratch) > 0 { scratch = scratch[:0] @@ -1273,6 +1262,40 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime) } +// calculateMaxTimeMS calculates the value of the 'maxTimeMS' field to potentially append +// to the wire message based on the current context's deadline and the 90th percentile RTT +// if the ctx is a Timeout context. If the context is not a Timeout context, it uses the +// operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is +// not a Timeout context, calculateMaxTimeMS returns 0. +func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, rttStats string) (uint64, error) { + if internal.IsTimeoutContext(ctx) { + if deadline, ok := ctx.Deadline(); ok { + remainingTimeout := time.Until(deadline) + maxTime := remainingTimeout - rtt90 + + // Always round up to the next millisecond value so we never truncate the calculated + // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond) + if maxTimeMS <= 0 { + return 0, internal.WrapErrorf(ErrDeadlineWouldBeExceeded, + "remaining time %v until context deadline is less than or equal to 90th percentile RTT\n%v", + remainingTimeout, rttStats) + } + return uint64(maxTimeMS), nil + } + } else if op.MaxTime != nil { + // Users are not allowed to pass a negative value as MaxTime. A value of 0 would indicate + // no timeout and is allowed. + if *op.MaxTime < 0 { + return 0, ErrNegativeMaxTime + } + // Always round up to the next millisecond value so we never truncate the requested + // MaxTime value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + return uint64((*op.MaxTime + (time.Millisecond - 1)) / time.Millisecond), nil + } + return 0, nil +} + // updateClusterTimes updates the cluster times for the session and cluster clock attached to this // operation. While the session's AdvanceClusterTime may return an error, this method does not // because an error being returned from this method will not be returned further up. diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index be311780d3..d33e8e298c 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -30,7 +30,7 @@ type Aggregate struct { collation bsoncore.Document comment *string hint bsoncore.Value - maxTimeMS *int64 + maxTime *time.Duration pipeline bsoncore.Document session *session.Client clock *session.ClusterClock @@ -109,6 +109,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { MinimumWriteConcernWireVersion: 5, ServerAPI: a.serverAPI, IsOutputAggregate: a.hasOutputStage, + MaxTime: a.maxTime, Timeout: a.timeout, }.Execute(ctx, nil) @@ -148,12 +149,6 @@ func (a *Aggregate) command(dst []byte, desc description.SelectedServer) ([]byte dst = bsoncore.AppendValueElement(dst, "hint", a.hint) } - - // Only append specified maxTimeMS if timeout is not also specified. - if a.maxTimeMS != nil && a.timeout == nil { - - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *a.maxTimeMS) - } if a.pipeline != nil { dst = bsoncore.AppendArrayElement(dst, "pipeline", a.pipeline) @@ -230,13 +225,13 @@ func (a *Aggregate) Hint(hint bsoncore.Value) *Aggregate { return a } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (a *Aggregate) MaxTimeMS(maxTimeMS int64) *Aggregate { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (a *Aggregate) MaxTime(maxTime *time.Duration) *Aggregate { if a == nil { a = new(Aggregate) } - a.maxTimeMS = &maxTimeMS + a.maxTime = maxTime return a } diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 14ed7bcd8c..815d2db8cd 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -9,6 +9,7 @@ package operation import ( "context" "errors" + "time" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo/description" @@ -20,7 +21,7 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { - maxTimeMS *int64 + maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -61,6 +62,7 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { Crypt: ct.crypt, Database: ct.database, Deployment: ct.deployment, + MaxTime: ct.maxTime, Selector: ct.selector, WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, @@ -71,22 +73,19 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { func (ct *CommitTransaction) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendInt32Element(dst, "commitTransaction", 1) - if ct.maxTimeMS != nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *ct.maxTimeMS) - } if ct.recoveryToken != nil { dst = bsoncore.AppendDocumentElement(dst, "recoveryToken", ct.recoveryToken) } return dst, nil } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (ct *CommitTransaction) MaxTimeMS(maxTimeMS int64) *CommitTransaction { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (ct *CommitTransaction) MaxTime(maxTime *time.Duration) *CommitTransaction { if ct == nil { ct = new(CommitTransaction) } - ct.maxTimeMS = &maxTimeMS + ct.maxTime = maxTime return ct } diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 756bb5f620..924e923eea 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -24,7 +24,7 @@ import ( // Count represents a count operation. type Count struct { - maxTimeMS *int64 + maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -120,6 +120,7 @@ func (c *Count) Execute(ctx context.Context) error { Crypt: c.crypt, Database: c.database, Deployment: c.deployment, + MaxTime: c.maxTime, ReadConcern: c.readConcern, ReadPreference: c.readPreference, Selector: c.selector, @@ -142,24 +143,19 @@ func (c *Count) command(dst []byte, desc description.SelectedServer) ([]byte, er if c.query != nil { dst = bsoncore.AppendDocumentElement(dst, "query", c.query) } - - // Only append specified maxTimeMS if timeout is not also specified. - if c.maxTimeMS != nil && c.timeout == nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *c.maxTimeMS) - } if c.comment.Type != bsontype.Type(0) { dst = bsoncore.AppendValueElement(dst, "comment", c.comment) } return dst, nil } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (c *Count) MaxTimeMS(maxTimeMS int64) *Count { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (c *Count) MaxTime(maxTime *time.Duration) *Count { if c == nil { c = new(Count) } - c.maxTimeMS = &maxTimeMS + c.maxTime = maxTime return c } diff --git a/x/mongo/driver/operation/createIndexes.go b/x/mongo/driver/operation/createIndexes.go index b828bc1860..3fa55406f1 100644 --- a/x/mongo/driver/operation/createIndexes.go +++ b/x/mongo/driver/operation/createIndexes.go @@ -25,7 +25,7 @@ import ( type CreateIndexes struct { commitQuorum bsoncore.Value indexes bsoncore.Document - maxTimeMS *int64 + maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -112,6 +112,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { Crypt: ci.crypt, Database: ci.database, Deployment: ci.deployment, + MaxTime: ci.maxTime, Selector: ci.selector, WriteConcern: ci.writeConcern, ServerAPI: ci.serverAPI, @@ -131,10 +132,6 @@ func (ci *CreateIndexes) command(dst []byte, desc description.SelectedServer) ([ if ci.indexes != nil { dst = bsoncore.AppendArrayElement(dst, "indexes", ci.indexes) } - // Only append specified maxTimeMS if timeout is not also specified. - if ci.maxTimeMS != nil && ci.timeout == nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *ci.maxTimeMS) - } return dst, nil } @@ -160,13 +157,13 @@ func (ci *CreateIndexes) Indexes(indexes bsoncore.Document) *CreateIndexes { return ci } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (ci *CreateIndexes) MaxTimeMS(maxTimeMS int64) *CreateIndexes { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (ci *CreateIndexes) MaxTime(maxTime *time.Duration) *CreateIndexes { if ci == nil { ci = new(CreateIndexes) } - ci.maxTimeMS = &maxTimeMS + ci.maxTime = maxTime return ci } diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index c4b16cd7e7..0ea303b418 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -25,7 +25,7 @@ import ( type Distinct struct { collation bsoncore.Document key *string - maxTimeMS *int64 + maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -99,6 +99,7 @@ func (d *Distinct) Execute(ctx context.Context) error { Crypt: d.crypt, Database: d.database, Deployment: d.deployment, + MaxTime: d.maxTime, ReadConcern: d.readConcern, ReadPreference: d.readPreference, Selector: d.selector, @@ -122,9 +123,6 @@ func (d *Distinct) command(dst []byte, desc description.SelectedServer) ([]byte, if d.key != nil { dst = bsoncore.AppendStringElement(dst, "key", *d.key) } - if d.maxTimeMS != nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *d.maxTimeMS) - } if d.query != nil { dst = bsoncore.AppendDocumentElement(dst, "query", d.query) } @@ -151,13 +149,13 @@ func (d *Distinct) Key(key string) *Distinct { return d } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (d *Distinct) MaxTimeMS(maxTimeMS int64) *Distinct { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (d *Distinct) MaxTime(maxTime *time.Duration) *Distinct { if d == nil { d = new(Distinct) } - d.maxTimeMS = &maxTimeMS + d.maxTime = maxTime return d } diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 20ca3668be..9ce91f7d39 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -23,7 +23,7 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { index *string - maxTimeMS *int64 + maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -94,6 +94,7 @@ func (di *DropIndexes) Execute(ctx context.Context) error { Crypt: di.crypt, Database: di.database, Deployment: di.deployment, + MaxTime: di.maxTime, Selector: di.selector, WriteConcern: di.writeConcern, ServerAPI: di.serverAPI, @@ -107,10 +108,6 @@ func (di *DropIndexes) command(dst []byte, desc description.SelectedServer) ([]b if di.index != nil { dst = bsoncore.AppendStringElement(dst, "index", *di.index) } - // Only append specified maxTimeMS if timeout is not also specified. - if di.maxTimeMS != nil && di.timeout == nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *di.maxTimeMS) - } return dst, nil } @@ -125,13 +122,13 @@ func (di *DropIndexes) Index(index string) *DropIndexes { return di } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (di *DropIndexes) MaxTimeMS(maxTimeMS int64) *DropIndexes { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (di *DropIndexes) MaxTime(maxTime *time.Duration) *DropIndexes { if di == nil { di = new(DropIndexes) } - di.maxTimeMS = &maxTimeMS + di.maxTime = maxTime return di } diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 5ccbf9f91a..b03c5866a9 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -34,7 +34,7 @@ type Find struct { let bsoncore.Document limit *int64 max bsoncore.Document - maxTimeMS *int64 + maxTime *time.Duration min bsoncore.Document noCursorTimeout *bool oplogReplay *bool @@ -98,6 +98,7 @@ func (f *Find) Execute(ctx context.Context) error { Crypt: f.crypt, Database: f.database, Deployment: f.deployment, + MaxTime: f.maxTime, ReadConcern: f.readConcern, ReadPreference: f.readPreference, Selector: f.selector, @@ -149,10 +150,6 @@ func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, err if f.max != nil { dst = bsoncore.AppendDocumentElement(dst, "max", f.max) } - // Only append specified maxTimeMS if timeout is not also specified. - if f.maxTimeMS != nil && f.timeout == nil { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *f.maxTimeMS) - } if f.min != nil { dst = bsoncore.AppendDocumentElement(dst, "min", f.min) } @@ -299,13 +296,13 @@ func (f *Find) Max(max bsoncore.Document) *Find { return f } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (f *Find) MaxTimeMS(maxTimeMS int64) *Find { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (f *Find) MaxTime(maxTime *time.Duration) *Find { if f == nil { f = new(Find) } - f.maxTimeMS = &maxTimeMS + f.maxTime = maxTime return f } diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index deaabea313..b7c006cd7a 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -29,7 +29,7 @@ type FindAndModify struct { collation bsoncore.Document comment bsoncore.Value fields bsoncore.Document - maxTimeMS *int64 + maxTime *time.Duration newDocument *bool query bsoncore.Document remove *bool @@ -137,6 +137,7 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { CommandMonitor: fam.monitor, Database: fam.database, Deployment: fam.deployment, + MaxTime: fam.maxTime, Selector: fam.selector, WriteConcern: fam.writeConcern, Crypt: fam.crypt, @@ -173,12 +174,6 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( dst = bsoncore.AppendDocumentElement(dst, "fields", fam.fields) } - - // Only append specified maxTimeMS if timeout is not also specified. - if fam.maxTimeMS != nil && fam.timeout == nil { - - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *fam.maxTimeMS) - } if fam.newDocument != nil { dst = bsoncore.AppendBooleanElement(dst, "new", *fam.newDocument) @@ -269,13 +264,13 @@ func (fam *FindAndModify) Fields(fields bsoncore.Document) *FindAndModify { return fam } -// MaxTimeMS specifies the maximum amount of time to allow the operation to run. -func (fam *FindAndModify) MaxTimeMS(maxTimeMS int64) *FindAndModify { +// MaxTime specifies the maximum amount of time to allow the operation to run on the server. +func (fam *FindAndModify) MaxTime(maxTime *time.Duration) *FindAndModify { if fam == nil { fam = new(FindAndModify) } - fam.maxTimeMS = &maxTimeMS + fam.maxTime = maxTime return fam } diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 39e6f25cc8..d8da68c726 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -21,7 +21,7 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { batchSize *int32 - maxTimeMS *int64 + maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -75,6 +75,7 @@ func (li *ListIndexes) Execute(ctx context.Context) error { CommandMonitor: li.monitor, Database: li.database, Deployment: li.deployment, + MaxTime: li.maxTime, Selector: li.selector, Crypt: li.crypt, Legacy: driver.LegacyListIndexes, @@ -94,12 +95,6 @@ func (li *ListIndexes) command(dst []byte, desc description.SelectedServer) ([]b cursorDoc = bsoncore.AppendInt32Element(cursorDoc, "batchSize", *li.batchSize) } - - // Only append specified maxTimeMS if timeout is not also specified. - if li.maxTimeMS != nil && li.timeout == nil { - - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *li.maxTimeMS) - } cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx) dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc) @@ -116,13 +111,13 @@ func (li *ListIndexes) BatchSize(batchSize int32) *ListIndexes { return li } -// MaxTimeMS specifies the maximum amount of time to allow the query to run. -func (li *ListIndexes) MaxTimeMS(maxTimeMS int64) *ListIndexes { +// MaxTime specifies the maximum amount of time to allow the query to run on the server. +func (li *ListIndexes) MaxTime(maxTime *time.Duration) *ListIndexes { if li == nil { li = new(ListIndexes) } - li.maxTimeMS = &maxTimeMS + li.maxTime = maxTime return li } diff --git a/x/mongo/driver/operation_legacy.go b/x/mongo/driver/operation_legacy.go index 2584f484ad..bbb0cb054f 100644 --- a/x/mongo/driver/operation_legacy.go +++ b/x/mongo/driver/operation_legacy.go @@ -32,8 +32,9 @@ func (op Operation) getFullCollectionName(coll string) string { return op.Database + "." + coll } -func (op Operation) legacyFind(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error { - wm, startedInfo, collName, err := op.createLegacyFindWireMessage(dst, desc) +func (op Operation) legacyFind(ctx context.Context, dst []byte, srvr Server, conn Connection, + desc description.SelectedServer, maxTimeMS uint64) error { + wm, startedInfo, collName, err := op.createLegacyFindWireMessage(dst, desc, maxTimeMS) if err != nil { return err } @@ -68,7 +69,7 @@ func (op Operation) legacyFind(ctx context.Context, dst []byte, srvr Server, con } // returns wire message, collection name, error -func (op Operation) createLegacyFindWireMessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) { +func (op Operation) createLegacyFindWireMessage(dst []byte, desc description.SelectedServer, maxTimeMS uint64) ([]byte, startedInformation, string, error) { info := startedInformation{ requestID: wiremessage.NextRequestID(), cmdName: "find", @@ -84,6 +85,11 @@ func (op Operation) createLegacyFindWireMessage(dst []byte, desc description.Sel if err != nil { return dst, info, "", err } + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + cmdDoc = bsoncore.AppendInt64Element(cmdDoc, "maxTimeMS", int64(maxTimeMS)) + } cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIndex) // for monitoring legacy events, the upconverted document should be captured rather than the legacy one info.cmd = cmdDoc @@ -523,8 +529,9 @@ func (op Operation) transformListCollectionsFilter(filter bsoncore.Document) (bs return combinedFilter, nil } -func (op Operation) legacyListIndexes(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error { - wm, startedInfo, collName, err := op.createLegacyListIndexesWiremessage(dst, desc) +func (op Operation) legacyListIndexes(ctx context.Context, dst []byte, srvr Server, conn Connection, + desc description.SelectedServer, maxTimeMS uint64) error { + wm, startedInfo, collName, err := op.createLegacyListIndexesWiremessage(dst, desc, maxTimeMS) if err != nil { return err } @@ -558,7 +565,7 @@ func (op Operation) legacyListIndexes(ctx context.Context, dst []byte, srvr Serv return nil } -func (op Operation) createLegacyListIndexesWiremessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) { +func (op Operation) createLegacyListIndexesWiremessage(dst []byte, desc description.SelectedServer, maxTimeMS uint64) ([]byte, startedInformation, string, error) { info := startedInformation{ cmdName: "find", requestID: wiremessage.NextRequestID(), @@ -573,6 +580,11 @@ func (op Operation) createLegacyListIndexesWiremessage(dst []byte, desc descript if err != nil { return dst, info, "", err } + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + cmdDoc = bsoncore.AppendInt64Element(cmdDoc, "maxTimeMS", int64(maxTimeMS)) + } cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIndex) info.cmd, err = op.convertCommandToFind(cmdDoc, listIndexesNamespace) if err != nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 443fd44e53..2734888328 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -282,6 +282,76 @@ func TestOperation(t *testing.T) { } }) }) + t.Run("calculateMaxTimeMS", func(t *testing.T) { + timeout := 5 * time.Second + maxTime := 2 * time.Second + negMaxTime := -2 * time.Second + shortRTT := 50 * time.Millisecond + longRTT := 10 * time.Second + timeoutCtx, cancel := internal.MakeTimeoutContext(context.Background(), timeout) + defer cancel() + + testCases := []struct { + name string + op Operation + ctx context.Context + rtt90 time.Duration + want uint64 + err error + }{ + { + name: "uses context deadline and rtt90 with timeout", + op: Operation{MaxTime: &maxTime}, + ctx: timeoutCtx, + rtt90: shortRTT, + want: 5000, + err: nil, + }, + { + name: "uses MaxTime without timeout", + op: Operation{MaxTime: &maxTime}, + ctx: context.Background(), + rtt90: longRTT, + want: 2000, + err: nil, + }, + { + name: "errors when remaining timeout is less than rtt90", + op: Operation{MaxTime: &maxTime}, + ctx: timeoutCtx, + rtt90: timeout, + want: 0, + err: ErrDeadlineWouldBeExceeded, + }, + { + name: "errors when MaxTime is negative", + op: Operation{MaxTime: &negMaxTime}, + ctx: context.Background(), + rtt90: longRTT, + want: 0, + err: ErrNegativeMaxTime, + }, + } + for _, tc := range testCases { + // Capture test-case for parallel sub-test. + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90, "") + + // Assert that the calculated maxTimeMS is less than or equal to the expected value. A few + // milliseconds will have elapsed toward the context deadline, and (remainingTimeout + // - rtt90) will be slightly smaller than the expected value. + if got > tc.want { + t.Errorf("maxTimeMS value higher than expected. got %v; wanted at most %v", got, tc.want) + } + if !errors.Is(err, tc.err) { + t.Errorf("error values do not match. got %v; want %v", err, tc.err) + } + }) + } + }) t.Run("updateClusterTimes", func(t *testing.T) { clustertime := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,