Skip to content

Commit caeb478

Browse files
author
Divjot Arora
authored
GODRIVER-1898 SDAM error handling changes for LB mode (#611)
1 parent 31673c0 commit caeb478

File tree

11 files changed

+371
-47
lines changed

11 files changed

+371
-47
lines changed

event/monitoring.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ type PoolEvent struct {
8989
ConnectionID uint64 `json:"connectionId"`
9090
PoolOptions *MonitorPoolOptions `json:"options"`
9191
Reason string `json:"reason"`
92+
// ServerID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field
93+
// can be used to distinguish between individual servers in a load balanced deployment.
94+
ServerID *primitive.ObjectID `json:"serverId"`
9295
}
9396

9497
// PoolMonitor is a function that allows the user to gain access to events occurring in the pool

x/mongo/driver/topology/CMAP_spec_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ func runOperationInThread(t *testing.T, operation map[string]interface{}, testIn
419419
}
420420
return c.Close()
421421
case "clear":
422-
s.pool.clear()
422+
s.pool.clear(nil)
423423
case "close":
424424
return s.pool.disconnect(context.Background())
425425
default:

x/mongo/driver/topology/connection.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
9191
cancellationListener: internal.NewCancellationListener(),
9292
poolMonitor: cfg.poolMonitor,
9393
}
94+
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
95+
// at any point during connection establishment can be processed without the connection being considered stale.
96+
if !c.config.loadBalanced {
97+
c.setGenerationNumber()
98+
}
9499
atomic.StoreInt32(&c.connected, initialized)
95100

96101
return c, nil
@@ -104,8 +109,29 @@ func (c *connection) processInitializationError(err error) {
104109

105110
c.connectErr = ConnectionError{Wrapped: err, init: true}
106111
if c.config.errorHandlingCallback != nil {
107-
c.config.errorHandlingCallback(c.connectErr, c.generation)
112+
c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServerID)
113+
}
114+
}
115+
116+
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
117+
// configuration.
118+
func (c *connection) setGenerationNumber() {
119+
if c.config.getGenerationFn != nil {
120+
c.generation = c.config.getGenerationFn(c.desc.ServerID)
121+
}
122+
}
123+
124+
// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
125+
// generationNumberFn provided via the connection options has been called exactly once.
126+
func (c *connection) hasGenerationNumber() bool {
127+
if !c.config.loadBalanced {
128+
// The generation is known for all non-LB clusters once the connection object has been created.
129+
return true
108130
}
131+
132+
// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
133+
// description has been updated to reflect that it's behind an LB.
134+
return c.desc.LoadBalanced()
109135
}
110136

111137
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
@@ -212,6 +238,13 @@ func (c *connection) connect(ctx context.Context) {
212238
}
213239
}
214240
if err == nil {
241+
// For load-balanced connections, the generation number depends on the server ID, which isn't known until the
242+
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
243+
// number unless GetHandshakeInformation succeeds.
244+
if c.config.loadBalanced {
245+
c.setGenerationNumber()
246+
}
247+
215248
// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
216249
// the handshake.
217250
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)

x/mongo/driver/topology/connection_options.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net"
77
"time"
88

9+
"go.mongodb.org/mongo-driver/bson/primitive"
910
"go.mongodb.org/mongo-driver/event"
1011
"go.mongodb.org/mongo-driver/x/mongo/driver"
1112
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
@@ -35,6 +36,9 @@ var DefaultDialer Dialer = &net.Dialer{}
3536
// initialization. Implementations must be goroutine safe.
3637
type Handshaker = driver.Handshaker
3738

39+
// generationNumberFn is a callback type used by a connection to fetch its generation number given its server ID.
40+
type generationNumberFn func(serverID *primitive.ObjectID) uint64
41+
3842
type connectionConfig struct {
3943
appName string
4044
connectTimeout time.Duration
@@ -51,9 +55,10 @@ type connectionConfig struct {
5155
zstdLevel *int
5256
ocspCache ocsp.Cache
5357
disableOCSPEndpointCheck bool
54-
errorHandlingCallback func(error, uint64)
58+
errorHandlingCallback func(error, uint64, *primitive.ObjectID)
5559
tlsConnectionSource tlsConnectionSource
5660
loadBalanced bool
61+
getGenerationFn generationNumberFn
5762
}
5863

5964
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
@@ -87,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C
8792
}
8893
}
8994

90-
func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption {
95+
func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption {
9196
return func(c *connectionConfig) error {
9297
c.errorHandlingCallback = fn
9398
return nil
@@ -217,3 +222,10 @@ func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
217222
return nil
218223
}
219224
}
225+
226+
func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
227+
return func(c *connectionConfig) error {
228+
c.getGenerationFn = fn(c.getGenerationFn)
229+
return nil
230+
}
231+
}

x/mongo/driver/topology/connection_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"time"
1818

1919
"github.com/google/go-cmp/cmp"
20+
"go.mongodb.org/mongo-driver/bson/primitive"
2021
"go.mongodb.org/mongo-driver/internal"
2122
"go.mongodb.org/mongo-driver/internal/testutil/assert"
2223
"go.mongodb.org/mongo-driver/mongo/address"
@@ -128,7 +129,7 @@ func TestConnection(t *testing.T) {
128129
return &net.TCPConn{}, nil
129130
})
130131
}),
131-
withErrorHandlingCallback(func(err error, _ uint64) {
132+
withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) {
132133
got = err
133134
}),
134135
)

