diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index ac28eb830d..0093980064 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -3,6 +3,7 @@ package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver" import ( "context" + "go.mongodb.org/mongo-driver/x/network/address" "go.mongodb.org/mongo-driver/x/network/description" ) @@ -25,4 +26,12 @@ type Connection interface { Description() description.Server Close() error ID() string + Address() address.Address +} + +// ErrorProcessor implementations can handle processing errors, which may modify their internal state. +// If this type is implemented by a Server, then Operation.Execute will call it's ProcessError +// method after it decodes a wire message. +type ErrorProcessor interface { + ProcessError(error) } diff --git a/x/mongo/driverlegacy/topology/DESIGN.md b/x/mongo/driverlegacy/topology/DESIGN.md new file mode 100644 index 0000000000..69abe5850f --- /dev/null +++ b/x/mongo/driverlegacy/topology/DESIGN.md @@ -0,0 +1,16 @@ +# Topology Package Design +This document outlines the design for this package. + +## Connection +Connections are handled by two main types and an auxiliary type. The two main types are `connection` +and `Connection`. The first holds most of the logic required to actually read and write wire +messages. Instances can be created with the `newConnection` method. Inside the `newConnection` +method the auxiliary type, `initConnection` is used to perform the connection handshake. This is +required because the `connection` type does not fully implement `driver.Connection` which is +required during handshaking. The `Connection` type is what is actually returned to a consumer of the +`topology` package. This type does implement the `driver.Connection` type, holds a reference to a +`connection` instance, and exists mainly to prevent accidental continued usage of a connection after +closing it. + +The connection implementations in this package are conduits for wire messages but they have no +ability to encode, decode, or validate wire messages. That must be handled by consumers. diff --git a/x/mongo/driverlegacy/topology/connection.go b/x/mongo/driverlegacy/topology/connection.go index 9d164285aa..0756023902 100644 --- a/x/mongo/driverlegacy/topology/connection.go +++ b/x/mongo/driverlegacy/topology/connection.go @@ -8,21 +8,319 @@ package topology import ( "context" + "crypto/tls" + "errors" + "fmt" + "io" "net" + "sync" + "sync/atomic" + "time" "strings" + "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/network/address" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" "go.mongodb.org/mongo-driver/x/network/wiremessage" ) +// ErrConnectionClosed is returned when attempting to call a method on a Connection that has already +// been closed. +var ErrConnectionClosed = errors.New("the Connection is closed") + +var globalConnectionID uint64 + +func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) } + +type connection struct { + id string + nc net.Conn // When nil, the connection is closed. + addr address.Address + idleTimeout time.Duration + idleDeadline time.Time + lifetimeDeadline time.Time + readTimeout time.Duration + writeTimeout time.Duration + desc description.Server +} + +// newConnection handles the creation of a connection. It will dial, configure TLS, and perform +// initialization handshakes. +func newConnection(ctx context.Context, addr address.Address, opts ...ConnectionOption) (*connection, error) { + cfg, err := newConnectionConfig(opts...) + if err != nil { + return nil, err + } + + nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String()) + if err != nil { + return nil, err + } + + if cfg.tlsConfig != nil { + tlsConfig := cfg.tlsConfig.Clone() + nc, err = configureTLS(ctx, nc, addr, tlsConfig) + if err != nil { + return nil, err + } + } + + var lifetimeDeadline time.Time + if cfg.lifeTimeout > 0 { + lifetimeDeadline = time.Now().Add(cfg.lifeTimeout) + } + + id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID()) + + c := &connection{ + id: id, + nc: nc, + addr: addr, + idleTimeout: cfg.idleTimeout, + lifetimeDeadline: lifetimeDeadline, + readTimeout: cfg.readTimeout, + writeTimeout: cfg.writeTimeout, + } + + c.bumpIdleDeadline() + + // running isMaster and authentication is handled by a handshaker on the configuration instance. + if cfg.handshaker != nil { + c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c}) + if err != nil { + return nil, err + } + } + return c, nil +} + +func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { + var err error + if c.nc == nil { + return ConnectionError{ConnectionID: c.id, message: "connection is closed"} + } + select { + case <-ctx.Done(): + return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"} + default: + } + + var deadline time.Time + if c.writeTimeout != 0 { + deadline = time.Now().Add(c.writeTimeout) + } + + if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { + deadline = dl + } + + if err := c.nc.SetWriteDeadline(deadline); err != nil { + return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"} + } + + _, err = c.nc.Write(wm) + if err != nil { + // TODO(GODRIVER-929): Close connection through the pool. + _ = c.nc.Close() + c.nc = nil + return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to write wire message to network"} + } + + c.bumpIdleDeadline() + return nil +} + +// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten. +func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) { + if c.nc == nil { + return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"} + } + + select { + case <-ctx.Done(): + // We close the connection because we don't know if there is an unread message on the wire. + // TODO(GODRIVER-929): Close connection through the pool. + _ = c.nc.Close() + c.nc = nil + return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"} + default: + } + + var deadline time.Time + if c.readTimeout != 0 { + deadline = time.Now().Add(c.readTimeout) + } + + if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { + deadline = dl + } + + if err := c.nc.SetReadDeadline(deadline); err != nil { + return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"} + } + + // We use an array here because it only costs 4 bytes on the stack and means we'll only need to + // reslice dst once instead of twice. + var sizeBuf [4]byte + + // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst + // because there might be more than one wire message waiting to be read, for example when + // reading messages from an exhaust cursor. + _, err := io.ReadFull(c.nc, sizeBuf[:]) + if err != nil { + // We close the connection because we don't know if there are other bytes left to read. + // TODO(GODRIVER-929): Close connection through the pool. + _ = c.nc.Close() + c.nc = nil + return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to decode message length"} + } + + // read the length as an int32 + size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) + + if int(size) > cap(dst) { + // Since we can't grow this slice without allocating, just allocate an entirely new slice. + dst = make([]byte, 0, size) + } + // We need to ensure we don't accidentally read into a subsequent wire message, so we set the + // size to read exactly this wire message. + dst = dst[:size] + copy(dst, sizeBuf[:]) + + _, err = io.ReadFull(c.nc, dst[4:]) + if err != nil { + // We close the connection because we don't know if there are other bytes left to read. + // TODO(GODRIVER-929): Close connection through the pool. + _ = c.nc.Close() + c.nc = nil + return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to read full message"} + } + + c.bumpIdleDeadline() + return dst, nil +} + +func (c *connection) expired() bool { + now := time.Now() + if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) { + return true + } + + if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) { + return true + } + + return c.nc == nil +} + +func (c *connection) bumpIdleDeadline() { + if c.idleTimeout > 0 { + c.idleDeadline = time.Now().Add(c.idleTimeout) + } +} + +// initConnection is an adapter used during connection initialization. It has the minimum +// functionality necessary to implement the driver.Connection interface, which is required to pass a +// *connection to a Handshaker. +type initConnection struct{ *connection } + +var _ driver.Connection = initConnection{} + +func (c initConnection) Description() description.Server { return description.Server{} } +func (c initConnection) Close() error { return c.nc.Close() } +func (c initConnection) ID() string { return c.id } +func (c initConnection) Address() address.Address { return c.addr } +func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error { + return c.writeWireMessage(ctx, wm) +} +func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) { + return c.readWireMessage(ctx, dst) +} + +// Connection implements the driver.Connection interface. It allows reading and writing wire +// messages. +type Connection struct { + *connection + s *Server + + mu sync.RWMutex +} + +var _ driver.Connection = (*Connection)(nil) + +// WriteWireMessage handles writing a wire message to the underlying connection. +func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.connection == nil { + return ErrConnectionClosed + } + return c.writeWireMessage(ctx, wm) +} + +// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter +// will be overwritten with the new wire message. +func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) { + c.mu.RLock() + defer c.mu.RUnlock() + if c.connection == nil { + return dst, ErrConnectionClosed + } + return c.readWireMessage(ctx, dst) +} + +// Description returns the server description of the server this connection is connected to. +func (c *Connection) Description() description.Server { + c.mu.RLock() + defer c.mu.RUnlock() + if c.connection == nil { + return description.Server{} + } + return c.desc +} + +// Close returns this connection to the connection pool. This method may not close the underlying +// socket. +func (c *Connection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.connection == nil { + return nil + } + // TODO(GODRIVER-929): Return c.connection to the pool. + // TODO(GODRIVER-929): Release an entry in the semaphore. + c.connection = nil + return nil +} + +// ID returns the ID of this connection. +func (c *Connection) ID() string { + c.mu.RLock() + defer c.mu.RUnlock() + if c.connection == nil { + return "" + } + return c.id +} + +// Address returns the address of this connection. +func (c *Connection) Address() address.Address { + c.mu.RLock() + defer c.mu.RUnlock() + if c.connection == nil { + return address.Address("0.0.0.0") + } + return c.addr +} + // sconn is a wrapper around a connection.Connection. This type is returned by // a Server so that it can track network errors and when a non-timeout network // error is returned, the pool on the server can be cleared. type sconn struct { - connection.Connection + connectionlegacy.Connection s *Server id uint64 } @@ -60,7 +358,7 @@ func (sc *sconn) processErr(err error) { return } - ne, ok := err.(connection.Error) + ne, ok := err.(connectionlegacy.Error) if !ok { return } @@ -96,3 +394,33 @@ func isNotMasterError(err command.Error) bool { } return strings.Contains(err.Error(), "not master") } + +func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config) (net.Conn, error) { + if !config.InsecureSkipVerify { + hostname := addr.String() + colonPos := strings.LastIndex(hostname, ":") + if colonPos == -1 { + colonPos = len(hostname) + } + + hostname = hostname[:colonPos] + config.ServerName = hostname + } + + client := tls.Client(nc, config) + + errChan := make(chan error, 1) + go func() { + errChan <- client.Handshake() + }() + + select { + case err := <-errChan: + if err != nil { + return nil, err + } + case <-ctx.Done(): + return nil, errors.New("server connection cancelled/timeout during TLS handshake") + } + return client, nil +} diff --git a/x/mongo/driverlegacy/topology/connection_options.go b/x/mongo/driverlegacy/topology/connection_options.go new file mode 100644 index 0000000000..6bfa2dc9dd --- /dev/null +++ b/x/mongo/driverlegacy/topology/connection_options.go @@ -0,0 +1,159 @@ +package topology + +import ( + "context" + "crypto/tls" + "net" + "time" + + "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/network/address" + "go.mongodb.org/mongo-driver/x/network/description" +) + +// Dialer is used to make network connections. +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// DialerFunc is a type implemented by functions that can be used as a Dialer. +type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// DialContext implements the Dialer interface. +func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return df(ctx, network, address) +} + +// DefaultDialer is the Dialer implementation that is used by this package. Changing this +// will also change the Dialer used for this package. This should only be changed why all +// of the connections being made need to use a different Dialer. Most of the time, using a +// WithDialer option is more appropriate than changing this variable. +var DefaultDialer Dialer = &net.Dialer{} + +// Handshaker is the interface implemented by types that can perform a MongoDB +// handshake over a provided driver.Connection. This is used during connection +// initialization. Implementations must be goroutine safe. +type Handshaker interface { + Handshake(context.Context, address.Address, driver.Connection) (description.Server, error) +} + +// HandshakerFunc is an adapter to allow the use of ordinary functions as +// connection handshakers. +type HandshakerFunc func(context.Context, address.Address, driver.Connection) (description.Server, error) + +// Handshake implements the Handshaker interface. +func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { + return hf(ctx, addr, conn) +} + +type connectionConfig struct { + appName string + connectTimeout time.Duration + dialer Dialer + handshaker Handshaker + idleTimeout time.Duration + lifeTimeout time.Duration + readTimeout time.Duration + writeTimeout time.Duration + tlsConfig *tls.Config +} + +func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { + cfg := &connectionConfig{ + connectTimeout: 30 * time.Second, + dialer: nil, + idleTimeout: 10 * time.Minute, + lifeTimeout: 30 * time.Minute, + } + + for _, opt := range opts { + err := opt(cfg) + if err != nil { + return nil, err + } + } + + if cfg.dialer == nil { + cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout} + } + + return cfg, nil +} + +// ConnectionOption is used to configure a connection. +type ConnectionOption func(*connectionConfig) error + +// WithAppName sets the application name which gets sent to MongoDB when it +// first connects. +func WithAppName(fn func(string) string) ConnectionOption { + return func(c *connectionConfig) error { + c.appName = fn(c.appName) + return nil + } +} + +// WithConnectTimeout configures the maximum amount of time a dial will wait for a +// connect to complete. The default is 30 seconds. +func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption { + return func(c *connectionConfig) error { + c.connectTimeout = fn(c.connectTimeout) + return nil + } +} + +// WithDialer configures the Dialer to use when making a new connection to MongoDB. +func WithDialer(fn func(Dialer) Dialer) ConnectionOption { + return func(c *connectionConfig) error { + c.dialer = fn(c.dialer) + return nil + } +} + +// WithHandshaker configures the Handshaker that wll be used to initialize newly +// dialed connections. +func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption { + return func(c *connectionConfig) error { + c.handshaker = fn(c.handshaker) + return nil + } +} + +// WithIdleTimeout configures the maximum idle time to allow for a connection. +func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption { + return func(c *connectionConfig) error { + c.idleTimeout = fn(c.idleTimeout) + return nil + } +} + +// WithLifeTimeout configures the maximum life of a connection. +func WithLifeTimeout(fn func(time.Duration) time.Duration) ConnectionOption { + return func(c *connectionConfig) error { + c.lifeTimeout = fn(c.lifeTimeout) + return nil + } +} + +// WithReadTimeout configures the maximum read time for a connection. +func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption { + return func(c *connectionConfig) error { + c.readTimeout = fn(c.readTimeout) + return nil + } +} + +// WithWriteTimeout configures the maximum write time for a connection. +func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption { + return func(c *connectionConfig) error { + c.writeTimeout = fn(c.writeTimeout) + return nil + } +} + +// WithTLSConfig configures the TLS options for a connection. +func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption { + return func(c *connectionConfig) error { + c.tlsConfig = fn(c.tlsConfig) + return nil + } +} diff --git a/x/mongo/driverlegacy/topology/connection_test.go b/x/mongo/driverlegacy/topology/connection_test.go index fb44675df1..8a427020f0 100644 --- a/x/mongo/driverlegacy/topology/connection_test.go +++ b/x/mongo/driverlegacy/topology/connection_test.go @@ -8,11 +8,16 @@ package topology import ( "context" + "errors" + "net" "testing" + "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/address" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" "go.mongodb.org/mongo-driver/x/network/wiremessage" ) @@ -33,7 +38,7 @@ func (n netErr) Temporary() bool { } type connect struct { - err *connection.Error + err *connectionlegacy.Error } func (c connect) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error { @@ -67,7 +72,7 @@ func TestConnectionProcessErrSpec(t *testing.T) { s.connectionstate = connected innerErr := netErr{} - connectErr := connection.Error{ConnectionID: "blah", Wrapped: innerErr} + connectErr := connectionlegacy.Error{ConnectionID: "blah", Wrapped: innerErr} c := connect{&connectErr} sc := sconn{c, s, 1} err = sc.WriteWireMessage(ctx, nil) @@ -76,3 +81,366 @@ func TestConnectionProcessErrSpec(t *testing.T) { require.NotNil(t, desc.LastError) require.Equal(t, desc.Kind, (description.ServerKind)(description.Unknown)) } + +func TestConnection(t *testing.T) { + t.Run("connection", func(t *testing.T) { + t.Run("newConnection", func(t *testing.T) { + t.Run("config error", func(t *testing.T) { + want := errors.New("config error") + _, got := newConnection(context.Background(), address.Address(""), ConnectionOption(func(*connectionConfig) error { return want })) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("dialer error", func(t *testing.T) { + want := errors.New("dialer error") + _, got := newConnection(context.Background(), address.Address(""), WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, want }) + })) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("handshaker error", func(t *testing.T) { + want := errors.New("handshaker error") + _, got := newConnection(context.Background(), address.Address(""), + WithHandshaker(func(Handshaker) Handshaker { + return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) { + return description.Server{}, want + }) + }), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Conn(nil), nil + }) + }), + ) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + }) + t.Run("writeWireMessage", func(t *testing.T) { + t.Run("closed connection", func(t *testing.T) { + conn := &connection{id: "foobar"} + want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"} + got := conn.writeWireMessage(context.Background(), []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("completed context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn := &connection{id: "foobar", nc: &net.TCPConn{}} + want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"} + got := conn.writeWireMessage(ctx, []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("deadlines", func(t *testing.T) { + testCases := []struct { + name string + ctxDeadline time.Duration + timeout time.Duration + deadline time.Time + }{ + {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, + {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, + {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, + {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + if tc.ctxDeadline > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline) + defer cancel() + } + want := ConnectionError{ + ConnectionID: "foobar", + Wrapped: errors.New("set writeDeadline error"), + message: "failed to set write deadline", + } + tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")} + conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout} + got := conn.writeWireMessage(ctx, []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tc.deadline.After(tnc.writeDeadline) { + t.Errorf("write deadline not properly set. got %v; want %v", tnc.writeDeadline, tc.deadline) + } + }) + } + }) + t.Run("Write", func(t *testing.T) { + t.Run("error", func(t *testing.T) { + err := errors.New("Write error") + want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to write wire message to network"} + tnc := &testNetConn{writeerr: err} + conn := &connection{id: "foobar", nc: tnc} + got := conn.writeWireMessage(context.Background(), []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tnc.closed { + t.Errorf("failed to close net.Conn after error writing bytes.") + } + }) + tnc := &testNetConn{} + conn := &connection{id: "foobar", nc: tnc} + want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} + err := conn.writeWireMessage(context.Background(), want) + noerr(t, err) + got := tnc.buf + if !cmp.Equal(got, want) { + t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want) + } + }) + }) + t.Run("readWireMessage", func(t *testing.T) { + t.Run("closed connection", func(t *testing.T) { + conn := &connection{id: "foobar"} + want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"} + _, got := conn.readWireMessage(context.Background(), []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("completed context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn := &connection{id: "foobar", nc: &net.TCPConn{}} + want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"} + _, got := conn.readWireMessage(ctx, []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + }) + t.Run("deadlines", func(t *testing.T) { + testCases := []struct { + name string + ctxDeadline time.Duration + timeout time.Duration + deadline time.Time + }{ + {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, + {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, + {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, + {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + if tc.ctxDeadline > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline) + defer cancel() + } + want := ConnectionError{ + ConnectionID: "foobar", + Wrapped: errors.New("set readDeadline error"), + message: "failed to set read deadline", + } + tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")} + conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout} + _, got := conn.readWireMessage(ctx, []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tc.deadline.After(tnc.readDeadline) { + t.Errorf("read deadline not properly set. got %v; want %v", tnc.readDeadline, tc.deadline) + } + }) + } + }) + t.Run("Read (size)", func(t *testing.T) { + err := errors.New("Read error") + want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to decode message length"} + tnc := &testNetConn{readerr: err} + conn := &connection{id: "foobar", nc: tnc} + _, got := conn.readWireMessage(context.Background(), []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tnc.closed { + t.Errorf("failed to close net.Conn after error writing bytes.") + } + }) + t.Run("Read (wire message)", func(t *testing.T) { + err := errors.New("Read error") + want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "unable to read full message"} + tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}} + conn := &connection{id: "foobar", nc: tnc} + _, got := conn.readWireMessage(context.Background(), []byte{}) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tnc.closed { + t.Errorf("failed to close net.Conn after error writing bytes.") + } + }) + t.Run("Read (success)", func(t *testing.T) { + want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} + tnc := &testNetConn{buf: make([]byte, len(want))} + copy(tnc.buf, want) + conn := &connection{id: "foobar", nc: tnc} + got, err := conn.readWireMessage(context.Background(), nil) + noerr(t, err) + if !cmp.Equal(got, want) { + t.Errorf("did not read full wire message. got %v; want %v", got, want) + } + }) + }) + }) + t.Run("Connection", func(t *testing.T) { + t.Run("nil connection does not panic", func(t *testing.T) { + conn := &Connection{} + defer func() { + if r := recover(); r != nil { + t.Fatalf("Methods on a Connection with a nil *connection should not panic, but panicked with %v", r) + } + }() + + var want, got interface{} + + want = ErrConnectionClosed + got = conn.WriteWireMessage(context.Background(), nil) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + _, got = conn.ReadWireMessage(context.Background(), nil) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + + want = description.Server{} + got = conn.Description() + if !cmp.Equal(got, want) { + t.Errorf("descriptions do not match. got %v; want %v", got, want) + } + + want = nil + got = conn.Close() + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + + want = "" + got = conn.ID() + if !cmp.Equal(got, want) { + t.Errorf("IDs do not match. got %v; want %v", got, want) + } + + want = address.Address("0.0.0.0") + got = conn.Address() + if !cmp.Equal(got, want) { + t.Errorf("Addresses do not match. got %v; want %v", got, want) + } + }) + }) +} + +type testNetConn struct { + nc net.Conn + buf []byte + + deadlineerr error + writeerr error + readerr error + closed bool + + deadline time.Time + readDeadline time.Time + writeDeadline time.Time +} + +func (tnc *testNetConn) Read(b []byte) (n int, err error) { + if len(tnc.buf) > 0 { + n := copy(b, tnc.buf) + tnc.buf = tnc.buf[n:] + return n, nil + } + if tnc.readerr != nil { + return 0, tnc.readerr + } + if tnc.nc == nil { + return 0, nil + } + return tnc.nc.Read(b) +} + +func (tnc *testNetConn) Write(b []byte) (n int, err error) { + if tnc.writeerr != nil { + return 0, tnc.writeerr + } + if tnc.nc == nil { + idx := len(tnc.buf) + tnc.buf = append(tnc.buf, make([]byte, len(b))...) + copy(tnc.buf[idx:], b) + return len(b), nil + } + return tnc.nc.Write(b) +} + +func (tnc *testNetConn) Close() error { + tnc.closed = true + if tnc.nc == nil { + return nil + } + return tnc.nc.Close() +} + +func (tnc *testNetConn) LocalAddr() net.Addr { + if tnc.nc == nil { + return nil + } + return tnc.nc.LocalAddr() +} + +func (tnc *testNetConn) RemoteAddr() net.Addr { + if tnc.nc == nil { + return nil + } + return tnc.nc.RemoteAddr() +} + +func (tnc *testNetConn) SetDeadline(t time.Time) error { + tnc.deadline = t + if tnc.deadlineerr != nil { + return tnc.deadlineerr + } + if tnc.nc == nil { + return nil + } + return tnc.nc.SetDeadline(t) +} + +func (tnc *testNetConn) SetReadDeadline(t time.Time) error { + tnc.readDeadline = t + if tnc.deadlineerr != nil { + return tnc.deadlineerr + } + if tnc.nc == nil { + return nil + } + return tnc.nc.SetReadDeadline(t) +} + +func (tnc *testNetConn) SetWriteDeadline(t time.Time) error { + tnc.writeDeadline = t + if tnc.deadlineerr != nil { + return tnc.deadlineerr + } + if tnc.nc == nil { + return nil + } + return tnc.nc.SetWriteDeadline(t) +} diff --git a/x/mongo/driverlegacy/topology/errors.go b/x/mongo/driverlegacy/topology/errors.go new file mode 100644 index 0000000000..a6fbf12685 --- /dev/null +++ b/x/mongo/driverlegacy/topology/errors.go @@ -0,0 +1,19 @@ +package topology + +import "fmt" + +// ConnectionError represents a connection error. +type ConnectionError struct { + ConnectionID string + Wrapped error + + message string +} + +// Error implements the error interface. +func (e ConnectionError) Error() string { + if e.Wrapped != nil { + return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, e.message, e.Wrapped.Error()) + } + return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message) +} diff --git a/x/mongo/driverlegacy/topology/server.go b/x/mongo/driverlegacy/topology/server.go index 8dccbc5fea..71c8f77496 100644 --- a/x/mongo/driverlegacy/topology/server.go +++ b/x/mongo/driverlegacy/topology/server.go @@ -20,7 +20,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/address" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" "go.mongodb.org/mongo-driver/x/network/result" ) @@ -87,7 +87,7 @@ type Server struct { done chan struct{} checkNow chan struct{} closewg sync.WaitGroup - pool connection.Pool + pool connectionlegacy.Pool desc atomic.Value // holds a description.Server @@ -144,7 +144,7 @@ func NewServer(addr address.Address, topo func(description.Server), opts ...Serv maxConns = uint64(cfg.maxConns) } - s.pool, err = connection.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...) + s.pool, err = connectionlegacy.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...) if err != nil { return nil, err } @@ -195,7 +195,7 @@ func (s *Server) Disconnect(ctx context.Context) error { } // Connection gets a connection to the server. -func (s *Server) Connection(ctx context.Context) (connection.Connection, error) { +func (s *Server) Connection(ctx context.Context) (connectionlegacy.Connection, error) { if atomic.LoadInt32(&s.connectionstate) != connected { return nil, ErrServerClosed } @@ -205,7 +205,7 @@ func (s *Server) Connection(ctx context.Context) (connection.Connection, error) // authentication error --> drain connection _ = s.pool.Drain() } - if _, ok := err.(*connection.NetworkError); ok { + if _, ok := err.(*connectionlegacy.NetworkError); ok { // update description to unknown and clears the connection pool if desc != nil { desc.Kind = description.Unknown @@ -326,7 +326,7 @@ func (s *Server) update() { } }() - var conn connection.Connection + var conn connectionlegacy.Connection var desc description.Server desc, conn = s.heartbeat(nil) @@ -406,7 +406,7 @@ func (s *Server) updateDescription(desc description.Server, initial bool) { } // heartbeat sends a heartbeat to the server using the given connection. The connection can be nil. -func (s *Server) heartbeat(conn connection.Connection) (description.Server, connection.Connection) { +func (s *Server) heartbeat(conn connectionlegacy.Connection) (description.Server, connectionlegacy.Connection) { const maxRetry = 2 var saved error var desc description.Server @@ -421,30 +421,30 @@ func (s *Server) heartbeat(conn connection.Connection) (description.Server, conn } if conn == nil { - opts := []connection.Option{ - connection.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - connection.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - connection.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + opts := []connectionlegacy.Option{ + connectionlegacy.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + connectionlegacy.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + connectionlegacy.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), } opts = append(opts, s.cfg.connectionOpts...) // We override whatever handshaker is currently attached to the options with an empty // one because need to make sure we don't do auth. - opts = append(opts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { + opts = append(opts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { return nil })) // Override any command monitors specified in options with nil to avoid monitoring heartbeats. - opts = append(opts, connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + opts = append(opts, connectionlegacy.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil })) - conn, _, err = connection.New(ctx, s.address, opts...) + conn, _, err = connectionlegacy.New(ctx, s.address, opts...) if err != nil { saved = err if conn != nil { conn.Close() } conn = nil - if _, ok := err.(*connection.NetworkError); ok { + if _, ok := err.(*connectionlegacy.NetworkError); ok { _ = s.pool.Drain() // If the server is not connected, give up and exit loop if s.Description().Kind == description.Unknown { @@ -464,7 +464,7 @@ func (s *Server) heartbeat(conn connection.Connection) (description.Server, conn saved = err conn.Close() conn = nil - if _, ok := err.(connection.NetworkError); ok { + if _, ok := err.(connectionlegacy.NetworkError); ok { _ = s.pool.Drain() // If the server is not connected, give up and exit loop if s.Description().Kind == description.Unknown { diff --git a/x/mongo/driverlegacy/topology/server_options.go b/x/mongo/driverlegacy/topology/server_options.go index abd0746643..1c0bb2ed9d 100644 --- a/x/mongo/driverlegacy/topology/server_options.go +++ b/x/mongo/driverlegacy/topology/server_options.go @@ -12,7 +12,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" ) var defaultRegistry = bson.NewRegistryBuilder().Build() @@ -20,7 +20,7 @@ var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { clock *session.ClusterClock compressionOpts []string - connectionOpts []connection.Option + connectionOpts []connectionlegacy.Option appname string heartbeatInterval time.Duration heartbeatTimeout time.Duration @@ -52,7 +52,7 @@ func newServerConfig(opts ...ServerOption) (*serverConfig, error) { type ServerOption func(*serverConfig) error // WithConnectionOptions configures the server's connections. -func WithConnectionOptions(fn func(...connection.Option) []connection.Option) ServerOption { +func WithConnectionOptions(fn func(...connectionlegacy.Option) []connectionlegacy.Option) ServerOption { return func(cfg *serverConfig) error { cfg.connectionOpts = fn(cfg.connectionOpts...) return nil diff --git a/x/mongo/driverlegacy/topology/server_test.go b/x/mongo/driverlegacy/topology/server_test.go index e7cb92ff80..58209b8298 100644 --- a/x/mongo/driverlegacy/topology/server_test.go +++ b/x/mongo/driverlegacy/topology/server_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/address" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" "go.mongodb.org/mongo-driver/x/network/result" ) @@ -26,12 +26,12 @@ type pool struct { desc *description.Server } -func (p *pool) Get(ctx context.Context) (connection.Connection, *description.Server, error) { +func (p *pool) Get(ctx context.Context) (connectionlegacy.Connection, *description.Server, error) { if p.connectionError { return nil, p.desc, &auth.Error{} } if p.networkError { - return nil, p.desc, &connection.NetworkError{} + return nil, p.desc, &connectionlegacy.NetworkError{} } return nil, p.desc, nil } @@ -49,7 +49,7 @@ func (p *pool) Drain() error { return nil } -func NewPool(connectionError bool, networkError bool, desc *description.Server) (connection.Pool, error) { +func NewPool(connectionError bool, networkError bool, desc *description.Server) (connectionlegacy.Pool, error) { p := &pool{ connectionError: connectionError, networkError: networkError, diff --git a/x/mongo/driverlegacy/topology/topology_options.go b/x/mongo/driverlegacy/topology/topology_options.go index 2c041f01a2..ba58d9207b 100644 --- a/x/mongo/driverlegacy/topology/topology_options.go +++ b/x/mongo/driverlegacy/topology/topology_options.go @@ -13,7 +13,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" + connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" ) @@ -55,10 +55,10 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option c.serverSelectionTimeout = cs.ServerSelectionTimeout } - var connOpts []connection.Option + var connOpts []connectionlegacy.Option if cs.AppName != "" { - connOpts = append(connOpts, connection.WithAppName(func(string) string { return cs.AppName })) + connOpts = append(connOpts, connectionlegacy.WithAppName(func(string) string { return cs.AppName })) } switch cs.Connect { @@ -70,14 +70,14 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option if cs.ConnectTimeout > 0 { c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) - connOpts = append(connOpts, connection.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) + connOpts = append(connOpts, connectionlegacy.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) } if cs.SocketTimeoutSet { connOpts = append( connOpts, - connection.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), - connection.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), + connectionlegacy.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), + connectionlegacy.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), ) } @@ -86,7 +86,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option } if cs.MaxConnIdleTime > 0 { - connOpts = append(connOpts, connection.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime })) + connOpts = append(connOpts, connectionlegacy.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime })) } if cs.MaxPoolSizeSet { @@ -100,7 +100,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option var x509Username string if cs.SSL { - tlsConfig := connection.NewTLSConfig() + tlsConfig := connectionlegacy.NewTLSConfig() if cs.SSLCaFileSet { err := tlsConfig.AddCACertFromFile(cs.SSLCaFile) @@ -137,7 +137,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option x509Username = b.String() } - connOpts = append(connOpts, connection.WithTLSConfig(func(*connection.TLSConfig) *connection.TLSConfig { return tlsConfig })) + connOpts = append(connOpts, connectionlegacy.WithTLSConfig(func(*connectionlegacy.TLSConfig) *connectionlegacy.TLSConfig { return tlsConfig })) } if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI { @@ -170,7 +170,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option return err } - connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { + connOpts = append(connOpts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { options := &auth.HandshakeOptions{ AppName: cs.AppName, Authenticator: authenticator, @@ -184,19 +184,19 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option })) } else { // We need to add a non-auth Handshaker to the connection options - connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { + connOpts = append(connOpts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors} })) } if len(cs.Compressors) > 0 { - connOpts = append(connOpts, connection.WithCompressors(func(compressors []string) []string { + connOpts = append(connOpts, connectionlegacy.WithCompressors(func(compressors []string) []string { return append(compressors, cs.Compressors...) })) for _, comp := range cs.Compressors { if comp == "zlib" { - connOpts = append(connOpts, connection.WithZlibLevel(func(level *int) *int { + connOpts = append(connOpts, connectionlegacy.WithZlibLevel(func(level *int) *int { return &cs.ZlibLevel })) } @@ -208,7 +208,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option } if len(connOpts) > 0 { - c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connectionlegacy.Option) []connectionlegacy.Option { return append(opts, connOpts...) })) } diff --git a/x/mongo/driverlegacy/topology/topology_test.go b/x/mongo/driverlegacy/topology/topology_test.go index 70c31b2e98..ccb1059ebc 100644 --- a/x/mongo/driverlegacy/topology/topology_test.go +++ b/x/mongo/driverlegacy/topology/topology_test.go @@ -27,6 +27,22 @@ func noerr(t *testing.T, err error) { } } +func compareErrors(err1, err2 error) bool { + if err1 == nil && err2 == nil { + return true + } + + if err1 == nil || err2 == nil { + return false + } + + if err1.Error() != err2.Error() { + return false + } + + return true +} + func TestServerSelection(t *testing.T) { var selectFirst description.ServerSelectorFunc = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) { if len(candidates) == 0 {