Skip to content

GODRIVER-1898 SDAM error handling changes for LB mode #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions event/monitoring.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ type PoolEvent struct {
ConnectionID uint64 `json:"connectionId"`
PoolOptions *MonitorPoolOptions `json:"options"`
Reason string `json:"reason"`
// ServerID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field
// can be used to distinguish between individual servers in a load balanced deployment.
ServerID *primitive.ObjectID `json:"serverId"`
}

// PoolMonitor is a function that allows the user to gain access to events occurring in the pool
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/CMAP_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func runOperationInThread(t *testing.T, operation map[string]interface{}, testIn
}
return c.Close()
case "clear":
s.pool.clear()
s.pool.clear(nil)
case "close":
return s.pool.disconnect(context.Background())
default:
Expand Down
35 changes: 34 additions & 1 deletion x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
cancellationListener: internal.NewCancellationListener(),
poolMonitor: cfg.poolMonitor,
}
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
// at any point during connection establishment can be processed without the connection being considered stale.
if !c.config.loadBalanced {
c.setGenerationNumber()
}
atomic.StoreInt32(&c.connected, initialized)

return c, nil
Expand All @@ -104,8 +109,29 @@ func (c *connection) processInitializationError(err error) {

c.connectErr = ConnectionError{Wrapped: err, init: true}
if c.config.errorHandlingCallback != nil {
c.config.errorHandlingCallback(c.connectErr, c.generation)
c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServerID)
}
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
if c.config.getGenerationFn != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would getGenerationFn ever be nil? The serverID => (generation, count) map is used to determine generation regardless of whether we are in load balanced mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unit tests, it can be. I verified that removing the line in pool.go that sets this function causes some of the new spec tests to fail though.

c.generation = c.config.getGenerationFn(c.desc.ServerID)
}
}

// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
// generationNumberFn provided via the connection options has been called exactly once.
func (c *connection) hasGenerationNumber() bool {
if !c.config.loadBalanced {
// The generation is known for all non-LB clusters once the connection object has been created.
return true
}

// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
// description has been updated to reflect that it's behind an LB.
return c.desc.LoadBalanced()
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform
Expand Down Expand Up @@ -212,6 +238,13 @@ func (c *connection) connect(ctx context.Context) {
}
}
if err == nil {
// For load-balanced connections, the generation number depends on the server ID, which isn't known until the
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
// number unless GetHandshakeInformation succeeds.
if c.config.loadBalanced {
c.setGenerationNumber()
}

// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
// the handshake.
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
Expand Down
16 changes: 14 additions & 2 deletions x/mongo/driver/topology/connection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"time"

"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
Expand Down Expand Up @@ -35,6 +36,9 @@ var DefaultDialer Dialer = &net.Dialer{}
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker

// generationNumberFn is a callback type used by a connection to fetch its generation number given its server ID.
type generationNumberFn func(serverID *primitive.ObjectID) uint64

type connectionConfig struct {
appName string
connectTimeout time.Duration
Expand All @@ -51,9 +55,10 @@ type connectionConfig struct {
zstdLevel *int
ocspCache ocsp.Cache
disableOCSPEndpointCheck bool
errorHandlingCallback func(error, uint64)
errorHandlingCallback func(error, uint64, *primitive.ObjectID)
tlsConnectionSource tlsConnectionSource
loadBalanced bool
getGenerationFn generationNumberFn
}

func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
Expand Down Expand Up @@ -87,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C
}
}

func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption {
func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption {
return func(c *connectionConfig) error {
c.errorHandlingCallback = fn
return nil
Expand Down Expand Up @@ -217,3 +222,10 @@ func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
return nil
}
}

func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
return func(c *connectionConfig) error {
c.getGenerationFn = fn(c.getGenerationFn)
return nil
}
}
3 changes: 2 additions & 1 deletion x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/address"
Expand Down Expand Up @@ -128,7 +129,7 @@ func TestConnection(t *testing.T) {
return &net.TCPConn{}, nil
})
}),
withErrorHandlingCallback(func(err error, _ uint64) {
withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) {
got = err
}),
)
Expand Down
46 changes: 28 additions & 18 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"