x/mongo/driver/topology/pool.go

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"sync/atomic"
1414
"time"
1515

16+
"go.mongodb.org/mongo-driver/bson/primitive"
1617
"go.mongodb.org/mongo-driver/event"
1718
"go.mongodb.org/mongo-driver/mongo/address"
1819
"golang.org/x/sync/semaphore"
@@ -60,7 +61,7 @@ type pool struct {
6061
address address.Address
6162
opts []ConnectionOption
6263
conns *resourcePool // pool for non-checked out connections
63-
generation uint64 // must be accessed using atomic package
64+
generation *poolGenerationMap
6465
monitor *event.PoolMonitor
6566

6667
connected int32 // Must be accessed using the sync/atomic package.
@@ -148,13 +149,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
148149
}
149150

150151
pool := &pool{
151-
address: config.Address,
152-
monitor: config.PoolMonitor,
153-
connected: disconnected,
154-
opened: make(map[uint64]*connection),
155-
opts: opts,
156-
sem: semaphore.NewWeighted(int64(maxConns)),
152+
address: config.Address,
153+
monitor: config.PoolMonitor,
154+
connected: disconnected,
155+
opened: make(map[uint64]*connection),
156+
opts: opts,
157+
sem: semaphore.NewWeighted(int64(maxConns)),
158+
generation: newPoolGenerationMap(),
157159
}
160+
pool.opts = append(pool.opts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection }))
158161

159162
// we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
160163
rpc := resourcePoolConfig{
@@ -189,14 +192,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
189192

190193
// stale checks if a given connection's generation is below the generation of the pool
191194
func (p *pool) stale(c *connection) bool {
192-
return c == nil || c.generation < atomic.LoadUint64(&p.generation)
195+
return c == nil || p.generation.stale(c.desc.ServerID, c.generation)
193196
}
194197

195198
// connect puts the pool into the connected state, allowing it to be used and will allow items to begin being processed from the wait queue
196199
func (p *pool) connect() error {
197200
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
198201
return ErrPoolConnected
199202
}
203+
p.generation.connect()
200204
p.conns.initialize()
201205
return nil
202206
}
@@ -212,7 +216,7 @@ func (p *pool) disconnect(ctx context.Context) error {
212216
}
213217

214218
p.conns.Close()
215-
atomic.AddUint64(&p.generation, 1)
219+
p.generation.disconnect()
216220

217221
var err error
218222
if dl, ok := ctx.Deadline(); ok {
@@ -277,7 +281,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {
277281

278282
c.pool = p
279283
c.poolID = atomic.AddUint64(&p.nextid, 1)
280-
c.generation = atomic.LoadUint64(&p.generation)
281284

282285
if p.monitor != nil {
283286
p.monitor.Event(&event.PoolEvent{
@@ -310,10 +313,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {
310313

311314
}
312315

313-
func (p *pool) getGeneration() uint64 {
314-
return atomic.LoadUint64(&p.generation)
315-
}
316-
317316
// Checkout returns a connection from the pool
318317
func (p *pool) get(ctx context.Context) (*connection, error) {
319318
if ctx == nil {
@@ -487,6 +486,10 @@ func (p *pool) closeConnection(c *connection) error {
487486
return nil
488487
}
489488

489+
func (p *pool) getGenerationForNewConnection(serverID *primitive.ObjectID) uint64 {
490+
return p.generation.addConnection(serverID)
491+
}
492+
490493
// removeConnection removes a connection from the pool.
491494
func (p *pool) removeConnection(c *connection, reason string) error {
492495
if c.pool != p {
@@ -501,6 +504,12 @@ func (p *pool) removeConnection(c *connection, reason string) error {
501504
}
502505
p.Unlock()
503506

507+
// Only update the generation numbers map if the connection has retrieved its generation number. Otherwise, we'd
508+
// decrement the count for the generation even though it had never been incremented.
509+
if c.hasGenerationNumber() {
510+
p.generation.removeConnection(c.desc.ServerID)
511+
}
512+
504513
if publishEvent && p.monitor != nil {
505514
c.pool.monitor.Event(&event.PoolEvent{
506515
Type: event.ConnectionClosed,
@@ -545,12 +554,13 @@ func (p *pool) put(c *connection) error {
545554
}
546555

547556
// clear clears the pool by incrementing the generation
548-
func (p *pool) clear() {
557+
func (p *pool) clear(serverID *primitive.ObjectID) {
549558
if p.monitor != nil {
550559
p.monitor.Event(&event.PoolEvent{
551-
Type: event.PoolCleared,
552-
Address: p.address.String(),
560+
Type: event.PoolCleared,
561+
Address: p.address.String(),
562+
ServerID: serverID,
553563
})
554564
}
555-
atomic.AddUint64(&p.generation, 1)
565+
p.generation.clear(serverID)
556566
}

0 commit comments

Comments
 (0)