Skip to content

Commit

Permalink
GODRIVER-2742 Do not perform server selection to determine sessions s…
Browse files Browse the repository at this point in the history
…upport (mongodb#1295)

Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
Co-authored-by: Kevin Albertson <kevin.albertson@10gen.com>
  • Loading branch information
3 people authored Jul 26, 2023
1 parent 5ee10b9 commit 26f508f
Show file tree
Hide file tree
Showing 23 changed files with 1,002 additions and 390 deletions.
37 changes: 37 additions & 0 deletions internal/assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,29 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}

// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if errors.Is(err, target) {
return true
}

var expectedText string
if target != nil {
expectedText = target.Error()
}

chain := buildErrorChainString(err)

return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
"in chain: %s", expectedText, chain,
), msgAndArgs...)
}

// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
Expand Down Expand Up @@ -1036,3 +1059,17 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
}
}
}

func buildErrorChainString(err error) string {
if err == nil {
return ""
}

e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
}
return chain
}
39 changes: 39 additions & 0 deletions internal/ptrutil/int64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package ptrutil

// CompareInt64 is a piecewise function with the following return conditions:
//
// (1) 2, ptr1 != nil AND ptr2 == nil
// (2) 1, *ptr1 > *ptr2
// (3) 0, ptr1 == ptr2 or *ptr1 == *ptr2
// (4) -1, *ptr1 < *ptr2
// (5) -2, ptr1 == nil AND ptr2 != nil
func CompareInt64(ptr1, ptr2 *int64) int {
if ptr1 == ptr2 {
// This will catch the double nil or same-pointer cases.
return 0
}

if ptr1 == nil && ptr2 != nil {
return -2
}

if ptr1 != nil && ptr2 == nil {
return 2
}

if *ptr1 > *ptr2 {
return 1
}

if *ptr1 < *ptr2 {
return -1
}

return 0
}
76 changes: 76 additions & 0 deletions internal/ptrutil/int64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package ptrutil

import (
"testing"

"go.mongodb.org/mongo-driver/internal/assert"
)