"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/address"
"golang.org/x/sync/semaphore"
Expand Down Expand Up @@ -60,7 +61,7 @@ type pool struct {
address address.Address
opts []ConnectionOption
conns *resourcePool // pool for non-checked out connections
generation uint64 // must be accessed using atomic package
generation *poolGenerationMap
monitor *event.PoolMonitor

connected int32 // Must be accessed using the sync/atomic package.
Expand Down Expand Up @@ -148,13 +149,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
}

pool := &pool{
address: config.Address,
monitor: config.PoolMonitor,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
sem: semaphore.NewWeighted(int64(maxConns)),
address: config.Address,
monitor: config.PoolMonitor,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
sem: semaphore.NewWeighted(int64(maxConns)),
generation: newPoolGenerationMap(),
}
pool.opts = append(pool.opts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection }))

// we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
rpc := resourcePoolConfig{
Expand Down Expand Up @@ -189,14 +192,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {

// stale checks if a given connection's generation is below the generation of the pool
func (p *pool) stale(c *connection) bool {
return c == nil || c.generation < atomic.LoadUint64(&p.generation)
return c == nil || p.generation.stale(c.desc.ServerID, c.generation)
}

// connect puts the pool into the connected state, allowing it to be used and will allow items to begin being processed from the wait queue
func (p *pool) connect() error {
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
return ErrPoolConnected
}
p.generation.connect()
p.conns.initialize()
return nil
}
Expand All @@ -212,7 +216,7 @@ func (p *pool) disconnect(ctx context.Context) error {
}

p.conns.Close()
atomic.AddUint64(&p.generation, 1)
p.generation.disconnect()

var err error
if dl, ok := ctx.Deadline(); ok {
Expand Down Expand Up @@ -277,7 +281,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {

c.pool = p
c.poolID = atomic.AddUint64(&p.nextid, 1)
c.generation = atomic.LoadUint64(&p.generation)

if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Expand Down Expand Up @@ -310,10 +313,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {

}

func (p *pool) getGeneration() uint64 {
return atomic.LoadUint64(&p.generation)
}

// Checkout returns a connection from the pool
func (p *pool) get(ctx context.Context) (*connection, error) {
if ctx == nil {
Expand Down Expand Up @@ -487,6 +486,10 @@ func (p *pool) closeConnection(c *connection) error {
return nil
}

func (p *pool) getGenerationForNewConnection(serverID *primitive.ObjectID) uint64 {
return p.generation.addConnection(serverID)
}

// removeConnection removes a connection from the pool.
func (p *pool) removeConnection(c *connection, reason string) error {
if c.pool != p {
Expand All @@ -501,6 +504,12 @@ func (p *pool) removeConnection(c *connection, reason string) error {
}
p.Unlock()

// Only update the generation numbers map if the connection has retrieved its generation number. Otherwise, we'd
// decrement the count for the generation even though it had never been incremented.
if c.hasGenerationNumber() {
p.generation.removeConnection(c.desc.ServerID)
}

if publishEvent && p.monitor != nil {
c.pool.monitor.Event(&event.PoolEvent{
Type: event.ConnectionClosed,
Expand Down Expand Up @@ -545,12 +554,13 @@ func (p *pool) put(c *connection) error {
}

// clear clears the pool by incrementing the generation
func (p *pool) clear() {
func (p *pool) clear(serverID *primitive.ObjectID) {
if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.PoolCleared,
Address: p.address.String(),
Type: event.PoolCleared,
Address: p.address.String(),
ServerID: serverID,
})
}
atomic.AddUint64(&p.generation, 1)
p.generation.clear(serverID)
}
Loading