func TestCompareInt64(t *testing.T) {
t.Parallel()

int64ToPtr := func(i64 int64) *int64 { return &i64 }
int64Ptr := int64ToPtr(1)

tests := []struct {
name string
ptr1, ptr2 *int64
want int
}{
{
name: "empty",
want: 0,
},
{
name: "ptr1 nil",
ptr2: int64ToPtr(1),
want: -2,
},
{
name: "ptr2 nil",
ptr1: int64ToPtr(1),
want: 2,
},
{
name: "ptr1 and ptr2 have same value, different address",
ptr1: int64ToPtr(1),
ptr2: int64ToPtr(1),
want: 0,
},
{
name: "ptr1 and ptr2 have the same address",
ptr1: int64Ptr,
ptr2: int64Ptr,
want: 0,
},
{
name: "ptr1 GT ptr2",
ptr1: int64ToPtr(1),
ptr2: int64ToPtr(0),
want: 1,
},
{
name: "ptr1 LT ptr2",
ptr1: int64ToPtr(0),
ptr2: int64ToPtr(1),
want: -1,
},
}

for _, test := range tests {
test := test // capture the range variable

t.Run(test.name, func(t *testing.T) {
t.Parallel()

got := CompareInt64(test.ptr1, test.ptr2)
assert.Equal(t, test.want, got, "compareInt64() = %v, wanted %v", got, test.want)
})
}
}
65 changes: 35 additions & 30 deletions mongo/description/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/ptrutil"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/tag"
)
Expand All @@ -31,35 +32,37 @@ type SelectedServer struct {
type Server struct {
Addr address.Address

Arbiters []string
AverageRTT time.Duration
AverageRTTSet bool
Compression []string // compression methods returned by server
CanonicalAddr address.Address
ElectionID primitive.ObjectID
HeartbeatInterval time.Duration
HelloOK bool
Hosts []string
IsCryptd bool
LastError error
LastUpdateTime time.Time
LastWriteTime time.Time
MaxBatchCount uint32
MaxDocumentSize uint32
MaxMessageSize uint32
Members []address.Address
Passives []string
Passive bool
Primary address.Address
ReadOnly bool
ServiceID *primitive.ObjectID // Only set for servers that are deployed behind a load balancer.
SessionTimeoutMinutes uint32
SetName string
SetVersion uint32
Tags tag.Set
TopologyVersion *TopologyVersion
Kind ServerKind
WireVersion *VersionRange
Arbiters []string
AverageRTT time.Duration
AverageRTTSet bool
Compression []string // compression methods returned by server
CanonicalAddr address.Address
ElectionID primitive.ObjectID
HeartbeatInterval time.Duration
HelloOK bool
Hosts []string
IsCryptd bool
LastError error
LastUpdateTime time.Time
LastWriteTime time.Time
MaxBatchCount uint32
MaxDocumentSize uint32
MaxMessageSize uint32
Members []address.Address
Passives []string
Passive bool
Primary address.Address
ReadOnly bool
ServiceID *primitive.ObjectID // Only set for servers that are deployed behind a load balancer.
// Deprecated: Use SessionTimeoutMinutesPtr instead.
SessionTimeoutMinutes uint32
SessionTimeoutMinutesPtr *int64
SetName string
SetVersion uint32
Tags tag.Set
TopologyVersion *TopologyVersion
Kind ServerKind
WireVersion *VersionRange
}

// NewServer creates a new server description from the given hello command response.
Expand Down Expand Up @@ -166,7 +169,9 @@ func NewServer(addr address.Address, response bson.Raw) Server {
desc.LastError = fmt.Errorf("expected 'logicalSessionTimeoutMinutes' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}

desc.SessionTimeoutMinutes = uint32(i64)
desc.SessionTimeoutMinutesPtr = &i64
case "maxBsonObjectSize":
i64, ok := element.Value().AsInt64OK()
if !ok {
Expand Down Expand Up @@ -462,7 +467,7 @@ func (s Server) Equal(other Server) bool {
return false
}

if s.SessionTimeoutMinutes != other.SessionTimeoutMinutes {
if ptrutil.CompareInt64(s.SessionTimeoutMinutesPtr, other.SessionTimeoutMinutesPtr) != 0 {
return false
}

Expand Down
11 changes: 10 additions & 1 deletion mongo/description/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
)

func TestServer(t *testing.T) {
int64ToPtr := func(i64 int64) *int64 { return &i64 }

t.Run("equals", func(t *testing.T) {
defaultServer := Server{}
// Only some of the Server fields affect equality
Expand Down Expand Up @@ -46,7 +48,14 @@ func TestServer(t *testing.T) {
{"passive", Server{Passive: true}, true},
{"primary", Server{Primary: address.Address("foo")}, false},
{"readOnly", Server{ReadOnly: true}, true},
{"sessionTimeoutMinutes", Server{SessionTimeoutMinutes: 1}, false},
{
"sessionTimeoutMinutes",
Server{
SessionTimeoutMinutesPtr: int64ToPtr(1),
SessionTimeoutMinutes: 1,
},
false,
},
{"setName", Server{SetName: "foo"}, false},
{"setVersion", Server{SetVersion: 1}, false},
{"tags", Server{Tags: tag.Set{tag.Tag{"foo", "bar"}}}, false},
Expand Down
12 changes: 7 additions & 5 deletions mongo/description/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (

// Topology contains information about a MongoDB cluster.
type Topology struct {
Servers []Server
SetName string
Kind TopologyKind
SessionTimeoutMinutes uint32
CompatibilityErr error
Servers []Server
SetName string
Kind TopologyKind
// Deprecated: Use SessionTimeoutMinutesPtr instead.
SessionTimeoutMinutes uint32
SessionTimeoutMinutesPtr *int64
CompatibilityErr error
}

// String implements the Stringer interface.
Expand Down
32 changes: 21 additions & 11 deletions mongo/integration/mtest/opmsg_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,28 @@ import (
)

const (
serverAddress = address.Address("localhost:27017")
maxDocumentSize uint32 = 16777216
maxMessageSize uint32 = 48000000
maxBatchCount uint32 = 100000
sessionTimeoutMinutes uint32 = 30
serverAddress = address.Address("localhost:27017")
maxDocumentSize uint32 = 16777216
maxMessageSize uint32 = 48000000
maxBatchCount uint32 = 100000
)

var (
sessionTimeoutMinutes uint32 = 30
sessionTimeoutMinutesInt64 = int64(sessionTimeoutMinutes)

// MockDescription is the server description used for the mock deployment. Each mocked connection returns this
// value from its Description method.
MockDescription = description.Server{
CanonicalAddr: serverAddress,
MaxDocumentSize: maxDocumentSize,
MaxMessageSize: maxMessageSize,
MaxBatchCount: maxBatchCount,
SessionTimeoutMinutes: sessionTimeoutMinutes,
Kind: description.RSPrimary,
CanonicalAddr: serverAddress,
MaxDocumentSize: maxDocumentSize,
MaxMessageSize: maxMessageSize,
MaxBatchCount: maxBatchCount,
// TODO(GODRIVER-2885): This can be removed once legacy
// SessionTimeoutMinutes is removed.
SessionTimeoutMinutes: sessionTimeoutMinutes,
SessionTimeoutMinutesPtr: &sessionTimeoutMinutesInt64,
Kind: description.RSPrimary,
WireVersion: &description.VersionRange{
Max: topology.SupportedWireVersions.Max,
},
Expand Down Expand Up @@ -162,7 +167,12 @@ func (md *mockDeployment) Disconnect(context.Context) error {
func (md *mockDeployment) Subscribe() (*driver.Subscription, error) {
if md.updates == nil {
md.updates = make(chan description.Topology, 1)

md.updates <- description.Topology{
SessionTimeoutMinutesPtr: &sessionTimeoutMinutesInt64,

// TODO(GODRIVER-2885): This can be removed once legacy
// SessionTimeoutMinutes is removed.
SessionTimeoutMinutes: sessionTimeoutMinutes,
}
}
Expand Down
Loading

0 comments on commit 26f508f

Please sign in to comment.