From 8748f69dc04ca3ba3757273fb923c4ff14f12f58 Mon Sep 17 00:00:00 2001 From: Kris Brandow Date: Mon, 15 Apr 2019 11:37:21 -0400 Subject: [PATCH] Update topology.Server type GODRIVER-932 Change-Id: I1234c66952cde375a272052f87f9753ab14e7706 --- .errcheck-excludes | 2 + internal/testutil/config.go | 15 +- internal/testutil/ops.go | 6 +- mongo/change_stream.go | 2 +- mongo/client.go | 39 +- mongo/retryable_writes_test.go | 7 +- mongo/sessions_test.go | 9 +- mongo/transactions_test.go | 2 +- x/mongo/driver/batches.go | 66 ++ x/mongo/driver/command.generated.go | 96 +++ x/mongo/driver/command.go | 75 ++ x/mongo/driver/driver.go | 101 ++- x/mongo/driver/drivertest/channel_conn.go | 72 ++ x/mongo/driver/errors.go | 290 +++++++ x/mongo/driver/ismaster.go | 153 ++++ x/mongo/driver/operation.go | 740 ++++++++++++++++++ x/mongo/driver/wiremessage/wiremessage.go | 236 ++++++ x/mongo/driverlegacy/abort_transaction.go | 2 +- x/mongo/driverlegacy/aggregate.go | 2 +- x/mongo/driverlegacy/auth/auth.go | 69 +- x/mongo/driverlegacy/auth/auth_test.go | 90 ++- x/mongo/driverlegacy/auth/default.go | 6 +- x/mongo/driverlegacy/auth/gssapi.go | 6 +- x/mongo/driverlegacy/auth/mongodbcr.go | 32 +- x/mongo/driverlegacy/auth/mongodbcr_test.go | 81 +- x/mongo/driverlegacy/auth/plain.go | 6 +- x/mongo/driverlegacy/auth/plain_test.go | 110 +-- x/mongo/driverlegacy/auth/sasl.go | 43 +- x/mongo/driverlegacy/auth/scram.go | 6 +- x/mongo/driverlegacy/auth/x509.go | 25 +- x/mongo/driverlegacy/batch_cursor.go | 8 +- x/mongo/driverlegacy/commit_transaction.go | 2 +- x/mongo/driverlegacy/count.go | 2 +- x/mongo/driverlegacy/count_documents.go | 2 +- x/mongo/driverlegacy/create_indexes.go | 2 +- x/mongo/driverlegacy/delete.go | 2 +- x/mongo/driverlegacy/delete_indexes.go | 2 +- x/mongo/driverlegacy/distinct.go | 2 +- x/mongo/driverlegacy/drop_collection.go | 2 +- x/mongo/driverlegacy/drop_database.go | 2 +- x/mongo/driverlegacy/end_sessions.go | 2 +- x/mongo/driverlegacy/find.go | 2 +- x/mongo/driverlegacy/find_one_and_delete.go | 2 +- x/mongo/driverlegacy/find_one_and_replace.go | 2 +- x/mongo/driverlegacy/find_one_and_update.go | 2 +- x/mongo/driverlegacy/insert.go | 2 +- x/mongo/driverlegacy/kill_cursors.go | 2 +- x/mongo/driverlegacy/list_collections.go | 2 +- x/mongo/driverlegacy/list_databases.go | 2 +- x/mongo/driverlegacy/list_indexes.go | 2 +- x/mongo/driverlegacy/read.go | 2 +- x/mongo/driverlegacy/read_cursor.go | 2 +- x/mongo/driverlegacy/topology/DESIGN.md | 3 + x/mongo/driverlegacy/topology/connection.go | 46 +- .../topology/connection_legacy.go | 21 +- .../topology/connection_options.go | 20 +- .../driverlegacy/topology/connection_test.go | 36 +- x/mongo/driverlegacy/topology/errors.go | 3 + x/mongo/driverlegacy/topology/pool_test.go | 5 +- x/mongo/driverlegacy/topology/server.go | 240 ++++-- .../driverlegacy/topology/server_options.go | 5 +- x/mongo/driverlegacy/topology/server_test.go | 76 +- x/mongo/driverlegacy/topology/topology.go | 10 +- .../driverlegacy/topology/topology_options.go | 29 +- .../driverlegacy/topology/topology_test.go | 7 +- x/mongo/driverlegacy/update.go | 2 +- x/mongo/driverlegacy/write.go | 2 +- x/network/examples/server_monitoring/main.go | 7 +- x/network/examples/workload/main.go | 2 +- x/network/integration/aggregate_test.go | 6 +- x/network/integration/command_test.go | 10 +- x/network/integration/compressor_test.go | 2 +- .../integration/list_collections_test.go | 2 +- x/network/integration/list_databases_test.go | 2 +- x/network/integration/main_test.go | 8 +- x/network/integration/opmsg_test.go | 2 +- x/network/integration/server_test.go | 75 +- x/network/integration/topology_test.go | 6 +- 78 files changed, 2548 insertions(+), 518 deletions(-) create mode 100644 x/mongo/driver/batches.go create mode 100644 x/mongo/driver/command.generated.go create mode 100644 x/mongo/driver/command.go create mode 100644 x/mongo/driver/drivertest/channel_conn.go create mode 100644 x/mongo/driver/errors.go create mode 100644 x/mongo/driver/ismaster.go create mode 100644 x/mongo/driver/operation.go create mode 100644 x/mongo/driver/wiremessage/wiremessage.go diff --git a/.errcheck-excludes b/.errcheck-excludes index 1cedf0b47f..f820488c70 100644 --- a/.errcheck-excludes +++ b/.errcheck-excludes @@ -1,5 +1,7 @@ +(go.mongodb.org/mongo-driver/x/mongo/driver.Connection).Close (*go.mongodb.org/mongo-driver/x/network/connection.connection).Close (go.mongodb.org/mongo-driver/x/network/connection.Connection).Close +(*go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology.connection).close (*go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology.Subscription).Unsubscribe (*go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology.Server).Close (*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection diff --git a/internal/testutil/config.go b/internal/testutil/config.go index 7e3ac9ea79..7fbd62536d 100644 --- a/internal/testutil/config.go +++ b/internal/testutil/config.go @@ -22,7 +22,6 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" "go.mongodb.org/mongo-driver/x/network/description" ) @@ -86,10 +85,10 @@ func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonito topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { return append( opts, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append( opts, - connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return monitor }), ) @@ -106,7 +105,7 @@ func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonito s, err := monitoredTopology.SelectServer(context.Background(), description.WriteSelector()) require.NoError(t, err) - c, err := s.Connection(context.Background()) + c, err := s.ConnectionLegacy(context.Background()) require.NoError(t, err) _, err = (&command.Write{ @@ -128,10 +127,10 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { return append( opts, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append( opts, - connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return monitor }), ) @@ -150,7 +149,7 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol s, err := monitoredTopology.SelectServer(context.Background(), description.WriteSelector()) require.NoError(t, err) - c, err := s.Connection(context.Background()) + c, err := s.ConnectionLegacy(context.Background()) require.NoError(t, err) _, err = (&command.Write{ @@ -182,7 +181,7 @@ func Topology(t *testing.T) *topology.Topology { s, err := liveTopology.SelectServer(context.Background(), description.WriteSelector()) require.NoError(t, err) - c, err := s.Connection(context.Background()) + c, err := s.ConnectionLegacy(context.Background()) require.NoError(t, err) _, err = (&command.Write{ diff --git a/internal/testutil/ops.go b/internal/testutil/ops.go index 8b533f8dd9..823be6df93 100644 --- a/internal/testutil/ops.go +++ b/internal/testutil/ops.go @@ -119,7 +119,7 @@ func EnableMaxTimeFailPoint(t *testing.T, s *topology.Server) error { {"mode", bsonx.String("alwaysOn")}, }, } - conn, err := s.Connection(context.Background()) + conn, err := s.ConnectionLegacy(context.Background()) require.NoError(t, err) defer testhelpers.RequireNoErrorOnClose(t, conn) _, err = cmd.RoundTrip(context.Background(), s.SelectedDescription(), conn) @@ -135,7 +135,7 @@ func DisableMaxTimeFailPoint(t *testing.T, s *topology.Server) { {"mode", bsonx.String("off")}, }, } - conn, err := s.Connection(context.Background()) + conn, err := s.ConnectionLegacy(context.Background()) require.NoError(t, err) defer testhelpers.RequireNoErrorOnClose(t, conn) _, err = cmd.RoundTrip(context.Background(), s.SelectedDescription(), conn) @@ -144,7 +144,7 @@ func DisableMaxTimeFailPoint(t *testing.T, s *topology.Server) { // RunCommand runs an arbitrary command on a given database of target server func RunCommand(t *testing.T, s *topology.Server, db string, b bsonx.Doc) (bson.Raw, error) { - conn, err := s.Connection(context.Background()) + conn, err := s.ConnectionLegacy(context.Background()) if err != nil { return nil, err } diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 351dc3194c..0ba58294fb 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -181,7 +181,7 @@ func (cs *ChangeStream) runCommand(ctx context.Context, replaceOptions bool) err } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return replaceErrors(err) } diff --git a/mongo/client.go b/mongo/client.go index 62a61414f8..653c24c946 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "crypto/tls" "strings" "time" @@ -18,13 +19,13 @@ import ( "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/uuid" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" "go.mongodb.org/mongo-driver/x/network/description" ) @@ -195,7 +196,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { return err } - var connOpts []connection.Option + var connOpts []topology.ConnectionOption var serverOpts []topology.ServerOption var topologyOpts []topology.Option @@ -211,7 +212,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { if len(opts.Compressors) > 0 { comps = opts.Compressors - connOpts = append(connOpts, connection.WithCompressors( + connOpts = append(connOpts, topology.WithCompressors( func(compressors []string) []string { return append(compressors, comps...) }, @@ -219,7 +220,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { for _, comp := range comps { if comp == "zlib" { - connOpts = append(connOpts, connection.WithZlibLevel(func(level *int) *int { + connOpts = append(connOpts, topology.WithZlibLevel(func(level *int) *int { return opts.ZlibLevel })) } @@ -230,8 +231,8 @@ func (c *Client) configure(opts *options.ClientOptions) error { )) } // Handshaker - var handshaker = func(connection.Handshaker) connection.Handshaker { - return &command.Handshake{Client: command.ClientDoc(appName), Compressors: comps} + var handshaker = func(driver.Handshaker) driver.Handshaker { + return driver.IsMaster().AppName(appName).Compressors(comps) } // Auth & Database & Password & Username if opts.Auth != nil { @@ -274,24 +275,24 @@ func (c *Client) configure(opts *options.ClientOptions) error { } } - handshaker = func(connection.Handshaker) connection.Handshaker { + handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) } } - connOpts = append(connOpts, connection.WithHandshaker(handshaker)) + connOpts = append(connOpts, topology.WithHandshaker(handshaker)) // ConnectTimeout if opts.ConnectTimeout != nil { serverOpts = append(serverOpts, topology.WithHeartbeatTimeout( func(time.Duration) time.Duration { return *opts.ConnectTimeout }, )) - connOpts = append(connOpts, connection.WithConnectTimeout( + connOpts = append(connOpts, topology.WithConnectTimeout( func(time.Duration) time.Duration { return *opts.ConnectTimeout }, )) } // Dialer if opts.Dialer != nil { - connOpts = append(connOpts, connection.WithDialer( - func(connection.Dialer) connection.Dialer { return opts.Dialer }, + connOpts = append(connOpts, topology.WithDialer( + func(topology.Dialer) topology.Dialer { return opts.Dialer }, )) } // Direct @@ -320,7 +321,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { } // MaxConIdleTime if opts.MaxConnIdleTime != nil { - connOpts = append(connOpts, connection.WithIdleTimeout( + connOpts = append(connOpts, topology.WithIdleTimeout( func(time.Duration) time.Duration { return *opts.MaxConnIdleTime }, )) } @@ -334,7 +335,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { } // Monitor if opts.Monitor != nil { - connOpts = append(connOpts, connection.WithMonitor( + connOpts = append(connOpts, topology.WithMonitor( func(*event.CommandMonitor) *event.CommandMonitor { return opts.Monitor }, )) } @@ -373,15 +374,15 @@ func (c *Client) configure(opts *options.ClientOptions) error { if opts.SocketTimeout != nil { connOpts = append( connOpts, - connection.WithReadTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), - connection.WithWriteTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), + topology.WithReadTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), + topology.WithWriteTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), ) } // TLSConfig if opts.TLSConfig != nil { - connOpts = append(connOpts, connection.WithTLSConfig( - func(*connection.TLSConfig) *connection.TLSConfig { - return &connection.TLSConfig{Config: opts.TLSConfig} + connOpts = append(connOpts, topology.WithTLSConfig( + func(*tls.Config) *tls.Config { + return opts.TLSConfig }, )) } @@ -396,7 +397,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { serverOpts = append( serverOpts, topology.WithClock(func(*session.ClusterClock) *session.ClusterClock { return c.clock }), - topology.WithConnectionOptions(func(...connection.Option) []connection.Option { return connOpts }), + topology.WithConnectionOptions(func(...topology.ConnectionOption) []topology.ConnectionOption { return connOpts }), ) c.topologyOptions = append(topologyOpts, topology.WithServerOptions( func(...topology.ServerOption) []topology.ServerOption { return serverOpts }, diff --git a/mongo/retryable_writes_test.go b/mongo/retryable_writes_test.go index 754455749c..74164f2346 100644 --- a/mongo/retryable_writes_test.go +++ b/mongo/retryable_writes_test.go @@ -24,14 +24,13 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/testutil" - "go.mongodb.org/mongo-driver/internal/testutil/helpers" + testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" - "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" ) @@ -361,10 +360,10 @@ func createRetryMonitoredTopology(t *testing.T, clock *session.ClusterClock, mon topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { return append( opts, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append( opts, - connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return monitor }), ) diff --git a/mongo/sessions_test.go b/mongo/sessions_test.go index dce4afc4e9..628241938c 100644 --- a/mongo/sessions_test.go +++ b/mongo/sessions_test.go @@ -22,7 +22,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/testutil" - "go.mongodb.org/mongo-driver/internal/testutil/helpers" + testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -31,7 +31,6 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" "go.mongodb.org/mongo-driver/x/network/description" ) @@ -182,10 +181,10 @@ func createMonitoredTopology(t *testing.T, clock *session.ClusterClock, monitor topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { return append( opts, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append( opts, - connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return monitor }), ) @@ -211,7 +210,7 @@ func createMonitoredTopology(t *testing.T, clock *session.ClusterClock, monitor t.Fatal(err) } - c, err := s.Connection(context.Background()) + c, err := s.ConnectionLegacy(context.Background()) if err != nil { t.Fatal(err) } diff --git a/mongo/transactions_test.go b/mongo/transactions_test.go index cafb8162bb..a444546465 100644 --- a/mongo/transactions_test.go +++ b/mongo/transactions_test.go @@ -334,7 +334,7 @@ func killSessions(t *testing.T, client *Client) { DB: "admin", Command: bsonx.Doc{{"killAllSessions", bsonx.Array(vals)}}, } - conn, err := s.Connection(ctx) + conn, err := s.ConnectionLegacy(ctx) require.NoError(t, err) defer testhelpers.RequireNoErrorOnClose(t, conn) // ignore the error because command kills its own implicit session diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go new file mode 100644 index 0000000000..7c4d7bacc1 --- /dev/null +++ b/x/mongo/driver/batches.go @@ -0,0 +1,66 @@ +package driver + +import ( + "errors" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// this is the amount of reserved buffer space in a message that the +// driver reserves for command overhead. +const reservedCommandBufferBytes = 16 * 10 * 10 * 10 + +// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a +// server is passed to an insert command. +var ErrDocumentTooLarge = errors.New("an inserted document is too large") + +// Batches contains the necessary information to batch split an operation. This is only used for write +// oeprations. +type Batches struct { + Identifier string + Documents []bsoncore.Document + Current []bsoncore.Document + Ordered *bool +} + +// Valid returns true if Batches contains both an identifier and the length of Documents is greater +// than zero. +func (b *Batches) Valid() bool { return b != nil && b.Identifier != "" && len(b.Documents) > 0 } + +// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the +// next batch. +func (b *Batches) ClearBatch() { b.Current = b.Current[:0] } + +// AdvanceBatch splits the next batch using maxCount and targetBbatchSize. This method will do nothing if +// the current batch has not been cleared. We do this so that when this is called during execute we +// can call it without first needing to check if we already have a batch, which makes the code +// simpler and makes retrying easier. +func (b *Batches) AdvanceBatch(maxCount, targetBatchSize int) error { + if len(b.Current) > 0 { + return nil + } + if targetBatchSize > reservedCommandBufferBytes { + targetBatchSize -= reservedCommandBufferBytes + } + + if maxCount <= 0 { + maxCount = 1 + } + + splitAfter := 0 + size := 1 + for _, doc := range b.Documents { + if len(doc) > targetBatchSize { + return ErrDocumentTooLarge + } + if size+len(doc) > targetBatchSize { + break + } + + size += len(doc) + splitAfter++ + } + + b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:] + return nil +} diff --git a/x/mongo/driver/command.generated.go b/x/mongo/driver/command.generated.go new file mode 100644 index 0000000000..812def6fbd --- /dev/null +++ b/x/mongo/driver/command.generated.go @@ -0,0 +1,96 @@ +// Code will be generated by drivergen. DO NOT EDIT. + +package driver + +import ( + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" + "go.mongodb.org/mongo-driver/x/network/description" +) + +// Command constructs and returns a new CommandOperation. +func Command(cmd bsoncore.Document) *CommandOperation { + return &CommandOperation{cmd: cmd} +} + +// Session sets the session for this operation. +func (co *CommandOperation) Session(client *session.Client) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.client = client + return co +} + +// Clock sets the cluster clock for this operation. +func (co *CommandOperation) Clock(clock *session.ClusterClock) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.clock = clock + return co +} + +// Command sets the command that will be run. +func (co *CommandOperation) Command(cmd bsoncore.Document) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.cmd = cmd + return co +} + +// Deployment sets the Deployment to run the command against. +func (co *CommandOperation) Deployment(d Deployment) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.d = d + return co +} + +// Database sets the database to run the command against. +func (co *CommandOperation) Database(database string) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.database = database + return co +} + +// ReadConcern sets the read concern to use when running the command. +func (co *CommandOperation) ReadConcern(rc *readconcern.ReadConcern) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.rc = rc + return co +} + +// ReadPreference sets the read preference for this operation. +func (co *CommandOperation) ReadPreference(readPref *readpref.ReadPref) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.readPref = readPref + return co +} + +// ServerSelector sets the server selector for this operaiton. +func (co *CommandOperation) ServerSelector(selector description.ServerSelector) *CommandOperation { + if co == nil { + co = new(CommandOperation) + } + + co.selector = selector + return co +} diff --git a/x/mongo/driver/command.go b/x/mongo/driver/command.go new file mode 100644 index 0000000000..11e066d87d --- /dev/null +++ b/x/mongo/driver/command.go @@ -0,0 +1,75 @@ +package driver + +import ( + "context" + "errors" + + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" + "go.mongodb.org/mongo-driver/x/network/description" +) + +// CommandOperation is used to run a generic operation against a server. +type CommandOperation struct { + _ struct{} `drivergen:"-"` + // Command sets the command that will be run. + cmd bsoncore.Document `drivergen:"Command,constructorArg"` + // ReadConcern sets the read concern to use when running the command. + rc *readconcern.ReadConcern `drivergen:"ReadConcern,pointerExempt"` + + // Database sets the database to run the command against. + database string + // Deployment sets the Deployment to run the command against. + d Deployment `drivergen:"Deployment"` + + selector description.ServerSelector `drivergen:"ServerSelector"` + readPref *readpref.ReadPref `drivergen:"ReadPreference,pointerExempt"` + clock *session.ClusterClock `drivergen:"Clock,pointerExempt"` + client *session.Client `drivergen:"Session,pointerExempt"` + + result bsoncore.Document `drivergen:"-"` +} + +// Result returns the result of executing this operation. +// +// TODO(GODRIVER-617): This should be generated by drivergen. +func (co *CommandOperation) Result() bsoncore.Document { return co.result } + +func (co *CommandOperation) processResponse(response bsoncore.Document, _ Server) error { + co.result = response + return nil +} + +// TODO(GODRIVER-617): This should be generated by drivergen. +func (co *CommandOperation) command(dst []byte, _ description.SelectedServer) ([]byte, error) { + return append(dst, co.cmd[4:len(co.cmd)-1]...), nil +} + +// Execute runs this operations. +// +// TODO(GODRIVER-617): This should be generated by drivergen. +func (co *CommandOperation) Execute(ctx context.Context) error { + if co.d == nil { + return errors.New("a CommandOperation must have a Deployment set before Execute can be called") + } + + if co.database == "" { + return errors.New("Database must be of non-zero length") + } + return OperationContext{ + CommandFn: co.command, + Deployment: co.d, + Database: co.database, + + ProcessResponseFn: co.processResponse, + + Selector: co.selector, + ReadPreference: co.readPref, + ReadConcern: co.rc, + + Client: co.client, + Clock: co.clock, + }.Execute(ctx) +} diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 0093980064..67b08a05a3 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -10,7 +10,8 @@ import ( // Deployment is implemented by types that can select a server from a deployment. type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) - Description() description.Topology + SupportsRetry() bool + Kind() description.TopologyKind } // Server represents a MongoDB server. Implementations should pool connections and handle the @@ -35,3 +36,101 @@ type Connection interface { type ErrorProcessor interface { ProcessError(error) } + +// Handshaker is the interface implemented by types that can perform a MongoDB +// handshake over a provided driver.Connection. This is used during connection +// initialization. Implementations must be goroutine safe. +type Handshaker interface { + Handshake(context.Context, address.Address, Connection) (description.Server, error) +} + +// HandshakerFunc is an adapter to allow the use of ordinary functions as +// connection handshakers. +type HandshakerFunc func(context.Context, address.Address, Connection) (description.Server, error) + +// Handshake implements the Handshaker interface. +func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, conn Connection) (description.Server, error) { + return hf(ctx, addr, conn) +} + +// SingleServerDeployment is an implementation of Deployment that always returns a single server. +type SingleServerDeployment struct{ Server } + +var _ Deployment = SingleServerDeployment{} + +// SelectServer implements the Deployment interface. This method does not use the +// description.SelectedServer provided and instead returns the embedded Server. +func (ssd SingleServerDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) { + return ssd.Server, nil +} + +// SupportsRetry implements the Deployment interface. It always returns false, because a single +// server does not support retryability. +func (SingleServerDeployment) SupportsRetry() bool { return false } + +// Kind implements the Deployment interface. It always returns description.Single. +func (SingleServerDeployment) Kind() description.TopologyKind { return description.Single } + +// SingleConnectionDeployment is an implementation of Deployment that always returns the same +// Connection. +type SingleConnectionDeployment struct{ C Connection } + +var _ Deployment = SingleConnectionDeployment{} +var _ Server = SingleConnectionDeployment{} + +// SelectServer implements the Deployment interface. This method does not use the +// description.SelectedServer provided and instead returns itself. +func (ssd SingleConnectionDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) { + return ssd, nil +} + +// SupportsRetry implements the Deployment interface. It always returns false, because a single +// connection does not support retryability. +func (ssd SingleConnectionDeployment) SupportsRetry() bool { return false } + +// Kind implements the Deployment interface. It always returns description.Single. +func (ssd SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single } + +// Connection implements the Server interface. It always returns the embedded connection. +func (ssd SingleConnectionDeployment) Connection(context.Context) (Connection, error) { + return ssd.C, nil +} + +// TODO(GODRUVER-617): We can likely use 1 type for both the RetryType and the RetryMode by using +// 2 bits for the mode and 1 bit for the type. Although in the practical sense, we might not want to +// do that since the type of retryability is tied to the operation itself and isn't going change, +// e.g. and insert operation will always be a write, however some operations are both reads and +// writes, for instance aggregate is a read but with a $out parameter it's a write. + +// RetryType specifies whether a retry is a read, write, or disabled. +type RetryType uint + +// THese are the availables types of retry. +const ( + _ RetryType = iota + RetryWrite + RetryRead +) + +// RetryMode specifies the way that retries are handled for retryable operations. +type RetryMode uint + +// These are the modes available for retrying. +const ( + // RetryNone disables retrying. + RetryNone RetryMode = iota + // RetryOnce will enable retrying the entire operation once. + RetryOnce + // RetryOncePerCommand will enable retrying each command associated with an operation. For + // example, if an insert is batch split into 4 commands then each of those commands is eligible + // for one retry. + RetryOncePerCommand + // RetryContext will enable retrying until the context.Context's deadline is exceeded or it is + // cancelled. + RetryContext +) + +// Enabled returns if this RetryMode enables retrying. +func (rm RetryMode) Enabled() bool { + return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext +} diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go new file mode 100644 index 0000000000..60a3310099 --- /dev/null +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -0,0 +1,72 @@ +package drivertest + +import ( + "context" + "errors" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" + "go.mongodb.org/mongo-driver/x/network/address" + "go.mongodb.org/mongo-driver/x/network/description" + "go.mongodb.org/mongo-driver/x/network/wiremessage" +) + +// ChannelConn implements the driver.Connection interface by reading and writing wire messages +// to a channel +type ChannelConn struct { + WriteErr error + Written chan []byte + ReadResp chan []byte + ReadErr chan error + Desc description.Server +} + +// WriteWireMessage implements the driver.Connection interface. +func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { + select { + case c.Written <- wm: + default: + c.WriteErr = errors.New("could not write wiremessage to written channel") + } + return c.WriteErr +} + +// ReadWireMessage implements the driver.Connection interface. +func (c *ChannelConn) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) { + var wm []byte + var err error + select { + case wm = <-c.ReadResp: + case err = <-c.ReadErr: + case <-ctx.Done(): + } + return wm, err +} + +// Description implements the driver.Connection interface. +func (c *ChannelConn) Description() description.Server { return c.Desc } + +// Close implements the driver.Connection interface. +func (c *ChannelConn) Close() error { + return nil +} + +// ID implements the driver.Connection interface. +func (c *ChannelConn) ID() string { + return "faked" +} + +// Address implements the driver.Connection interface. +func (c *ChannelConn) Address() address.Address { return address.Address("0.0.0.0") } + +// MakeReply creates an OP_REPLY wiremessage from a BSON document +func MakeReply(doc bsoncore.Document) []byte { + var dst []byte + idx, dst := wiremessagex.AppendHeaderStart(dst, 10, 9, wiremessage.OpReply) + dst = wiremessagex.AppendReplyFlags(dst, 0) + dst = wiremessagex.AppendReplyCursorID(dst, 0) + dst = wiremessagex.AppendReplyStartingFrom(dst, 0) + dst = wiremessagex.AppendReplyNumberReturned(dst, 1) + dst = append(dst, doc...) + return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) +} diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go new file mode 100644 index 0000000000..eede853dc2 --- /dev/null +++ b/x/mongo/driver/errors.go @@ -0,0 +1,290 @@ +package driver + +import ( + "bytes" + "fmt" + "strings" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +var retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001} + +var ( + // TransientTransactionError is an error label for transient errors with transactions. + TransientTransactionError = "TransientTransactionError" + // NetworkError is an error label for network errors. + NetworkError = "NetworkError" +) + +// QueryFailureError is an error representing a command failure as a document. +type QueryFailureError struct { + Message string + Response bsoncore.Document +} + +// Error implements the error interface. +func (e QueryFailureError) Error() string { + return fmt.Sprintf("%s: %v", e.Message, e.Response) +} + +// ResponseError is an error parsing the response to a command. +type ResponseError struct { + Message string + Wrapped error +} + +// NewCommandResponseError creates a CommandResponseError. +func NewCommandResponseError(msg string, err error) ResponseError { + return ResponseError{Message: msg, Wrapped: err} +} + +// Error implements the error interface. +func (e ResponseError) Error() string { + if e.Wrapped != nil { + return fmt.Sprintf("%s: %s", e.Message, e.Wrapped) + } + return fmt.Sprintf("%s", e.Message) +} + +// WriteCommandError is an error for a write command. +type WriteCommandError struct { + WriteConcernError *WriteConcernError + WriteErrors WriteErrors +} + +func (wce WriteCommandError) Error() string { + var buf bytes.Buffer + fmt.Fprint(&buf, "write command error: [") + fmt.Fprintf(&buf, "{%s}, ", wce.WriteErrors) + fmt.Fprintf(&buf, "{%s}]", wce.WriteConcernError) + return buf.String() +} + +// Retryable returns true if the error is retryable +func (wce WriteCommandError) Retryable() bool { + if wce.WriteConcernError == nil { + return false + } + return (*wce.WriteConcernError).Retryable() +} + +// WriteConcernError is a write concern failure that occurred as a result of a +// write operation. +type WriteConcernError struct { + Code int64 + Message string + Details bsoncore.Document +} + +func (wce WriteConcernError) Error() string { return wce.Message } + +// Retryable returns true if the error is retryable +func (wce WriteConcernError) Retryable() bool { + for _, code := range retryableCodes { + if wce.Code == int64(code) { + return true + } + } + if strings.Contains(wce.Message, "not master") || strings.Contains(wce.Message, "node is recovering") { + return true + } + + return false +} + +// WriteError is a non-write concern failure that occurred as a result of a write +// operation. +type WriteError struct { + Index int64 + Code int64 + Message string +} + +func (we WriteError) Error() string { return we.Message } + +// WriteErrors is a group of non-write concern failures that occurred as a result +// of a write operation. +type WriteErrors []WriteError + +func (we WriteErrors) Error() string { + var buf bytes.Buffer + fmt.Fprint(&buf, "write errors: [") + for idx, err := range we { + if idx != 0 { + fmt.Fprintf(&buf, ", ") + } + fmt.Fprintf(&buf, "{%s}", err) + } + fmt.Fprint(&buf, "]") + return buf.String() +} + +// Error is a command execution error from the database. +type Error struct { + Code int32 + Message string + Labels []string + Name string +} + +// Error implements the error interface. +func (e Error) Error() string { + if e.Name != "" { + return fmt.Sprintf("(%v) %v", e.Name, e.Message) + } + return e.Message +} + +// HasErrorLabel returns true if the error contains the specified label. +func (e Error) HasErrorLabel(label string) bool { + if e.Labels != nil { + for _, l := range e.Labels { + if l == label { + return true + } + } + } + return false +} + +// Retryable returns true if the error is retryable +func (e Error) Retryable() bool { + for _, label := range e.Labels { + if label == NetworkError { + return true + } + } + for _, code := range retryableCodes { + if e.Code == code { + return true + } + } + if strings.Contains(e.Message, "not master") || strings.Contains(e.Message, "node is recovering") { + return true + } + + return false +} + +// helper method to extract an error from a reader if there is one; first returned item is the +// error if it exists, the second holds parsing errors +func extractError(rdr bsoncore.Document) error { + var errmsg, codeName string + var code int32 + var labels []string + var ok bool + var wcError WriteCommandError + elems, err := rdr.Elements() + if err != nil { + return err + } + + // TODO(GODRIVER-617): We need to handle write errors and write concern errors here. + for _, elem := range elems { + switch elem.Key() { + case "ok": + switch elem.Value().Type { + case bson.TypeInt32: + if elem.Value().Int32() == 1 { + ok = true + } + case bson.TypeInt64: + if elem.Value().Int64() == 1 { + ok = true + } + case bson.TypeDouble: + if elem.Value().Double() == 1 { + ok = true + } + } + case "errmsg": + if str, okay := elem.Value().StringValueOK(); okay { + errmsg = str + } + case "codeName": + if str, okay := elem.Value().StringValueOK(); okay { + codeName = str + } + case "code": + if c, okay := elem.Value().Int32OK(); okay { + code = c + } + case "errorLabels": + if arr, okay := elem.Value().ArrayOK(); okay { + elems, err := arr.Elements() + if err != nil { + continue + } + for _, elem := range elems { + if str, ok := elem.Value().StringValueOK(); ok { + labels = append(labels, str) + } + } + + } + case "writeErrors": + arr, exists := elem.Value().ArrayOK() + if !exists { + break + } + vals, err := arr.Values() + if err != nil { + continue + } + for _, val := range vals { + var we WriteError + doc, exists := val.DocumentOK() + if !exists { + continue + } + if index, exists := doc.Lookup("index").AsInt64OK(); exists { + we.Index = index + } + if code, exists := doc.Lookup("code").AsInt64OK(); exists { + we.Code = code + } + if msg, exists := doc.Lookup("errMsg").StringValueOK(); exists { + we.Message = msg + } + wcError.WriteErrors = append(wcError.WriteErrors, we) + } + case "writeConcernError": + doc, exists := elem.Value().DocumentOK() + if !exists { + break + } + wcError.WriteConcernError = new(WriteConcernError) + if code, exists := doc.Lookup("code").AsInt64OK(); exists { + wcError.WriteConcernError.Code = code + } + if msg, exists := doc.Lookup("errMsg").StringValueOK(); exists { + wcError.WriteConcernError.Message = msg + } + if info, exists := doc.Lookup("errInfo").DocumentOK(); exists { + wcError.WriteConcernError.Details = make([]byte, len(info)) + copy(wcError.WriteConcernError.Details, info) + } + } + } + + if !ok { + if errmsg == "" { + errmsg = "command failed" + } + + return Error{ + Code: code, + Message: errmsg, + Name: codeName, + Labels: labels, + } + } + + if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil { + return wcError + } + + return nil +} diff --git a/x/mongo/driver/ismaster.go b/x/mongo/driver/ismaster.go new file mode 100644 index 0000000000..75ba0f0e70 --- /dev/null +++ b/x/mongo/driver/ismaster.go @@ -0,0 +1,153 @@ +package driver + +import ( + "context" + "errors" + "runtime" + "strconv" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/version" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/network/address" + "go.mongodb.org/mongo-driver/x/network/description" + "go.mongodb.org/mongo-driver/x/network/result" +) + +// IsMasterOperation is used to run the isMaster handshake operation. +type IsMasterOperation struct { + appname string + compressors []string + saslSupportedMechs string + + server Server + conn Connection + tkind description.TopologyKind + + res result.IsMaster +} + +// IsMaster constructs an IsMasterOperation. +func IsMaster() *IsMasterOperation { return &IsMasterOperation{} } + +// AppName sets the application name in the client metadata sent in this operation. +func (imo *IsMasterOperation) AppName(appname string) *IsMasterOperation { + imo.appname = appname + return imo +} + +// Compressors sets the compressors that can be used. +func (imo *IsMasterOperation) Compressors(compressors []string) *IsMasterOperation { + imo.compressors = compressors + return imo +} + +// SASLSupportedMechs retrieves the supported SASL mechanism for the given user when this operation +// is run. +func (imo *IsMasterOperation) SASLSupportedMechs(username string) *IsMasterOperation { + imo.saslSupportedMechs = username + return imo +} + +// Server sets the server for this operation. +func (imo *IsMasterOperation) Server(server Server) *IsMasterOperation { + imo.server = server + return imo +} + +// Connection sets the connection for this operation. +func (imo *IsMasterOperation) Connection(conn Connection) *IsMasterOperation { + imo.conn = conn + return imo +} + +// Result returns the result of executing this operaiton. +func (imo *IsMasterOperation) Result() result.IsMaster { return imo.res } + +func (imo *IsMasterOperation) processResponse(response bsoncore.Document, _ Server) error { + // Replace this with direct unmarshaling. + err := bson.Unmarshal(response, &imo.res) + if err != nil { + return err + } + + // Reconstructs the $clusterTime doc after decode + if imo.res.ClusterTime != nil { + imo.res.ClusterTime = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", imo.res.ClusterTime)) + } + return nil +} + +func (imo *IsMasterOperation) command(dst []byte, _ description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "isMaster", 1) + + idx, dst := bsoncore.AppendDocumentElementStart(dst, "client") + + didx, dst := bsoncore.AppendDocumentElementStart(dst, "driver") + dst = bsoncore.AppendStringElement(dst, "name", "mongo-go-driver") + dst = bsoncore.AppendStringElement(dst, "version", version.Driver) + dst, _ = bsoncore.AppendDocumentEnd(dst, didx) + + didx, dst = bsoncore.AppendDocumentElementStart(dst, "os") + dst = bsoncore.AppendStringElement(dst, "type", runtime.GOOS) + dst = bsoncore.AppendStringElement(dst, "architecture", runtime.GOARCH) + dst, _ = bsoncore.AppendDocumentEnd(dst, didx) + + dst = bsoncore.AppendStringElement(dst, "platform", runtime.Version()) + if imo.appname != "" { + didx, dst = bsoncore.AppendDocumentElementStart(dst, "application") + dst = bsoncore.AppendStringElement(dst, "name", imo.appname) + dst, _ = bsoncore.AppendDocumentEnd(dst, didx) + } + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + + if imo.saslSupportedMechs != "" { + dst = bsoncore.AppendStringElement(dst, "saslSupportedMechs", imo.saslSupportedMechs) + } + + idx, dst = bsoncore.AppendArrayElementStart(dst, "compression") + for i, compressor := range imo.compressors { + dst = bsoncore.AppendStringElement(dst, strconv.Itoa(i), compressor) + } + dst, _ = bsoncore.AppendArrayEnd(dst, idx) + + return dst, nil +} + +// Execute runs this operation. +func (imo *IsMasterOperation) Execute(ctx context.Context) error { + if imo.server == nil && imo.conn == nil { + return errors.New("an IsMasterOperation must have a Server or Connection set before Execute can be called") + } + + server := imo.server + if imo.conn != nil { + server = connectionServer{imo.conn} + } + return OperationContext{ + CommandFn: imo.command, + Server: server, + Database: "admin", + ProcessResponseFn: imo.processResponse, + }.Execute(ctx) +} + +// Handshake implements the Handshaker interface. +func (imo *IsMasterOperation) Handshake(ctx context.Context, _ address.Address, c Connection) (description.Server, error) { + err := OperationContext{ + CommandFn: imo.command, + Server: connectionServer{c}, + Database: "admin", + ProcessResponseFn: imo.processResponse, + }.Execute(ctx) + if err != nil { + return description.Server{}, err + } + return description.NewServer(c.Address(), imo.res), nil +} + +type connectionServer struct{ c Connection } + +var _ Server = connectionServer{} + +func (cs connectionServer) Connection(context.Context) (Connection, error) { return cs.c, nil } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go new file mode 100644 index 0000000000..9d47a6eba3 --- /dev/null +++ b/x/mongo/driver/operation.go @@ -0,0 +1,740 @@ +package driver + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/mongo/writeconcern" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" + "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" + "go.mongodb.org/mongo-driver/x/network/description" + "go.mongodb.org/mongo-driver/x/network/wiremessage" +) + +var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'} + +var ( + // ErrNoDocCommandResponse occurs when the server indicated a response existed, but none was found. + ErrNoDocCommandResponse = errors.New("command returned no documents") + // ErrMultiDocCommandResponse occurs when the server sent multiple documents in response to a command. + ErrMultiDocCommandResponse = errors.New("command returned multiple documents") +) + +// OperationContext is used to execute an operation. It contains all of the common code required to +// select a server, transform an operation into a command, write the command to a connection from +// the selected server, read a response from that connection, process the response, and potentially +// retry. +// +// The required fields are Database, CommandFn, and either Deployment or ServerSelector and +// TopologyKind. All other fields are optional. +type OperationContext struct { + CommandFn func(dst []byte, desc description.SelectedServer) ([]byte, error) + Database string + + Deployment Deployment + Server Server + TopologyKind description.TopologyKind + + ProcessResponseFn func(response bsoncore.Document, srvr Server) error + RetryableFn func(description.Server) RetryType + + Selector description.ServerSelector + ReadPreference *readpref.ReadPref + ReadConcern *readconcern.ReadConcern + WriteConcern *writeconcern.WriteConcern + + Client *session.Client + Clock *session.ClusterClock + + RetryMode *RetryMode + Batches *Batches + + RetryType RetryType +} + +func (oc OperationContext) selectServer(ctx context.Context) (Server, error) { + if err := oc.Validate(); err != nil { + return nil, err + } + + if oc.Server != nil { + return oc.Server, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + rp := oc.ReadPreference + if rp == nil { + rp = readpref.Primary() + } + + return oc.Deployment.SelectServer(ctx, createReadPrefSelector(rp, oc.Selector)) +} + +// Validate validates this operation, ensuring the fields are set properly. +func (oc OperationContext) Validate() error { + if oc.CommandFn == nil { + return errors.New("the CommandFn field must be set on OperationContext") + } + if (oc.Deployment == nil) && (oc.Server == nil && oc.TopologyKind == description.TopologyKind(0)) { + return errors.New("the Deployment field or the Server and TopologyKind fields must be set on OperationContext") + } + if oc.Database == "" { + return errors.New("the Database field must be non-empty on OperationContext") + } + return nil +} + +// Execute runs this operation. +func (oc OperationContext) Execute(ctx context.Context) error { + err := oc.Validate() + if err != nil { + return err + } + + srvr, err := oc.selectServer(ctx) + if err != nil { + return err + } + + conn, err := srvr.Connection(ctx) + if err != nil { + return err + } + defer conn.Close() + + kind := oc.TopologyKind + if oc.Deployment != nil { + kind = oc.Deployment.Kind() + } + desc := description.SelectedServer{Server: conn.Description(), Kind: kind} + + var retryable RetryType + if oc.RetryableFn != nil { + retryable = oc.RetryableFn(desc.Server) + } + if retryable == RetryWrite && oc.Client != nil { + oc.Client.RetryWrite = true + oc.Client.IncrementTxnNumber() + } + + var res bsoncore.Document + var operationErr WriteCommandError + var original error + var retries int + // TODO(GODRIVER-617): Add support for retryable reads. + if retryable == RetryWrite && oc.RetryMode != nil { + switch *oc.RetryMode { + case RetryOnce, RetryOncePerCommand: + retries = 1 + case RetryContext: + retries = -1 + } + } + batching := oc.Batches.Valid() + for { + if batching { + err = oc.Batches.AdvanceBatch(int(desc.MaxBatchCount), int(desc.MaxDocumentSize)) + if err != nil { + return err + } + } + + // convert to wire message + wm, err := oc.createWireMessage(nil, desc) + if err != nil { + return err + } + + // roundtrip + res, err = oc.roundTrip(ctx, conn, wm) + + // Pull out $clusterTime and operationTime and update session and clock. We handle this before + // handling the error to ensure we are properly gossiping the cluster time. + _ = updateClusterTimes(oc.Client, oc.Clock, res) + _ = updateOperationTime(oc.Client, res) + if err != nil { + return err + } + + var perr error + if oc.ProcessResponseFn != nil { + perr = oc.ProcessResponseFn(res, srvr) + } + switch tt := err.(type) { + case WriteCommandError: + if retryable == RetryWrite && tt.Retryable() && retries != 0 { + retries-- + original, err = err, nil + conn.Close() // Avoid leaking the connection. + srvr, err = oc.selectServer(ctx) + if err != nil { + return original + } + conn, err := srvr.Connection(ctx) + // We know that oc.RetryableFn is not nil because retryable is a valid retryable + // value. + if err != nil || conn == nil || oc.RetryableFn(conn.Description()) == RetryWrite { + return original + } + defer conn.Close() // Avoid leaking the new connection. + continue + } + if batching && oc.Batches.Ordered != nil && *oc.Batches.Ordered == true && len(tt.WriteErrors) > 0 { + return tt + } + operationErr.WriteConcernError = tt.WriteConcernError + operationErr.WriteErrors = append(operationErr.WriteErrors, tt.WriteErrors...) + case Error: + if retryable == RetryWrite && tt.Retryable() && retries != 0 { + retries-- + original, err = err, nil + conn.Close() // Avoid leaking the connection. + srvr, err = oc.selectServer(ctx) + if err != nil { + return original + } + conn, err := srvr.Connection(ctx) + // We know that oc.RetryableFn is not nil because retryable is a valid retryable + // value. + if err != nil || conn == nil || oc.RetryableFn(conn.Description()) == RetryWrite { + return original + } + defer conn.Close() // Avoid leaking the new connection. + continue + } + return err + case nil: + if perr != nil { + return perr + } + default: + return err + } + + if batching && len(oc.Batches.Documents) > 0 { + if retryable == RetryWrite && oc.Client != nil { + oc.Client.IncrementTxnNumber() + if oc.RetryMode != nil && *oc.RetryMode == RetryOncePerCommand { + retries = 1 + } + } + oc.Batches.ClearBatch() + continue + } + break + } + return nil +} + +// roundTrip writes a wiremessage to the connection, reads a wiremessage, and then decodes the +// response into a result or an error. The wm parameter is reused when reading the wiremessage. +func (OperationContext) roundTrip(ctx context.Context, conn Connection, wm []byte) (bsoncore.Document, error) { + err := conn.WriteWireMessage(ctx, wm) + if err != nil { + return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}} + } + + res, err := conn.ReadWireMessage(ctx, wm[:0]) + if err != nil { + err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}} + } + return decodeResult(res) +} + +func (oc OperationContext) createWireMessage(dst []byte, desc description.SelectedServer) ([]byte, error) { + if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion { + return oc.createQueryWireMessage(dst, desc) + } + return oc.createMsgWireMessage(dst, desc) +} + +func (oc OperationContext) createQueryWireMessage(dst []byte, desc description.SelectedServer) ([]byte, error) { + flags := slaveOK(desc, nil) + var wmindex int32 + wmindex, dst = wiremessagex.AppendHeaderStart(dst, wiremessage.NextRequestID(), 0, wiremessage.OpQuery) + dst = wiremessagex.AppendQueryFlags(dst, flags) + // FullCollectionName + dst = append(dst, oc.Database...) + dst = append(dst, dollarCmd[:]...) + dst = append(dst, 0x00) + dst = wiremessagex.AppendQueryNumberToSkip(dst, 0) + dst = wiremessagex.AppendQueryNumberToReturn(dst, -1) + + wrapper := int32(-1) + rp := createReadPref(oc.ReadPreference, desc.Server.Kind, desc.Kind, true) + if len(rp) > 0 { + wrapper, dst = bsoncore.AppendDocumentStart(dst) + dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") + } + idx, dst := bsoncore.AppendDocumentStart(dst) + dst, err := oc.CommandFn(dst, desc) + if err != nil { + return dst, err + } + + if oc.Batches != nil && len(oc.Batches.Current) > 0 { + aidx, dst := bsoncore.AppendArrayElementStart(dst, oc.Batches.Identifier) + for i, doc := range oc.Batches.Current { + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) + } + dst, _ = bsoncore.AppendArrayEnd(dst, aidx) + } + + dst, err = addReadConcern(dst, oc.ReadConcern, oc.Client, desc) + if err != nil { + return dst, err + } + + dst, err = addWriteConcern(dst, oc.WriteConcern) + if err != nil { + return dst, err + } + + dst, err = addSession(dst, oc.Client, desc) + if err != nil { + return dst, err + } + + // TODO(GODRIVER-617): This should likely be part of addSession, but we need to ensure that we + // either turn off RetryWrite when we are doing a retryable read or that we pass in RetryType to + // addSession. We should also only be adding this if the connection supports sessions, but I + // think that's a given if we've set RetryWrite to true. + if oc.RetryType == RetryWrite && oc.Client != nil && oc.Client.RetryWrite { + dst = bsoncore.AppendInt64Element(dst, "txnNumber", oc.Client.TxnNumber) + } + + dst = addClusterTime(dst, oc.Client, oc.Clock, desc) + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + + if len(rp) > 0 { + var err error + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + if err != nil { + return dst, err + } + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), nil +} + +func (oc OperationContext) createMsgWireMessage(dst []byte, desc description.SelectedServer) ([]byte, error) { + // TODO(GODRIVER-617): We need to figure out how to include the writeconcern here so that we can + // set the moreToCome bit. + var flags wiremessage.MsgFlag + var wmindex int32 + wmindex, dst = wiremessagex.AppendHeaderStart(dst, wiremessage.NextRequestID(), 0, wiremessage.OpMsg) + dst = wiremessagex.AppendMsgFlags(dst, flags) + // Body + dst = wiremessagex.AppendMsgSectionType(dst, wiremessage.SingleDocument) + + idx, dst := bsoncore.AppendDocumentStart(dst) + + dst, err := oc.CommandFn(dst, desc) + if err != nil { + return dst, err + } + dst, err = addReadConcern(dst, oc.ReadConcern, oc.Client, desc) + if err != nil { + return dst, err + } + dst, err = addWriteConcern(dst, oc.WriteConcern) + if err != nil { + return dst, err + } + + dst, err = addSession(dst, oc.Client, desc) + if err != nil { + return dst, err + } + + // TODO(GODRIVER-617): This should likely be part of addSession, but we need to ensure that we + // either turn off RetryWrite when we are doing a retryable read or that we pass in RetryType to + // addSession. We should also only be adding this if the connection supports sessions, but I + // think that's a given if we've set RetryWrite to true. + if oc.RetryType == RetryWrite && oc.Client != nil && oc.Client.RetryWrite { + dst = bsoncore.AppendInt64Element(dst, "txnNumber", oc.Client.TxnNumber) + } + + dst = addClusterTime(dst, oc.Client, oc.Clock, desc) + + dst = bsoncore.AppendStringElement(dst, "$db", oc.Database) + rp := createReadPref(oc.ReadPreference, desc.Server.Kind, desc.Kind, false) + if len(rp) > 0 { + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + } + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + + if oc.Batches != nil && len(oc.Batches.Current) > 0 { + dst = wiremessagex.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + + dst = append(dst, oc.Batches.Identifier...) + dst = append(dst, 0x00) + + for _, doc := range oc.Batches.Current { + dst = append(dst, doc...) + } + + dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), nil +} + +// Retryable writes are supported if the server supports sessions, the operation is not +// within a transaction, and the write is acknowledged +func retrySupported( + tdesc description.Topology, + desc description.Server, + sess *session.Client, + wc *writeconcern.WriteConcern, +) bool { + return (tdesc.SessionTimeoutMinutes != 0 && tdesc.Kind != description.Single) && + description.SessionsSupported(desc.WireVersion) && + sess != nil && + !(sess.TransactionInProgress() || sess.TransactionStarting()) && + writeconcern.AckWrite(wc) +} + +// createReadPrefSelector will either return the first non-nil selector or create a read preference +// selector with the provided read preference. +func createReadPrefSelector(rp *readpref.ReadPref, selectors ...description.ServerSelector) description.ServerSelector { + for _, selector := range selectors { + if selector != nil { + return selector + } + } + if rp == nil { + rp = readpref.Primary() + } + return description.CompositeSelector([]description.ServerSelector{ + description.ReadPrefSelector(rp), + description.LatencySelector(15 * time.Millisecond), + }) +} + +func addReadConcern(dst []byte, rc *readconcern.ReadConcern, client *session.Client, desc description.SelectedServer) ([]byte, error) { + // Starting transaction's read concern overrides all others + if client != nil && client.TransactionStarting() && client.CurrentRc != nil { + rc = client.CurrentRc + } + + // start transaction must append afterclustertime IF causally consistent and operation time exists + if rc == nil && client != nil && client.TransactionStarting() && client.Consistent && client.OperationTime != nil { + rc = readconcern.New() + } + + if rc == nil { + return dst, nil + } + + _, data, err := rc.MarshalBSONValue() // always returns a document + if err != nil { + return dst, err + } + + if description.SessionsSupported(desc.WireVersion) && client != nil && client.Consistent && client.OperationTime != nil { + data = data[:len(data)-1] // remove the null byte + data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I) + data, _ = bsoncore.AppendDocumentEnd(data, 0) + } + + return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil +} + +func addWriteConcern(dst []byte, wc *writeconcern.WriteConcern) ([]byte, error) { + if wc == nil { + return dst, nil + } + + t, data, err := wc.MarshalBSONValue() + if err == writeconcern.ErrEmptyWriteConcern { + return dst, nil + } + if err != nil { + return dst, err + } + + return append(bsoncore.AppendHeader(dst, t, "writeConcern"), data...), nil +} + +func addSession(dst []byte, client *session.Client, desc description.SelectedServer) ([]byte, error) { + if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 { + return dst, nil + } + if client.Terminated { + return dst, session.ErrSessionEnded + } + lsid, _ := client.SessionID.MarshalBSON() + dst = bsoncore.AppendDocumentElement(dst, "lsid", lsid) + + if client.TransactionRunning() || client.RetryingCommit { + dst = bsoncore.AppendInt64Element(dst, "txnNumber", client.TxnNumber) + if client.TransactionStarting() { + dst = bsoncore.AppendBooleanElement(dst, "startTransaction", true) + } + dst = bsoncore.AppendBooleanElement(dst, "autocommit", false) + } + + client.ApplyCommand(desc.Server) + + return dst, nil +} + +func addClusterTime(dst []byte, client *session.Client, clock *session.ClusterClock, desc description.SelectedServer) []byte { + if (clock == nil && client == nil) || !description.SessionsSupported(desc.WireVersion) { + return dst + } + clusterTime := clock.GetClusterTime() + if client != nil { + clusterTime = session.MaxClusterTime(clusterTime, client.ClusterTime) + } + if clusterTime == nil { + return dst + } + val, err := clusterTime.LookupErr("$clusterTime") + if err != nil { + return dst + } + return append(bsoncore.AppendHeader(dst, val.Type, "$clusterTime"), val.Value...) + // return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime) +} + +func responseClusterTime(response bsoncore.Document) bsoncore.Document { + clusterTime, err := response.LookupErr("$clusterTime") + if err != nil { + // $clusterTime not included by the server + return nil + } + idx, doc := bsoncore.AppendDocumentStart(nil) + doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime") + doc = append(doc, clusterTime.Data...) + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + return doc +} + +func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bsoncore.Document) error { + clusterTime := responseClusterTime(response) + if clusterTime == nil { + return nil + } + + if sess != nil { + err := sess.AdvanceClusterTime(bson.Raw(clusterTime)) + if err != nil { + return err + } + } + + if clock != nil { + clock.AdvanceClusterTime(bson.Raw(clusterTime)) + } + + return nil +} + +func updateOperationTime(sess *session.Client, response bsoncore.Document) error { + if sess == nil { + return nil + } + + opTimeElem, err := response.LookupErr("operationTime") + if err != nil { + // operationTime not included by the server + return nil + } + + t, i := opTimeElem.Timestamp() + return sess.AdvanceOperationTime(&primitive.Timestamp{ + T: t, + I: i, + }) +} + +func createReadPref(rp *readpref.ReadPref, serverKind description.ServerKind, topologyKind description.TopologyKind, isOpQuery bool) bsoncore.Document { + idx, doc := bsoncore.AppendDocumentStart(nil) + + if rp == nil { + if topologyKind == description.Single && serverKind != description.Mongos { + doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + return doc + } + return nil + } + + switch rp.Mode() { + case readpref.PrimaryMode: + if serverKind == description.Mongos { + return nil + } + if topologyKind == description.Single { + doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + return doc + } + doc = bsoncore.AppendStringElement(doc, "mode", "primary") + case readpref.PrimaryPreferredMode: + doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") + case readpref.SecondaryPreferredMode: + _, ok := rp.MaxStaleness() + if serverKind == description.Mongos && isOpQuery && !ok && len(rp.TagSets()) == 0 { + return nil + } + doc = bsoncore.AppendStringElement(doc, "mode", "secondaryPreferred") + case readpref.SecondaryMode: + doc = bsoncore.AppendStringElement(doc, "mode", "secondary") + case readpref.NearestMode: + doc = bsoncore.AppendStringElement(doc, "mode", "nearest") + } + + sets := make([]bsoncore.Document, 0, len(rp.TagSets())) + for _, ts := range rp.TagSets() { + if len(ts) == 0 { + continue + } + i, set := bsoncore.AppendDocumentStart(nil) + for _, t := range ts { + set = bsoncore.AppendStringElement(set, t.Name, t.Value) + } + set, _ = bsoncore.AppendDocumentEnd(set, i) + sets = append(sets, set) + } + if len(sets) > 0 { + var aidx int32 + aidx, doc = bsoncore.AppendArrayElementStart(doc, "tags") + for i, set := range sets { + doc = bsoncore.AppendDocumentElement(doc, strconv.Itoa(i), set) + } + doc, _ = bsoncore.AppendArrayEnd(doc, aidx) + } + + if d, ok := rp.MaxStaleness(); ok { + doc = bsoncore.AppendInt32Element(doc, "maxStalenessSeconds", int32(d.Seconds())) + } + + return doc +} + +func slaveOK(desc description.SelectedServer, rp []byte) wiremessage.QueryFlag { + if desc.Kind == description.Single && desc.Server.Kind != description.Mongos { + return wiremessage.SlaveOK + } + + if mode, ok := bsoncore.Document(rp).Lookup("mode").StringValueOK(); ok && mode != "primary" { + return wiremessage.SlaveOK + } + + return 0 +} + +func decodeResult(wm []byte) (bsoncore.Document, error) { + wmLength := len(wm) + length, _, _, opcode, wm, ok := wiremessagex.ReadHeader(wm) + if !ok || int(length) > wmLength { + return nil, errors.New("malformed wire message: insufficient bytes") + } + + wm = wm[:wmLength-16] // constrain to just this wiremessage, incase there are multiple in the slice + + switch opcode { + case wiremessage.OpReply: + var flags wiremessage.ReplyFlag + flags, wm, ok = wiremessagex.ReadReplyFlags(wm) + if !ok { + return nil, errors.New("malformed OP_REPLY: missing flags") + } + _, wm, ok = wiremessagex.ReadReplyCursorID(wm) + if !ok { + return nil, errors.New("malformed OP_REPLY: missing cursorID") + } + _, wm, ok = wiremessagex.ReadReplyStartingFrom(wm) + if !ok { + return nil, errors.New("malformed OP_REPLY: missing startingFrom") + } + var numReturned int32 + numReturned, wm, ok = wiremessagex.ReadReplyNumberReturned(wm) + if !ok { + return nil, errors.New("malformed OP_REPLY: missing numberReturned") + } + if numReturned == 0 { + return nil, ErrNoDocCommandResponse + } + if numReturned > 1 { + return nil, ErrMultiDocCommandResponse + } + var rdr bsoncore.Document + rdr, rem, ok := wiremessagex.ReadReplyDocument(wm) + if !ok || len(rem) > 0 { + return nil, NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of documents returned", nil) + } + err := rdr.Validate() + if err != nil { + return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err) + } + if flags&wiremessage.QueryFailure == wiremessage.QueryFailure { + return nil, QueryFailureError{ + Message: "command failure", + Response: rdr, + } + } + + return rdr, extractError(rdr) + case wiremessage.OpMsg: + _, wm, ok = wiremessagex.ReadMsgFlags(wm) + if !ok { + return nil, errors.New("malformed wire message: missing OP_MSG flags") + } + + var res bsoncore.Document + for len(wm) > 0 { + var stype wiremessage.SectionType + stype, wm, ok = wiremessagex.ReadMsgSectionType(wm) + if !ok { + return nil, errors.New("malformed wire message: insuffienct bytes to read section type") + } + + switch stype { + case wiremessage.SingleDocument: + res, wm, ok = wiremessagex.ReadMsgSectionSingleDocument(wm) + if !ok { + return nil, errors.New("malformed wire message: insufficient bytes to read single document") + } + case wiremessage.DocumentSequence: + // TODO(GODRIVER-617): Implement document sequence returns. + _, _, wm, ok = wiremessagex.ReadMsgSectionDocumentSequence(wm) + if !ok { + return nil, errors.New("malformed wire message: insufficient bytes to read document sequence") + } + default: + return nil, fmt.Errorf("malformed wire message: uknown section type %v", stype) + } + } + + err := res.Validate() + if err != nil { + return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err) + } + + return res, extractError(res) + default: + return nil, fmt.Errorf("cannot decode result from %s", opcode) + } +} diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go new file mode 100644 index 0000000000..43786b8a02 --- /dev/null +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -0,0 +1,236 @@ +package wiremessage + +import ( + "bytes" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/network/wiremessage" +) + +// WireMessage represents a MongoDB wire message in binary form. +type WireMessage []byte + +// OpCode represents a MongoDB wire protocol opcode. +type OpCode = wiremessage.OpCode + +// AppendHeaderStart appends a header to the dst slice and returns an index where the wire message +// starts in dst and the updated slice. +func AppendHeaderStart(dst []byte, reqid, respto int32, opcode OpCode) (index int32, b []byte) { + index, dst = bsoncore.ReserveLength(dst) + dst = appendi32(dst, reqid) + dst = appendi32(dst, respto) + dst = appendi32(dst, int32(opcode)) + return index, dst +} + +// ReadHeader reads a wire message header from src. +func ReadHeader(src []byte) (length, requestID, responseTo int32, opcode OpCode, rem []byte, ok bool) { + if len(src) < 16 { + return 0, 0, 0, 0, src, false + } + length = (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24) + requestID = (int32(src[4]) | int32(src[5])<<8 | int32(src[6])<<16 | int32(src[7])<<24) + responseTo = (int32(src[8]) | int32(src[9])<<8 | int32(src[10])<<16 | int32(src[11])<<24) + opcode = OpCode(int32(src[12]) | int32(src[13])<<8 | int32(src[14])<<16 | int32(src[15])<<24) + return length, requestID, responseTo, opcode, src[16:], true +} + +// AppendQueryFlags appends the flags for an OP_QUERY wire message. +func AppendQueryFlags(dst []byte, flags wiremessage.QueryFlag) []byte { + return appendi32(dst, int32(flags)) +} + +// AppendMsgFlags appends the flags for an OP_MSG wire message. +func AppendMsgFlags(dst []byte, flags wiremessage.MsgFlag) []byte { + return appendi32(dst, int32(flags)) +} + +// AppendReplyFlags appends the flags for an OP_REPLY wire message. +func AppendReplyFlags(dst []byte, flags wiremessage.ReplyFlag) []byte { + return appendi32(dst, int32(flags)) +} + +// AppendMsgSectionType appends the section type to dst. +func AppendMsgSectionType(dst []byte, stype wiremessage.SectionType) []byte { + return append(dst, byte(stype)) +} + +// AppendQueryFullCollectionName appends the full collection name to dst. +func AppendQueryFullCollectionName(dst []byte, ns string) []byte { + return appendCString(dst, ns) +} + +// AppendQueryNumberToSkip appends the number to skip to dst. +func AppendQueryNumberToSkip(dst []byte, skip int32) []byte { + return appendi32(dst, skip) +} + +// AppendQueryNumberToReturn appends the number to return to dst. +func AppendQueryNumberToReturn(dst []byte, nor int32) []byte { + return appendi32(dst, nor) +} + +// AppendReplyCursorID appends the cursor ID to dst. +func AppendReplyCursorID(dst []byte, id int64) []byte { + return appendi64(dst, id) +} + +// AppendReplyStartingFrom appends the starting from field to dst. +func AppendReplyStartingFrom(dst []byte, sf int32) []byte { + return appendi32(dst, sf) +} + +// AppendReplyNumberReturned appends the number returned to dst. +func AppendReplyNumberReturned(dst []byte, nr int32) []byte { + return appendi32(dst, nr) +} + +// ReadMsgFlags reads the OP_MSG flags from src. +func ReadMsgFlags(src []byte) (flags wiremessage.MsgFlag, rem []byte, ok bool) { + i32, rem, ok := readi32(src) + return wiremessage.MsgFlag(i32), rem, ok +} + +// ReadMsgSectionType reads the section type from src. +func ReadMsgSectionType(src []byte) (stype wiremessage.SectionType, rem []byte, ok bool) { + if len(src) < 1 { + return 0, src, false + } + return wiremessage.SectionType(src[0]), src[1:], true +} + +// ReadMsgSectionSingleDocument reads a single document from src. +func ReadMsgSectionSingleDocument(src []byte) (doc bsoncore.Document, rem []byte, ok bool) { + return bsoncore.ReadDocument(src) +} + +// ReadMsgSectionDocumentSequence reads an identifier and document sequence from src. +func ReadMsgSectionDocumentSequence(src []byte) (identifier string, docs []bsoncore.Document, rem []byte, ok bool) { + length, rem, ok := readi32(src) + if !ok || int(length) > len(src) { + return "", nil, rem, false + } + + rem, ret := rem[:length-4], rem[length-4:] // reslice so we can just iterate a loop later + + identifier, rem, ok = readcstring(rem) + if !ok { + return "", nil, rem, false + } + + docs = make([]bsoncore.Document, 0) + var doc bsoncore.Document + for { + doc, rem, ok = bsoncore.ReadDocument(rem) + if !ok { + break + } + docs = append(docs, doc) + } + if len(rem) > 0 { + return "", nil, append(rem, ret...), false + } + + return identifier, docs, ret, true +} + +// ReadMsgChecksum reads a checksum from src. +func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { + i32, rem, ok := readi32(src) + return uint32(i32), rem, ok +} + +// ReadQueryFlags reads OP_QUERY flags from src. +func ReadQueryFlags(src []byte) (flags wiremessage.QueryFlag, rem []byte, ok bool) { + i32, rem, ok := readi32(src) + return wiremessage.QueryFlag(i32), rem, ok +} + +// ReadQueryFullCollectionName reads the full collection name from src. +func ReadQueryFullCollectionName(src []byte) (collname string, rem []byte, ok bool) { + return readcstring(src) +} + +// ReadQueryNumberToSkip reads the number to skip from src. +func ReadQueryNumberToSkip(src []byte) (nts int32, rem []byte, ok bool) { + return readi32(src) +} + +// ReadQueryNumberToReturn reads the number to return from src. +func ReadQueryNumberToReturn(src []byte) (ntr int32, rem []byte, ok bool) { + return readi32(src) +} + +// ReadQueryQuery reads the query from src. +func ReadQueryQuery(src []byte) (query bsoncore.Document, rem []byte, ok bool) { + return bsoncore.ReadDocument(src) +} + +// ReadQueryReturnFieldsSelector reads a return fields selector document from src. +func ReadQueryReturnFieldsSelector(src []byte) (rfs bsoncore.Document, rem []byte, ok bool) { + return bsoncore.ReadDocument(src) +} + +// ReadReplyFlags reads OP_REPLY flags from src. +func ReadReplyFlags(src []byte) (flags wiremessage.ReplyFlag, rem []byte, ok bool) { + i32, rem, ok := readi32(src) + return wiremessage.ReplyFlag(i32), rem, ok +} + +// ReadReplyCursorID reads a cursor ID from src. +func ReadReplyCursorID(src []byte) (cursorID int64, rem []byte, ok bool) { + return readi64(src) +} + +// ReadReplyStartingFrom reads the starting from from src. +func ReadReplyStartingFrom(src []byte) (startingFrom int32, rem []byte, ok bool) { + return readi32(src) +} + +// ReadReplyNumberReturned reads the numbered returned from src. +func ReadReplyNumberReturned(src []byte) (numberReturned int32, rem []byte, ok bool) { + return readi32(src) +} + +// ReadReplyDocument reads a reply document from src. +func ReadReplyDocument(src []byte) (doc bsoncore.Document, rem []byte, ok bool) { + return bsoncore.ReadDocument(src) +} + +func appendi32(dst []byte, i32 int32) []byte { + return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24)) +} + +func appendi64(b []byte, i int64) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) +} + +func appendCString(b []byte, str string) []byte { + b = append(b, str...) + return append(b, 0x00) +} + +func readi32(src []byte) (int32, []byte, bool) { + if len(src) < 4 { + return 0, src, false + } + + return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true +} + +func readi64(src []byte) (int64, []byte, bool) { + if len(src) < 8 { + return 0, src, false + } + i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 | + int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56) + return i64, src[8:], true +} + +func readcstring(src []byte) (string, []byte, bool) { + idx := bytes.IndexByte(src, 0x00) + if idx < 0 { + return "", src, false + } + return string(src[:idx]), src[idx+1:], true +} diff --git a/x/mongo/driverlegacy/abort_transaction.go b/x/mongo/driverlegacy/abort_transaction.go index 692a0ee23b..8a0026f953 100644 --- a/x/mongo/driverlegacy/abort_transaction.go +++ b/x/mongo/driverlegacy/abort_transaction.go @@ -61,7 +61,7 @@ func abortTransaction( return result.TransactionResult{}, oldErr } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.TransactionResult{}, oldErr diff --git a/x/mongo/driverlegacy/aggregate.go b/x/mongo/driverlegacy/aggregate.go index ceff3dc145..26661e5072 100644 --- a/x/mongo/driverlegacy/aggregate.go +++ b/x/mongo/driverlegacy/aggregate.go @@ -59,7 +59,7 @@ func Aggregate( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/auth/auth.go b/x/mongo/driverlegacy/auth/auth.go index 4464e3870c..8baf41dae0 100644 --- a/x/mongo/driverlegacy/auth/auth.go +++ b/x/mongo/driverlegacy/auth/auth.go @@ -10,11 +10,9 @@ import ( "context" "fmt" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/address" - "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // AuthenticatorFactory constructs an authenticator. @@ -46,51 +44,6 @@ func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) { authFactories[name] = factory } -// // Opener returns a connection opener that will open and authenticate the connection. -// func Opener(opener conn.Opener, authenticator Authenticator) conn.Opener { -// return func(ctx context.Context, addr model.Addr, opts ...conn.Option) (conn.Connection, error) { -// return NewConnection(ctx, authenticator, opener, addr, opts...) -// } -// } -// -// // NewConnection opens a connection and authenticates it. -// func NewConnection(ctx context.Context, authenticator Authenticator, opener conn.Opener, addr model.Addr, opts ...conn.Option) (conn.Connection, error) { -// conn, err := opener(ctx, addr, opts...) -// if err != nil { -// if conn != nil { -// // Ignore any error that occurs since we're already returning a different one. -// _ = conn.Close() -// } -// return nil, err -// } -// -// err = authenticator.Auth(ctx, conn) -// if err != nil { -// // Ignore any error that occurs since we're already returning a different one. -// _ = conn.Close() -// return nil, err -// } -// -// return conn, nil -// } - -// Configurer creates a connection configurer for the given authenticator. -// -// TODO(skriptble): Fully implement this once this package is moved over to the new connection type. -// func Configurer(configurer connection.Configurer, authenticator Authenticator) connection.Configurer { -// return connection.ConfigurerFunc(func(ctx context.Context, conn connection.Connection) (connection.Connection, error) { -// err := authenticator.Auth(ctx, conn) -// if err != nil { -// conn.Close() -// return nil, err -// } -// if configurer == nil { -// return conn, nil -// } -// return configurer.Configure(ctx, conn) -// }) -// } - // HandshakeOptions packages options that can be passed to the Handshaker() // function. DBUser is optional but must be of the form ; // if non-empty, then the connection will do SASL mechanism negotiation. @@ -103,13 +56,13 @@ type HandshakeOptions struct { } // Handshaker creates a connection handshaker for the given authenticator. -func Handshaker(h connection.Handshaker, options *HandshakeOptions) connection.Handshaker { - return connection.HandshakerFunc(func(ctx context.Context, addr address.Address, rw wiremessage.ReadWriter) (description.Server, error) { - desc, err := (&command.Handshake{ - Client: command.ClientDoc(options.AppName), - Compressors: options.Compressors, - SaslSupportedMechs: options.DBUser, - }).Handshake(ctx, addr, rw) +func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker { + return driver.HandshakerFunc(func(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { + desc, err := driver.IsMaster(). + AppName(options.AppName). + Compressors(options.Compressors). + SASLSupportedMechs(options.DBUser). + Handshake(ctx, addr, conn) if err != nil { return description.Server{}, newAuthError("handshake failure", err) @@ -125,7 +78,7 @@ func Handshaker(h connection.Handshaker, options *HandshakeOptions) connection.H } } if performAuth(desc) && options.Authenticator != nil { - err = options.Authenticator.Auth(ctx, desc, rw) + err = options.Authenticator.Auth(ctx, desc, conn) if err != nil { return description.Server{}, newAuthError("auth error", err) } @@ -134,14 +87,14 @@ func Handshaker(h connection.Handshaker, options *HandshakeOptions) connection.H if h == nil { return desc, nil } - return h.Handshake(ctx, addr, rw) + return h.Handshake(ctx, addr, conn) }) } // Authenticator handles authenticating a connection. type Authenticator interface { // Auth authenticates the connection. - Auth(context.Context, description.Server, wiremessage.ReadWriter) error + Auth(context.Context, description.Server, driver.Connection) error } func newAuthError(msg string, inner error) error { diff --git a/x/mongo/driverlegacy/auth/auth_test.go b/x/mongo/driverlegacy/auth/auth_test.go index f047586b6c..7650cff04c 100644 --- a/x/mongo/driverlegacy/auth/auth_test.go +++ b/x/mongo/driverlegacy/auth/auth_test.go @@ -9,10 +9,10 @@ package auth_test import ( "testing" - "reflect" - + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" - "go.mongodb.org/mongo-driver/x/bsonx" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" . "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/wiremessage" ) @@ -47,21 +47,81 @@ func TestCreateAuthenticator(t *testing.T) { } } -func compareResponses(t *testing.T, wm wiremessage.WireMessage, expectedPayload bsonx.Doc, dbName string) { - switch converted := wm.(type) { - case wiremessage.Query: - payloadBytes, err := expectedPayload.MarshalBSON() - if err != nil { - t.Fatalf("couldn't marshal query bson: %v", err) +func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document, dbName string) { + _, _, _, opcode, wm, ok := wiremessagex.ReadHeader(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + var actualPayload bsoncore.Document + switch opcode { + case wiremessage.OpQuery: + _, wm, ok := wiremessagex.ReadQueryFlags(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + _, wm, ok = wiremessagex.ReadQueryFullCollectionName(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + _, wm, ok = wiremessagex.ReadQueryNumberToSkip(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + _, wm, ok = wiremessagex.ReadQueryNumberToReturn(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + actualPayload, _, ok = wiremessagex.ReadQueryQuery(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") } - require.True(t, reflect.DeepEqual([]byte(converted.Query), payloadBytes)) - case wiremessage.Msg: - msgPayload := append(expectedPayload, bsonx.Elem{"$db", bsonx.String(dbName)}) - payloadBytes, err := msgPayload.MarshalBSON() + case wiremessage.OpMsg: + // Append the $db field. + elems, err := expectedPayload.Elements() if err != nil { - t.Fatalf("couldn't marshal msg bson: %v", err) + t.Fatalf("expectedPayload is not valid: %v", err) } + elems = append(elems, bsoncore.AppendStringElement(nil, "$db", dbName)) + elems = append(elems, bsoncore.AppendDocumentElement(nil, + "$readPreference", + bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred")), + )) + bslc := make([][]byte, 0, len(elems)) // BuildDocumentFromElements takes a [][]byte, not a []bsoncore.Element. + for _, elem := range elems { + bslc = append(bslc, elem) + } + expectedPayload = bsoncore.BuildDocumentFromElements(nil, bslc...) + + _, wm, ok := wiremessagex.ReadMsgFlags(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + loop: + for { + var stype wiremessage.SectionType + stype, wm, ok = wiremessagex.ReadMsgSectionType(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + break + } + switch stype { + case wiremessage.DocumentSequence: + _, _, wm, ok = wiremessagex.ReadMsgSectionDocumentSequence(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + break loop + } + case wiremessage.SingleDocument: + actualPayload, wm, ok = wiremessagex.ReadMsgSectionSingleDocument(wm) + if !ok { + t.Fatalf("wiremessage is too short to unmarshal") + } + break loop + } + } + } - require.True(t, reflect.DeepEqual([]byte(converted.Sections[0].(wiremessage.SectionBody).Document), payloadBytes)) + if !cmp.Equal(actualPayload, expectedPayload) { + t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload) } } diff --git a/x/mongo/driverlegacy/auth/default.go b/x/mongo/driverlegacy/auth/default.go index e251f90360..d974b32a6b 100644 --- a/x/mongo/driverlegacy/auth/default.go +++ b/x/mongo/driverlegacy/auth/default.go @@ -9,8 +9,8 @@ package auth import ( "context" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { @@ -26,7 +26,7 @@ type DefaultAuthenticator struct { } // Auth authenticates the connection. -func (a *DefaultAuthenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { +func (a *DefaultAuthenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error { var actual Authenticator var err error @@ -43,7 +43,7 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, desc description.Server return newAuthError("error creating authenticator", err) } - return actual.Auth(ctx, desc, rw) + return actual.Auth(ctx, desc, conn) } // If a server provides a list of supported mechanisms, we choose diff --git a/x/mongo/driverlegacy/auth/gssapi.go b/x/mongo/driverlegacy/auth/gssapi.go index 4fb27f5ed2..64f5b67ec1 100644 --- a/x/mongo/driverlegacy/auth/gssapi.go +++ b/x/mongo/driverlegacy/auth/gssapi.go @@ -14,9 +14,9 @@ import ( "fmt" "net" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth/internal/gssapi" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // GSSAPI is the mechanism name for GSSAPI. @@ -44,7 +44,7 @@ type GSSAPIAuthenticator struct { } // Auth authenticates the connection. -func (a *GSSAPIAuthenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { +func (a *GSSAPIAuthenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error { target := desc.Addr.String() hostname, _, err := net.SplitHostPort(target) if err != nil { @@ -56,5 +56,5 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, desc description.Server, if err != nil { return newAuthError("error creating gssapi", err) } - return ConductSaslConversation(ctx, desc, rw, "$external", client) + return ConductSaslConversation(ctx, conn, "$external", client) } diff --git a/x/mongo/driverlegacy/auth/mongodbcr.go b/x/mongo/driverlegacy/auth/mongodbcr.go index ccc1b0f922..a643fbdade 100644 --- a/x/mongo/driverlegacy/auth/mongodbcr.go +++ b/x/mongo/driverlegacy/auth/mongodbcr.go @@ -14,10 +14,9 @@ import ( "io" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/x/bsonx" - "go.mongodb.org/mongo-driver/x/network/command" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // MONGODBCR is the mechanism name for MONGODB-CR. @@ -45,19 +44,20 @@ type MongoDBCRAuthenticator struct { // Auth authenticates the connection. // // The MONGODB-CR authentication mechanism is deprecated in MongoDB 4.0. -func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { +func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error { db := a.DB if db == "" { db = defaultAuthDB } - cmd := command.Read{DB: db, Command: bsonx.Doc{{"getnonce", bsonx.Int32(1)}}} - ssdesc := description.SelectedServer{Server: desc} - rdr, err := cmd.RoundTrip(ctx, ssdesc, rw) + doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1)) + cmd := driver.Command(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn}) + err := cmd.Execute(ctx) if err != nil { return newError(err, MONGODBCR) } + rdr := cmd.Result() var getNonceResult struct { Nonce string `bson:"nonce"` @@ -68,16 +68,14 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Serv return newAuthError("unmarshal error", err) } - cmd = command.Read{ - DB: db, - Command: bsonx.Doc{ - {"authenticate", bsonx.Int32(1)}, - {"user", bsonx.String(a.Username)}, - {"nonce", bsonx.String(getNonceResult.Nonce)}, - {"key", bsonx.String(a.createKey(getNonceResult.Nonce))}, - }, - } - _, err = cmd.RoundTrip(ctx, ssdesc, rw) + doc = bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "authenticate", 1), + bsoncore.AppendStringElement(nil, "user", a.Username), + bsoncore.AppendStringElement(nil, "nonce", getNonceResult.Nonce), + bsoncore.AppendStringElement(nil, "key", a.createKey(getNonceResult.Nonce)), + ) + cmd = driver.Command(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn}) + err = cmd.Execute(ctx) if err != nil { return newError(err, MONGODBCR) } diff --git a/x/mongo/driverlegacy/auth/mongodbcr_test.go b/x/mongo/driverlegacy/auth/mongodbcr_test.go index f1314c2dd5..23c4588d1e 100644 --- a/x/mongo/driverlegacy/auth/mongodbcr_test.go +++ b/x/mongo/driverlegacy/auth/mongodbcr_test.go @@ -12,11 +12,10 @@ import ( "strings" - "go.mongodb.org/mongo-driver/internal" - "go.mongodb.org/mongo-driver/x/bsonx" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" . "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) func TestMongoDBCRAuthenticator_Fails(t *testing.T) { @@ -28,21 +27,26 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) { Password: "pencil", } - resps := make(chan wiremessage.WireMessage, 2) - writeReplies(t, resps, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"nonce", bsonx.String("2375531c32080ae8")}, - }, bsonx.Doc{ - {"ok", bsonx.Int32(0)}, - }) + resps := make(chan []byte, 2) + writeReplies(t, resps, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"), + ), bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 0), + )) - c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 2), ReadResp: resps} - - err := authenticator.Auth(context.Background(), description.Server{ + desc := description.Server{ WireVersion: &description.VersionRange{ Max: 6, }, - }, c) + } + c := &drivertest.ChannelConn{ + Written: make(chan []byte, 2), + ReadResp: resps, + Desc: desc, + } + + err := authenticator.Auth(context.Background(), desc, c) if err == nil { t.Fatalf("expected an error but got none") } @@ -62,21 +66,26 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) { Password: "pencil", } - resps := make(chan wiremessage.WireMessage, 2) - writeReplies(t, resps, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"nonce", bsonx.String("2375531c32080ae8")}, - }, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - }) - - c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 2), ReadResp: resps} + resps := make(chan []byte, 2) + writeReplies(t, resps, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"), + ), bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + )) - err := authenticator.Auth(context.Background(), description.Server{ + desc := description.Server{ WireVersion: &description.VersionRange{ Max: 6, }, - }, c) + } + c := &drivertest.ChannelConn{ + Written: make(chan []byte, 2), + ReadResp: resps, + Desc: desc, + } + + err := authenticator.Auth(context.Background(), desc, c) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } @@ -85,25 +94,21 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) { t.Fatalf("expected 2 messages to be sent but had %d", len(c.Written)) } - want := bsonx.Doc{{"getnonce", bsonx.Int32(1)}} + want := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1)) compareResponses(t, <-c.Written, want, "source") - expectedAuthenticateDoc := bsonx.Doc{ - {"authenticate", bsonx.Int32(1)}, - {"user", bsonx.String("user")}, - {"nonce", bsonx.String("2375531c32080ae8")}, - {"key", bsonx.String("21742f26431831d5cfca035a08c5bdf6")}, - } + expectedAuthenticateDoc := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "authenticate", 1), + bsoncore.AppendStringElement(nil, "user", "user"), + bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"), + bsoncore.AppendStringElement(nil, "key", "21742f26431831d5cfca035a08c5bdf6"), + ) compareResponses(t, <-c.Written, expectedAuthenticateDoc, "source") } -func writeReplies(t *testing.T, c chan wiremessage.WireMessage, docs ...bsonx.Doc) { +func writeReplies(t *testing.T, c chan []byte, docs ...bsoncore.Document) { for _, doc := range docs { - reply, err := internal.MakeReply(doc) - if err != nil { - t.Fatalf("error constructing reply: %v", err) - } - + reply := drivertest.MakeReply(doc) c <- reply } } diff --git a/x/mongo/driverlegacy/auth/plain.go b/x/mongo/driverlegacy/auth/plain.go index 4f76eacc93..9174859b63 100644 --- a/x/mongo/driverlegacy/auth/plain.go +++ b/x/mongo/driverlegacy/auth/plain.go @@ -9,8 +9,8 @@ package auth import ( "context" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // PLAIN is the mechanism name for PLAIN. @@ -30,8 +30,8 @@ type PlainAuthenticator struct { } // Auth authenticates the connection. -func (a *PlainAuthenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { - return ConductSaslConversation(ctx, desc, rw, "$external", &plainSaslClient{ +func (a *PlainAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error { + return ConductSaslConversation(ctx, conn, "$external", &plainSaslClient{ username: a.Username, password: a.Password, }) diff --git a/x/mongo/driverlegacy/auth/plain_test.go b/x/mongo/driverlegacy/auth/plain_test.go index c64f42eef5..c5c5f69c57 100644 --- a/x/mongo/driverlegacy/auth/plain_test.go +++ b/x/mongo/driverlegacy/auth/plain_test.go @@ -13,11 +13,10 @@ import ( "encoding/base64" - "go.mongodb.org/mongo-driver/internal" - "go.mongodb.org/mongo-driver/x/bsonx" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" . "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) func TestPlainAuthenticator_Fails(t *testing.T) { @@ -28,22 +27,27 @@ func TestPlainAuthenticator_Fails(t *testing.T) { Password: "pencil", } - resps := make(chan wiremessage.WireMessage, 1) - writeReplies(t, resps, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"conversationId", bsonx.Int32(1)}, - {"payload", bsonx.Binary(0x00, []byte{})}, - {"code", bsonx.Int32(143)}, - {"done", bsonx.Boolean(true)}, - }) + resps := make(chan []byte, 1) + writeReplies(t, resps, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendInt32Element(nil, "conversationId", 1), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}), + bsoncore.AppendInt32Element(nil, "code", 143), + bsoncore.AppendBooleanElement(nil, "done", true), + )) - c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps} - - err := authenticator.Auth(context.Background(), description.Server{ + desc := description.Server{ WireVersion: &description.VersionRange{ Max: 6, }, - }, c) + } + c := &drivertest.ChannelConn{ + Written: make(chan []byte, 1), + ReadResp: resps, + Desc: desc, + } + + err := authenticator.Auth(context.Background(), desc, c) if err == nil { t.Fatalf("expected an error but got none") } @@ -62,26 +66,31 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) { Password: "pencil", } - resps := make(chan wiremessage.WireMessage, 2) - writeReplies(t, resps, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"conversationId", bsonx.Int32(1)}, - {"payload", bsonx.Binary(0x00, []byte{})}, - {"done", bsonx.Boolean(false)}, - }, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"conversationId", bsonx.Int32(1)}, - {"payload", bsonx.Binary(0x00, []byte{})}, - {"done", bsonx.Boolean(true)}, - }) - - c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps} - - err := authenticator.Auth(context.Background(), description.Server{ + resps := make(chan []byte, 2) + writeReplies(t, resps, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendInt32Element(nil, "conversationId", 1), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}), + bsoncore.AppendBooleanElement(nil, "done", false), + ), bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendInt32Element(nil, "conversationId", 1), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}), + bsoncore.AppendBooleanElement(nil, "done", true), + )) + + desc := description.Server{ WireVersion: &description.VersionRange{ Max: 6, }, - }, c) + } + c := &drivertest.ChannelConn{ + Written: make(chan []byte, 1), + ReadResp: resps, + Desc: desc, + } + + err := authenticator.Auth(context.Background(), desc, c) if err == nil { t.Fatalf("expected an error but got none") } @@ -100,21 +109,26 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) { Password: "pencil", } - resps := make(chan wiremessage.WireMessage, 1) - writeReplies(t, resps, bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - {"conversationId", bsonx.Int32(1)}, - {"payload", bsonx.Binary(0x00, []byte{})}, - {"done", bsonx.Boolean(true)}, - }) - - c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps} + resps := make(chan []byte, 1) + writeReplies(t, resps, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendInt32Element(nil, "conversationId", 1), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}), + bsoncore.AppendBooleanElement(nil, "done", true), + )) - err := authenticator.Auth(context.Background(), description.Server{ + desc := description.Server{ WireVersion: &description.VersionRange{ Max: 6, }, - }, c) + } + c := &drivertest.ChannelConn{ + Written: make(chan []byte, 1), + ReadResp: resps, + Desc: desc, + } + + err := authenticator.Auth(context.Background(), desc, c) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } @@ -124,10 +138,10 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) { } payload, _ := base64.StdEncoding.DecodeString("AHVzZXIAcGVuY2ls") - expectedCmd := bsonx.Doc{ - {"saslStart", bsonx.Int32(1)}, - {"mechanism", bsonx.String("PLAIN")}, - {"payload", bsonx.Binary(0x00, payload)}, - } + expectedCmd := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "saslStart", 1), + bsoncore.AppendStringElement(nil, "mechanism", "PLAIN"), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + ) compareResponses(t, <-c.Written, expectedCmd, "$external") } diff --git a/x/mongo/driverlegacy/auth/sasl.go b/x/mongo/driverlegacy/auth/sasl.go index 98ee4cfb5d..19e756f5e4 100644 --- a/x/mongo/driverlegacy/auth/sasl.go +++ b/x/mongo/driverlegacy/auth/sasl.go @@ -10,10 +10,8 @@ import ( "context" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/x/bsonx" - "go.mongodb.org/mongo-driver/x/network/command" - "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) // SaslClient is the client piece of a sasl conversation. @@ -30,7 +28,7 @@ type SaslClientCloser interface { } // ConductSaslConversation handles running a sasl conversation with MongoDB. -func ConductSaslConversation(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter, db string, client SaslClient) error { +func ConductSaslConversation(ctx context.Context, conn driver.Connection, db string, client SaslClient) error { if db == "" { db = defaultAuthDB @@ -45,14 +43,12 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi return newError(err, mech) } - saslStartCmd := command.Read{ - DB: db, - Command: bsonx.Doc{ - {"saslStart", bsonx.Int32(1)}, - {"mechanism", bsonx.String(mech)}, - {"payload", bsonx.Binary(0x00, payload)}, - }, - } + doc := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "saslStart", 1), + bsoncore.AppendStringElement(nil, "mechanism", mech), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + ) + saslStartCmd := driver.Command(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn}) type saslResponse struct { ConversationID int `bson:"conversationId"` @@ -63,11 +59,11 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi var saslResp saslResponse - ssdesc := description.SelectedServer{Server: desc} - rdr, err := saslStartCmd.RoundTrip(ctx, ssdesc, rw) + err = saslStartCmd.Execute(ctx) if err != nil { return newError(err, mech) } + rdr := saslStartCmd.Result() err = bson.Unmarshal(rdr, &saslResp) if err != nil { @@ -94,19 +90,18 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi return nil } - saslContinueCmd := command.Read{ - DB: db, - Command: bsonx.Doc{ - {"saslContinue", bsonx.Int32(1)}, - {"conversationId", bsonx.Int32(int32(cid))}, - {"payload", bsonx.Binary(0x00, payload)}, - }, - } + doc := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "saslContinue", 1), + bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + ) + saslContinueCmd := driver.Command(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn}) - rdr, err = saslContinueCmd.RoundTrip(ctx, ssdesc, rw) + err = saslContinueCmd.Execute(ctx) if err != nil { return newError(err, mech) } + rdr = saslContinueCmd.Result() err = bson.Unmarshal(rdr, &saslResp) if err != nil { diff --git a/x/mongo/driverlegacy/auth/scram.go b/x/mongo/driverlegacy/auth/scram.go index fa6d51e7ed..c6a31818a0 100644 --- a/x/mongo/driverlegacy/auth/scram.go +++ b/x/mongo/driverlegacy/auth/scram.go @@ -18,8 +18,8 @@ import ( "github.com/xdg/scram" "github.com/xdg/stringprep" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1" @@ -67,9 +67,9 @@ type ScramAuthenticator struct { } // Auth authenticates the connection. -func (a *ScramAuthenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { +func (a *ScramAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error { adapter := &scramSaslAdapter{conversation: a.client.NewConversation(), mechanism: a.mechanism} - err := ConductSaslConversation(ctx, desc, rw, a.source, adapter) + err := ConductSaslConversation(ctx, conn, a.source, adapter) if err != nil { return newAuthError("sasl conversation error", err) } diff --git a/x/mongo/driverlegacy/auth/x509.go b/x/mongo/driverlegacy/auth/x509.go index 3a681428ad..a46aea7e29 100644 --- a/x/mongo/driverlegacy/auth/x509.go +++ b/x/mongo/driverlegacy/auth/x509.go @@ -9,10 +9,9 @@ package auth import ( "context" - "go.mongodb.org/mongo-driver/x/bsonx" - "go.mongodb.org/mongo-driver/x/network/command" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/description" - "go.mongodb.org/mongo-driver/x/network/wiremessage" ) // MongoDBX509 is the mechanism name for MongoDBX509. @@ -28,19 +27,19 @@ type MongoDBX509Authenticator struct { } // Auth implements the Authenticator interface. -func (a *MongoDBX509Authenticator) Auth(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter) error { - authRequestDoc := bsonx.Doc{ - {"authenticate", bsonx.Int32(1)}, - {"mechanism", bsonx.String(MongoDBX509)}, - } +func (a *MongoDBX509Authenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error { + requestDoc := bsoncore.AppendInt32Element(nil, "authenticate", 1) + requestDoc = bsoncore.AppendStringElement(requestDoc, "mechanism", MongoDBX509) - if desc.WireVersion.Max < 5 { - authRequestDoc = append(authRequestDoc, bsonx.Elem{"user", bsonx.String(a.User)}) + if desc.WireVersion == nil || desc.WireVersion.Max < 5 { + requestDoc = bsoncore.AppendStringElement(requestDoc, "user", a.User) } - authCmd := command.Read{DB: "$external", Command: authRequestDoc} - ssdesc := description.SelectedServer{Server: desc} - _, err := authCmd.RoundTrip(ctx, ssdesc, rw) + authCmd := driver. + Command(bsoncore.BuildDocument(nil, requestDoc)). + Database("$external"). + Deployment(driver.SingleConnectionDeployment{conn}) + err := authCmd.Execute(ctx) if err != nil { return newAuthError("round trip error", err) } diff --git a/x/mongo/driverlegacy/batch_cursor.go b/x/mongo/driverlegacy/batch_cursor.go index 53a8aa60f1..48fe2ab9d9 100644 --- a/x/mongo/driverlegacy/batch_cursor.go +++ b/x/mongo/driverlegacy/batch_cursor.go @@ -195,7 +195,7 @@ func (bc *BatchCursor) Close(ctx context.Context) error { } defer bc.closeImplicitSession() - conn, err := bc.server.Connection(ctx) + conn, err := bc.server.ConnectionLegacy(ctx) if err != nil { return err } @@ -234,7 +234,7 @@ func (bc *BatchCursor) getMore(ctx context.Context) { return } - conn, err := bc.server.Connection(ctx) + conn, err := bc.server.ConnectionLegacy(ctx) if err != nil { bc.err = err return @@ -299,7 +299,7 @@ func (bc *BatchCursor) legacy() bool { } func (bc *BatchCursor) legacyKillCursor(ctx context.Context) error { - conn, err := bc.server.Connection(ctx) + conn, err := bc.server.ConnectionLegacy(ctx) if err != nil { return err } @@ -333,7 +333,7 @@ func (bc *BatchCursor) legacyGetMore(ctx context.Context) { return } - conn, err := bc.server.Connection(ctx) + conn, err := bc.server.ConnectionLegacy(ctx) if err != nil { bc.err = err return diff --git a/x/mongo/driverlegacy/commit_transaction.go b/x/mongo/driverlegacy/commit_transaction.go index fb99cffa59..0919dcda94 100644 --- a/x/mongo/driverlegacy/commit_transaction.go +++ b/x/mongo/driverlegacy/commit_transaction.go @@ -78,7 +78,7 @@ func commitTransaction( return result.TransactionResult{}, oldErr } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.TransactionResult{}, oldErr diff --git a/x/mongo/driverlegacy/count.go b/x/mongo/driverlegacy/count.go index 4b5191e06e..2966fdbb2b 100644 --- a/x/mongo/driverlegacy/count.go +++ b/x/mongo/driverlegacy/count.go @@ -40,7 +40,7 @@ func Count( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return 0, err } diff --git a/x/mongo/driverlegacy/count_documents.go b/x/mongo/driverlegacy/count_documents.go index ed94096639..4976323f90 100644 --- a/x/mongo/driverlegacy/count_documents.go +++ b/x/mongo/driverlegacy/count_documents.go @@ -42,7 +42,7 @@ func CountDocuments( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return 0, err } diff --git a/x/mongo/driverlegacy/create_indexes.go b/x/mongo/driverlegacy/create_indexes.go index ef4fec1fa6..680732eb4c 100644 --- a/x/mongo/driverlegacy/create_indexes.go +++ b/x/mongo/driverlegacy/create_indexes.go @@ -43,7 +43,7 @@ func CreateIndexes( return result.CreateIndexes{}, ErrCollation } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return result.CreateIndexes{}, err } diff --git a/x/mongo/driverlegacy/delete.go b/x/mongo/driverlegacy/delete.go index d4ba908151..d42529c64e 100644 --- a/x/mongo/driverlegacy/delete.go +++ b/x/mongo/driverlegacy/delete.go @@ -98,7 +98,7 @@ func delete( ) (result.Delete, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.Delete{}, oldErr diff --git a/x/mongo/driverlegacy/delete_indexes.go b/x/mongo/driverlegacy/delete_indexes.go index 915e9a98ae..47e996282f 100644 --- a/x/mongo/driverlegacy/delete_indexes.go +++ b/x/mongo/driverlegacy/delete_indexes.go @@ -38,7 +38,7 @@ func DropIndexes( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/distinct.go b/x/mongo/driverlegacy/distinct.go index c233723aeb..94bb887902 100644 --- a/x/mongo/driverlegacy/distinct.go +++ b/x/mongo/driverlegacy/distinct.go @@ -41,7 +41,7 @@ func Distinct( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return result.Distinct{}, err } diff --git a/x/mongo/driverlegacy/drop_collection.go b/x/mongo/driverlegacy/drop_collection.go index d13e516791..7046f96b49 100644 --- a/x/mongo/driverlegacy/drop_collection.go +++ b/x/mongo/driverlegacy/drop_collection.go @@ -33,7 +33,7 @@ func DropCollection( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/drop_database.go b/x/mongo/driverlegacy/drop_database.go index 408bbd040e..89b7710206 100644 --- a/x/mongo/driverlegacy/drop_database.go +++ b/x/mongo/driverlegacy/drop_database.go @@ -33,7 +33,7 @@ func DropDatabase( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/end_sessions.go b/x/mongo/driverlegacy/end_sessions.go index b9a99aa30e..b8baee0f60 100644 --- a/x/mongo/driverlegacy/end_sessions.go +++ b/x/mongo/driverlegacy/end_sessions.go @@ -30,7 +30,7 @@ func EndSessions( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, []error{err} } diff --git a/x/mongo/driverlegacy/find.go b/x/mongo/driverlegacy/find.go index 0fdcd4b6b9..971775f071 100644 --- a/x/mongo/driverlegacy/find.go +++ b/x/mongo/driverlegacy/find.go @@ -49,7 +49,7 @@ func Find( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/find_one_and_delete.go b/x/mongo/driverlegacy/find_one_and_delete.go index 4311f7d8be..dd0aa3c4cf 100644 --- a/x/mongo/driverlegacy/find_one_and_delete.go +++ b/x/mongo/driverlegacy/find_one_and_delete.go @@ -121,7 +121,7 @@ func findOneAndDelete( oldErr error, ) (result.FindAndModify, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.FindAndModify{}, oldErr diff --git a/x/mongo/driverlegacy/find_one_and_replace.go b/x/mongo/driverlegacy/find_one_and_replace.go index 2633bb0e83..4bc9c91335 100644 --- a/x/mongo/driverlegacy/find_one_and_replace.go +++ b/x/mongo/driverlegacy/find_one_and_replace.go @@ -130,7 +130,7 @@ func findOneAndReplace( oldErr error, ) (result.FindAndModify, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.FindAndModify{}, oldErr diff --git a/x/mongo/driverlegacy/find_one_and_update.go b/x/mongo/driverlegacy/find_one_and_update.go index 36d854317b..0143f6de2f 100644 --- a/x/mongo/driverlegacy/find_one_and_update.go +++ b/x/mongo/driverlegacy/find_one_and_update.go @@ -146,7 +146,7 @@ func findOneAndUpdate( oldErr error, ) (result.FindAndModify, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.FindAndModify{}, oldErr diff --git a/x/mongo/driverlegacy/insert.go b/x/mongo/driverlegacy/insert.go index 44d0aba487..d0df37a4e0 100644 --- a/x/mongo/driverlegacy/insert.go +++ b/x/mongo/driverlegacy/insert.go @@ -97,7 +97,7 @@ func insert( oldErr error, ) (result.Insert, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.Insert{}, oldErr diff --git a/x/mongo/driverlegacy/kill_cursors.go b/x/mongo/driverlegacy/kill_cursors.go index 34e92ed6d4..3eea0bd259 100644 --- a/x/mongo/driverlegacy/kill_cursors.go +++ b/x/mongo/driverlegacy/kill_cursors.go @@ -26,7 +26,7 @@ func KillCursors( cursorID int64, ) (result.KillCursors, error) { desc := server.SelectedDescription() - conn, err := server.Connection(ctx) + conn, err := server.ConnectionLegacy(ctx) if err != nil { return result.KillCursors{}, err } diff --git a/x/mongo/driverlegacy/list_collections.go b/x/mongo/driverlegacy/list_collections.go index 3459e2609b..53e6df1218 100644 --- a/x/mongo/driverlegacy/list_collections.go +++ b/x/mongo/driverlegacy/list_collections.go @@ -42,7 +42,7 @@ func ListCollections( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/list_databases.go b/x/mongo/driverlegacy/list_databases.go index dca709d2b5..766b45689f 100644 --- a/x/mongo/driverlegacy/list_databases.go +++ b/x/mongo/driverlegacy/list_databases.go @@ -36,7 +36,7 @@ func ListDatabases( return result.ListDatabases{}, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return result.ListDatabases{}, err } diff --git a/x/mongo/driverlegacy/list_indexes.go b/x/mongo/driverlegacy/list_indexes.go index 7fc17b7442..fce2bb103e 100644 --- a/x/mongo/driverlegacy/list_indexes.go +++ b/x/mongo/driverlegacy/list_indexes.go @@ -39,7 +39,7 @@ func ListIndexes( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/read.go b/x/mongo/driverlegacy/read.go index 3d6870cc18..858988535d 100644 --- a/x/mongo/driverlegacy/read.go +++ b/x/mongo/driverlegacy/read.go @@ -37,7 +37,7 @@ func Read( return nil, err } - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/read_cursor.go b/x/mongo/driverlegacy/read_cursor.go index 0babf979d9..7404c564ec 100644 --- a/x/mongo/driverlegacy/read_cursor.go +++ b/x/mongo/driverlegacy/read_cursor.go @@ -39,7 +39,7 @@ func ReadCursor( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/mongo/driverlegacy/topology/DESIGN.md b/x/mongo/driverlegacy/topology/DESIGN.md index 013c845d3a..148ed22e11 100644 --- a/x/mongo/driverlegacy/topology/DESIGN.md +++ b/x/mongo/driverlegacy/topology/DESIGN.md @@ -1,6 +1,9 @@ # Topology Package Design This document outlines the design for this package. +## Server +The `Server` type handles heartbeating a MongoDB server and holds a pool of connections. + ## Connection Connections are handled by two main types and an auxiliary type. The two main types are `connection` and `Connection`. The first holds most of the logic required to actually read and write wire diff --git a/x/mongo/driverlegacy/topology/connection.go b/x/mongo/driverlegacy/topology/connection.go index 405f11315c..1cb686c74a 100644 --- a/x/mongo/driverlegacy/topology/connection.go +++ b/x/mongo/driverlegacy/topology/connection.go @@ -58,14 +58,14 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String()) if err != nil { - return nil, err + return nil, ConnectionError{Wrapped: err, init: true} } if cfg.tlsConfig != nil { tlsConfig := cfg.tlsConfig.Clone() nc, err = configureTLS(ctx, nc, addr, tlsConfig) if err != nil { - return nil, err + return nil, ConnectionError{Wrapped: err, init: true} } } @@ -92,7 +92,11 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection if cfg.handshaker != nil { c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c}) if err != nil { - return nil, err + c.nc.Close() + return nil, ConnectionError{Wrapped: err, init: true} + } + if cfg.descCallback != nil { + cfg.descCallback(c.desc) } } return c, nil @@ -124,9 +128,7 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { _, err = c.nc.Write(wm) if err != nil { - // TODO(GODRIVER-929): Close connection through the pool. - _ = c.nc.Close() - c.nc = nil + c.close() return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to write wire message to network"} } @@ -143,9 +145,7 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e select { case <-ctx.Done(): // We close the connection because we don't know if there is an unread message on the wire. - // TODO(GODRIVER-929): Close connection through the pool. - _ = c.nc.Close() - c.nc = nil + c.close() return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"} default: } @@ -173,9 +173,7 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e _, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { // We close the connection because we don't know if there are other bytes left to read. - // TODO(GODRIVER-929): Close connection through the pool. - _ = c.nc.Close() - c.nc = nil + c.close() return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to decode message length"} } @@ -194,9 +192,7 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e _, err = io.ReadFull(c.nc, dst[4:]) if err != nil { // We close the connection because we don't know if there are other bytes left to read. - // TODO(GODRIVER-929): Close connection through the pool. - _ = c.nc.Close() - c.nc = nil + c.close() return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to read full message"} } @@ -204,6 +200,18 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e return dst, nil } +func (c *connection) close() error { + if c.nc == nil { + return nil + } + if c.pool == nil { + err := c.nc.Close() + c.nc = nil + return err + } + return c.pool.close(c) +} + func (c *connection) expired() bool { now := time.Now() if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) { @@ -231,7 +239,7 @@ type initConnection struct{ *connection } var _ driver.Connection = initConnection{} func (c initConnection) Description() description.Server { return description.Server{} } -func (c initConnection) Close() error { return c.nc.Close() } +func (c initConnection) Close() error { return nil } func (c initConnection) ID() string { return c.id } func (c initConnection) Address() address.Address { return c.addr } func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error { @@ -291,7 +299,9 @@ func (c *Connection) Close() error { if c.connection == nil { return nil } - // TODO(GODRIVER-932): Release an entry in the semaphore. + if c.s != nil { + c.s.sem.Release(1) + } err := c.pool.put(c.connection) if err != nil { return err @@ -358,7 +368,7 @@ func (sc *sconn) processErr(err error) { // updates description to unknown sc.s.updateDescription(desc, false) sc.s.RequestImmediateCheck() - _ = sc.s.pool.Drain() + sc.s.pool.drain() return } diff --git a/x/mongo/driverlegacy/topology/connection_legacy.go b/x/mongo/driverlegacy/topology/connection_legacy.go index e7d6de8676..1e957f137b 100644 --- a/x/mongo/driverlegacy/topology/connection_legacy.go +++ b/x/mongo/driverlegacy/topology/connection_legacy.go @@ -10,6 +10,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/network/command" "go.mongodb.org/mongo-driver/x/network/compressor" "go.mongodb.org/mongo-driver/x/network/wiremessage" ) @@ -30,10 +31,12 @@ type connectionLegacy struct { // server can compress response with any compressor supported by driver compressorMap map[wiremessage.CompressorID]compressor.Compressor + s *Server + sync.RWMutex } -func newConnectionLegacy(c *connection, opts ...ConnectionOption) (*connectionLegacy, error) { +func newConnectionLegacy(c *connection, s *Server, opts ...ConnectionOption) (*connectionLegacy, error) { cfg, err := newConnectionConfig(opts...) if err != nil { return nil, err @@ -65,6 +68,8 @@ func newConnectionLegacy(c *connection, opts ...ConnectionOption) (*connectionLe uncompressBuf: make([]byte, 256), writeBuf: make([]byte, 0, 256), wireMessageBuf: make([]byte, 256), + + s: s, } d := c.desc @@ -117,6 +122,9 @@ func (c *connectionLegacy) WriteWireMessage(ctx context.Context, wm wiremessage. } err = c.writeWireMessage(ctx, c.writeBuf) + if c.s != nil { + c.s.ProcessError(err) + } if err != nil { // The error we got back was probably a ConnectionError already, so we don't really need to // wrap it here. @@ -136,6 +144,9 @@ func (c *connectionLegacy) ReadWireMessage(ctx context.Context) (wiremessage.Wir var err error c.readBuf, err = c.readWireMessage(ctx, c.readBuf) + if c.s != nil { + c.s.ProcessError(err) + } if err != nil { // The error we got back was probably a ConnectionError already, so we don't really need to // wrap it here. @@ -211,6 +222,10 @@ func (c *connectionLegacy) ReadWireMessage(ctx context.Context) (wiremessage.Wir } } + if c.s != nil { + c.s.ProcessError(command.DecodeError(wm)) + } + // TODO: do we care if monitoring fails? return wm, c.commandFinishedEvent(ctx, wm) } @@ -221,7 +236,9 @@ func (c *connectionLegacy) Close() error { if c.connection == nil { return nil } - // TODO(GODRIVER-932): Release an entry in the semaphore. + if c.s != nil { + c.s.sem.Release(1) + } err := c.pool.put(c.connection) if err != nil { return err diff --git a/x/mongo/driverlegacy/topology/connection_options.go b/x/mongo/driverlegacy/topology/connection_options.go index d3b5749b82..5b9ee6778a 100644 --- a/x/mongo/driverlegacy/topology/connection_options.go +++ b/x/mongo/driverlegacy/topology/connection_options.go @@ -8,7 +8,6 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/network/address" "go.mongodb.org/mongo-driver/x/network/description" ) @@ -34,18 +33,11 @@ var DefaultDialer Dialer = &net.Dialer{} // Handshaker is the interface implemented by types that can perform a MongoDB // handshake over a provided driver.Connection. This is used during connection // initialization. Implementations must be goroutine safe. -type Handshaker interface { - Handshake(context.Context, address.Address, driver.Connection) (description.Server, error) -} +type Handshaker = driver.Handshaker // HandshakerFunc is an adapter to allow the use of ordinary functions as // connection handshakers. -type HandshakerFunc func(context.Context, address.Address, driver.Connection) (description.Server, error) - -// Handshake implements the Handshaker interface. -func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { - return hf(ctx, addr, conn) -} +type HandshakerFunc = driver.HandshakerFunc type connectionConfig struct { appName string @@ -60,6 +52,7 @@ type connectionConfig struct { tlsConfig *tls.Config compressors []string zlibLevel *int + descCallback func(description.Server) } func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { @@ -84,6 +77,13 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { return cfg, nil } +func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption { + return append(opts, ConnectionOption(func(c *connectionConfig) error { + c.descCallback = callback + return nil + })) +} + // ConnectionOption is used to configure a connection. type ConnectionOption func(*connectionConfig) error diff --git a/x/mongo/driverlegacy/topology/connection_test.go b/x/mongo/driverlegacy/topology/connection_test.go index f1a254f6ac..8a5d410313 100644 --- a/x/mongo/driverlegacy/topology/connection_test.go +++ b/x/mongo/driverlegacy/topology/connection_test.go @@ -64,7 +64,7 @@ func (c connect) ID() string { // Test case for sconn processErr func TestConnectionProcessErrSpec(t *testing.T) { ctx := context.Background() - s, err := NewServer(address.Address("localhost"), nil) + s, err := NewServer(address.Address("localhost")) require.NoError(t, err) desc := s.Description() @@ -94,25 +94,27 @@ func TestConnection(t *testing.T) { } }) t.Run("dialer error", func(t *testing.T) { - want := errors.New("dialer error") + err := errors.New("dialer error") + var want error = ConnectionError{Wrapped: err} _, got := newConnection(context.Background(), address.Address(""), WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, want }) + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, err }) })) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) t.Run("handshaker error", func(t *testing.T) { - want := errors.New("handshaker error") + err := errors.New("handshaker error") + var want error = ConnectionError{Wrapped: err} _, got := newConnection(context.Background(), address.Address(""), WithHandshaker(func(Handshaker) Handshaker { return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) { - return description.Server{}, want + return description.Server{}, err }) }), WithDialer(func(Dialer) Dialer { return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Conn(nil), nil + return &net.TCPConn{}, nil }) }), ) @@ -120,6 +122,28 @@ func TestConnection(t *testing.T) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) + t.Run("calls description callback", func(t *testing.T) { + want := description.Server{Addr: address.Address("1.2.3.4:56789")} + var got description.Server + _, err := newConnection(context.Background(), address.Address(""), + withServerDescriptionCallback(func(desc description.Server) { got = desc }, + WithHandshaker(func(Handshaker) Handshaker { + return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) { + return want, nil + }) + }), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return &net.TCPConn{}, nil + }) + }), + )..., + ) + noerr(t, err) + if !cmp.Equal(got, want) { + t.Errorf("Server descriptions do not match. got %v; want %v", got, want) + } + }) }) t.Run("writeWireMessage", func(t *testing.T) { t.Run("closed connection", func(t *testing.T) { diff --git a/x/mongo/driverlegacy/topology/errors.go b/x/mongo/driverlegacy/topology/errors.go index a6fbf12685..34a4a8c1cb 100644 --- a/x/mongo/driverlegacy/topology/errors.go +++ b/x/mongo/driverlegacy/topology/errors.go @@ -7,6 +7,9 @@ type ConnectionError struct { ConnectionID string Wrapped error + // init will be set to true if this error occured during connection initialization or + // during a connection handshake. + init bool message string } diff --git a/x/mongo/driverlegacy/topology/pool_test.go b/x/mongo/driverlegacy/topology/pool_test.go index c17422962b..1b293d4df1 100644 --- a/x/mongo/driverlegacy/topology/pool_test.go +++ b/x/mongo/driverlegacy/topology/pool_test.go @@ -275,8 +275,9 @@ func TestPool(t *testing.T) { close(cleanup) }) t.Run("return error when attempting to create new connection", func(t *testing.T) { - want := errors.New("create new connection error") - var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) { return nil, want } + wanterr := errors.New("create new connection error") + var want error = ConnectionError{Wrapped: wanterr, init: true} + var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) { return nil, wanterr } p := newPool(address.Address(""), 2, WithDialer(func(Dialer) Dialer { return dialer })) err := p.connect() noerr(t, err) diff --git a/x/mongo/driverlegacy/topology/server.go b/x/mongo/driverlegacy/topology/server.go index 71c8f77496..5dabf0f1e6 100644 --- a/x/mongo/driverlegacy/topology/server.go +++ b/x/mongo/driverlegacy/topology/server.go @@ -11,18 +11,20 @@ import ( "errors" "fmt" "math" + "net" "strings" "sync" "sync/atomic" "time" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/network/address" "go.mongodb.org/mongo-driver/x/network/command" connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/description" "go.mongodb.org/mongo-driver/x/network/result" + "golang.org/x/sync/semaphore" ) const minHeartbeatInterval = 500 * time.Millisecond @@ -80,37 +82,40 @@ func connectionStateString(state int32) string { // Server is a single server within a topology. type Server struct { - cfg *serverConfig - address address.Address - + cfg *serverConfig + address address.Address connectionstate int32 - done chan struct{} - checkNow chan struct{} - closewg sync.WaitGroup - pool connectionlegacy.Pool - desc atomic.Value // holds a description.Server + // connection related fields + pool *pool + sem *semaphore.Weighted + + // goroutine management fields + done chan struct{} + checkNow chan struct{} + closewg sync.WaitGroup - averageRTTSet bool - averageRTT time.Duration + // description related fields + desc atomic.Value // holds a description.Server + updateTopologyCallback atomic.Value + averageRTTSet bool + averageRTT time.Duration + // subscriber related fields subLock sync.Mutex subscribers map[uint64]chan description.Server currentSubscriberID uint64 - subscriptionsClosed bool - - updateTopologyCallback atomic.Value } // ConnectServer creates a new Server and then initializes it using the // Connect method. -func ConnectServer(ctx context.Context, addr address.Address, topo func(description.Server), opts ...ServerOption) (*Server, error) { - srvr, err := NewServer(addr, topo, opts...) +func ConnectServer(addr address.Address, updateCallback func(description.Server), opts ...ServerOption) (*Server, error) { + srvr, err := NewServer(addr, opts...) if err != nil { return nil, err } - err = srvr.Connect(ctx) + err = srvr.Connect(updateCallback) if err != nil { return nil, err } @@ -119,49 +124,47 @@ func ConnectServer(ctx context.Context, addr address.Address, topo func(descript // NewServer creates a new server. The mongodb server at the address will be monitored // on an internal monitoring goroutine. -func NewServer(addr address.Address, topo func(description.Server), opts ...ServerOption) (*Server, error) { +func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) { cfg, err := newServerConfig(opts...) if err != nil { return nil, err } + var maxConns = uint64(cfg.maxConns) + if maxConns == 0 { + maxConns = math.MaxInt64 + } + s := &Server{ cfg: cfg, address: addr, + sem: semaphore.NewWeighted(int64(maxConns)), + done: make(chan struct{}), checkNow: make(chan struct{}, 1), subscribers: make(map[uint64]chan description.Server), } s.desc.Store(description.Server{Addr: addr}) - s.updateTopologyCallback.Store(topo) - var maxConns uint64 - if cfg.maxConns == 0 { - maxConns = math.MaxInt64 - } else { - maxConns = uint64(cfg.maxConns) - } - - s.pool, err = connectionlegacy.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...) - if err != nil { - return nil, err - } + callback := func(desc description.Server) { s.updateDescription(desc, false) } + s.pool = newPool(addr, uint64(cfg.maxIdleConns), withServerDescriptionCallback(callback, cfg.connectionOpts...)...) return s, nil } // Connect initializes the Server by starting background monitoring goroutines. // This method must be called before a Server can be used. -func (s *Server) Connect(ctx context.Context) error { +func (s *Server) Connect(updateCallback func(description.Server)) error { if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) { return ErrServerConnected } s.desc.Store(description.Server{Addr: s.address}) + s.updateTopologyCallback.Store(updateCallback) go s.update() s.closewg.Add(1) - return s.pool.Connect(ctx) + return s.pool.connect() } // Disconnect closes sockets to the server referenced by this Server. @@ -183,7 +186,7 @@ func (s *Server) Disconnect(ctx context.Context) error { // For every call to Connect there must be at least 1 goroutine that is // waiting on the done channel. s.done <- struct{}{} - err := s.pool.Disconnect(ctx) + err := s.pool.disconnect(ctx) if err != nil { return err } @@ -195,33 +198,68 @@ func (s *Server) Disconnect(ctx context.Context) error { } // Connection gets a connection to the server. -func (s *Server) Connection(ctx context.Context) (connectionlegacy.Connection, error) { +func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { if atomic.LoadInt32(&s.connectionstate) != connected { return nil, ErrServerClosed } - conn, desc, err := s.pool.Get(ctx) + + err := s.sem.Acquire(ctx, 1) if err != nil { - if _, ok := err.(*auth.Error); ok { - // authentication error --> drain connection - _ = s.pool.Drain() + return nil, err + } + + conn, err := s.pool.get(ctx) + if err != nil { + s.sem.Release(1) + connerr, ok := err.(ConnectionError) + if !ok { + return nil, err } - if _, ok := err.(*connectionlegacy.NetworkError); ok { - // update description to unknown and clears the connection pool - if desc != nil { - desc.Kind = description.Unknown - desc.LastError = err - s.updateDescription(*desc, false) - } else { - _ = s.pool.Drain() - } + + // Since the only kind of ConnectionError we receive from pool.get will be an initialization + // error, we should set the description.Server appropriately. + desc := description.Server{ + Kind: description.Unknown, + LastError: connerr.Wrapped, } + s.updateDescription(desc, false) + return nil, err } - if desc != nil { - go s.updateDescription(*desc, false) + + return &Connection{connection: conn, s: s}, nil +} + +// ConnectionLegacy gets a connection to the server. +func (s *Server) ConnectionLegacy(ctx context.Context) (connectionlegacy.Connection, error) { + if atomic.LoadInt32(&s.connectionstate) != connected { + return nil, ErrServerClosed + } + + err := s.sem.Acquire(ctx, 1) + if err != nil { + return nil, err } - sc := &sconn{Connection: conn, s: s} - return sc, nil + + conn, err := s.pool.get(ctx) + if err != nil { + s.sem.Release(1) + connerr, ok := err.(ConnectionError) + if !ok { + return nil, err + } + + // Since the only kind of ConnectionError we receive from pool.get will be an initialization + // error, we should set the description.Server appropriately. + desc := description.Server{ + Kind: description.Unknown, + LastError: connerr.Wrapped, + } + s.updateDescription(desc, false) + + return nil, err + } + return newConnectionLegacy(conn, s, s.cfg.connectionOpts...) } // Description returns a description of the server as of the last heartbeat. @@ -277,6 +315,39 @@ func (s *Server) RequestImmediateCheck() { } } +// ProcessError handles SDAM error handling and implements driver.ErrorProcessor. +func (s *Server) ProcessError(err error) { + // Invalidate server description if not master or node recovering error occurs + if cerr, ok := err.(command.Error); ok && (isRecoveringError(cerr) || isNotMasterError(cerr)) { + desc := s.Description() + desc.Kind = description.Unknown + desc.LastError = err + // updates description to unknown + s.updateDescription(desc, false) + s.RequestImmediateCheck() + s.pool.drain() + return + } + + ne, ok := err.(connectionlegacy.Error) + if !ok { + return + } + + if netErr, ok := ne.Wrapped.(net.Error); ok && netErr.Timeout() { + return + } + if ne.Wrapped == context.Canceled || ne.Wrapped == context.DeadlineExceeded { + return + } + + desc := s.Description() + desc.Kind = description.Unknown + desc.LastError = err + // updates description to unknown + s.updateDescription(desc, false) +} + // ProcessWriteConcernError checks if a WriteConcernError is an isNotMaster or // isRecovering error, and if so updates the server accordingly. func (s *Server) ProcessWriteConcernError(err *result.WriteConcernError) { @@ -289,7 +360,6 @@ func (s *Server) ProcessWriteConcernError(err *result.WriteConcernError) { // updates description to unknown s.updateDescription(desc, false) s.RequestImmediateCheck() - _ = s.pool.Drain() } func wceIsNotMasterOrRecovering(wce *result.WriteConcernError) bool { @@ -326,7 +396,7 @@ func (s *Server) update() { } }() - var conn connectionlegacy.Connection + var conn *connection var desc description.Server desc, conn = s.heartbeat(nil) @@ -341,10 +411,10 @@ func (s *Server) update() { } s.subscriptionsClosed = true s.subLock.Unlock() - if conn == nil { + if conn == nil || conn.nc == nil { return } - conn.Close() + conn.nc.Close() } for { select { @@ -378,9 +448,9 @@ func (s *Server) updateDescription(desc description.Server, initial bool) { }() s.desc.Store(desc) - topo := s.updateTopologyCallback.Load().(func(description.Server)) - if topo != nil { - topo(desc) + callback, ok := s.updateTopologyCallback.Load().(func(description.Server)) + if ok && callback != nil { + callback(desc) } s.subLock.Lock() @@ -401,12 +471,12 @@ func (s *Server) updateDescription(desc description.Server, initial bool) { switch desc.Kind { case description.Unknown: - _ = s.pool.Drain() + s.pool.drain() } } // heartbeat sends a heartbeat to the server using the given connection. The connection can be nil. -func (s *Server) heartbeat(conn connectionlegacy.Connection) (description.Server, connectionlegacy.Connection) { +func (s *Server) heartbeat(conn *connection) (description.Server, *connection) { const maxRetry = 2 var saved error var desc description.Server @@ -415,37 +485,39 @@ func (s *Server) heartbeat(conn connectionlegacy.Connection) (description.Server ctx := context.Background() for i := 1; i <= maxRetry; i++ { - if conn != nil && conn.Expired() { - conn.Close() + if conn != nil && conn.expired() { + if conn.nc != nil { + conn.nc.Close() + } conn = nil } if conn == nil { - opts := []connectionlegacy.Option{ - connectionlegacy.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - connectionlegacy.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - connectionlegacy.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + opts := []ConnectionOption{ + WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), + WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), } opts = append(opts, s.cfg.connectionOpts...) // We override whatever handshaker is currently attached to the options with an empty // one because need to make sure we don't do auth. - opts = append(opts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { + opts = append(opts, WithHandshaker(func(h Handshaker) Handshaker { return nil })) // Override any command monitors specified in options with nil to avoid monitoring heartbeats. - opts = append(opts, connectionlegacy.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { + opts = append(opts, WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil })) - conn, _, err = connectionlegacy.New(ctx, s.address, opts...) + conn, err = newConnection(ctx, s.address, opts...) if err != nil { saved = err - if conn != nil { - conn.Close() + if conn != nil && conn.nc != nil { + conn.nc.Close() } conn = nil - if _, ok := err.(*connectionlegacy.NetworkError); ok { - _ = s.pool.Drain() + if _, ok := err.(ConnectionError); ok { + s.pool.drain() // If the server is not connected, give up and exit loop if s.Description().Kind == description.Unknown { break @@ -457,15 +529,17 @@ func (s *Server) heartbeat(conn connectionlegacy.Connection) (description.Server now := time.Now() - isMasterCmd := &command.IsMaster{Compressors: s.cfg.compressionOpts} - isMaster, err := isMasterCmd.RoundTrip(ctx, conn) + op := driver.IsMaster().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).Connection(initConnection{conn}) + err = op.Execute(ctx) // we do a retry if the server is connected, if succeed return new server desc (see below) if err != nil { saved = err - conn.Close() + if conn.nc != nil { + conn.nc.Close() + } conn = nil - if _, ok := err.(connectionlegacy.NetworkError); ok { - _ = s.pool.Drain() + if _, ok := err.(ConnectionError); ok { + s.pool.drain() // If the server is not connected, give up and exit loop if s.Description().Kind == description.Unknown { break @@ -474,6 +548,8 @@ func (s *Server) heartbeat(conn connectionlegacy.Connection) (description.Server continue } + isMaster := op.Result() + clusterTime := isMaster.ClusterTime if s.cfg.clock != nil { s.cfg.clock.AdvanceClusterTime(clusterTime) @@ -514,7 +590,13 @@ func (s *Server) updateAverageRTT(delay time.Duration) time.Duration { // This is exposed here so we don't have to wrap the Connection type and sniff responses // for errors that would cause the pool to be drained, which can in turn centralize the // logic for handling errors in the Client type. -func (s *Server) Drain() error { return s.pool.Drain() } +// +// TODO(GODRIVER-617): I don't think we actually need this method. It's likely replaced by +// ProcessError. +func (s *Server) Drain() error { + s.pool.drain() + return nil +} // String implements the Stringer interface. func (s *Server) String() string { diff --git a/x/mongo/driverlegacy/topology/server_options.go b/x/mongo/driverlegacy/topology/server_options.go index 1c0bb2ed9d..53f237ccdf 100644 --- a/x/mongo/driverlegacy/topology/server_options.go +++ b/x/mongo/driverlegacy/topology/server_options.go @@ -12,7 +12,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/session" - connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" ) var defaultRegistry = bson.NewRegistryBuilder().Build() @@ -20,7 +19,7 @@ var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { clock *session.ClusterClock compressionOpts []string - connectionOpts []connectionlegacy.Option + connectionOpts []ConnectionOption appname string heartbeatInterval time.Duration heartbeatTimeout time.Duration @@ -52,7 +51,7 @@ func newServerConfig(opts ...ServerOption) (*serverConfig, error) { type ServerOption func(*serverConfig) error // WithConnectionOptions configures the server's connections. -func WithConnectionOptions(fn func(...connectionlegacy.Option) []connectionlegacy.Option) ServerOption { +func WithConnectionOptions(fn func(...ConnectionOption) []ConnectionOption) ServerOption { return func(cfg *serverConfig) error { cfg.connectionOpts = fn(cfg.connectionOpts...) return nil diff --git a/x/mongo/driverlegacy/topology/server_test.go b/x/mongo/driverlegacy/topology/server_test.go index 6bb5e721d5..02bd248c66 100644 --- a/x/mongo/driverlegacy/topology/server_test.go +++ b/x/mongo/driverlegacy/topology/server_test.go @@ -8,10 +8,13 @@ package topology import ( "context" + "net" "sync/atomic" "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/network/address" connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" @@ -72,9 +75,35 @@ func TestServer(t *testing.T) { {"network_error_desc", false, true, true}, } + authErr := ConnectionError{Wrapped: &auth.Error{}} + netErr := ConnectionError{Wrapped: &net.AddrError{}} for _, tt := range serverTestTable { t.Run(tt.name, func(t *testing.T) { - s, err := NewServer(address.Address("localhost"), nil) + s, err := NewServer( + address.Address("localhost"), + WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { + return append(connOpts, + WithHandshaker(func(Handshaker) Handshaker { + return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) { + var err error + if tt.connectionError { + err = authErr.Wrapped + } + return description.Server{}, err + }) + }), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + var err error + if tt.networkError { + err = netErr.Wrapped + } + return &net.TCPConn{}, err + }) + }), + ) + }), + ) require.NoError(t, err) var desc *description.Server @@ -83,35 +112,40 @@ func TestServer(t *testing.T) { desc = &descript require.Nil(t, desc.LastError) } - s.pool, err = NewTestPool(tt.connectionError, tt.networkError, desc) s.connectionstate = connected + s.pool.connected = connected - _, err = s.Connection(context.Background()) + _, err = s.ConnectionLegacy(context.Background()) - if tt.connectionError || tt.networkError { - require.Error(t, err) - } else { - require.NoError(t, err) + switch { + case tt.connectionError && !cmp.Equal(err, authErr, cmp.Comparer(compareErrors)): + t.Errorf("Expected connection error. got %v; want %v", err, authErr) + case tt.networkError && !cmp.Equal(err, netErr, cmp.Comparer(compareErrors)): + t.Errorf("Expected network error. got %v; want %v", err, netErr) + case !tt.connectionError && !tt.networkError && err != nil: + t.Errorf("Expected error to be nil. got %v; want %v", err, "") } if tt.hasDesc { - require.Equal(t, desc.Kind, (description.ServerKind)(description.Unknown)) - require.NotNil(t, desc.LastError) + require.Equal(t, s.Description().Kind, (description.ServerKind)(description.Unknown)) + require.NotNil(t, s.Description().LastError) + } + + if (tt.connectionError || tt.networkError) && s.pool.generation != 1 { + t.Errorf("Expected pool to be drained once on connection or network error. got %d; want %d", s.pool.generation, 1) } - drained := s.pool.(*testpool).drainCalled.Load().(bool) - require.Equal(t, drained, tt.connectionError || tt.networkError) }) } t.Run("WriteConcernError", func(t *testing.T) { - s, err := NewServer(address.Address("localhost"), nil) + s, err := NewServer(address.Address("localhost")) require.NoError(t, err) var desc *description.Server descript := s.Description() desc = &descript require.Nil(t, desc.LastError) - s.pool, err = NewTestPool(false, false, desc) s.connectionstate = connected + s.pool.connected = connected wce := result.WriteConcernError{10107, "not master", []byte{}} require.Equal(t, wceIsNotMasterOrRecovering(&wce), true) @@ -123,19 +157,20 @@ func TestServer(t *testing.T) { require.Equal(t, resultDesc.LastError, &wce) // pool should be drained - drained := s.pool.(*testpool).drainCalled.Load().(bool) - require.Equal(t, drained, true) + if s.pool.generation != 1 { + t.Errorf("Expected pool to be drained once from a write concern error. got %d; want %d", s.pool.generation, 1) + } }) t.Run("no WriteConcernError", func(t *testing.T) { - s, err := NewServer(address.Address("localhost"), nil) + s, err := NewServer(address.Address("localhost")) require.NoError(t, err) var desc *description.Server descript := s.Description() desc = &descript require.Nil(t, desc.LastError) - s.pool, err = NewTestPool(false, false, desc) s.connectionstate = connected + s.pool.connected = connected wce := result.WriteConcernError{} require.Equal(t, wceIsNotMasterOrRecovering(&wce), false) @@ -145,12 +180,13 @@ func TestServer(t *testing.T) { require.Nil(t, s.Description().LastError) // pool should not be drained - drained := s.pool.(*testpool).drainCalled.Load().(bool) - require.Equal(t, drained, false) + if s.pool.generation != 0 { + t.Errorf("Expected pool to not be drained. got %d; want %d", s.pool.generation, 0) + } }) t.Run("update topology", func(t *testing.T) { var updated bool - s, err := NewServer(address.Address("localhost"), func(description.Server) { updated = true }) + s, err := ConnectServer(address.Address("localhost"), func(description.Server) { updated = true }) require.NoError(t, err) s.updateDescription(description.Server{Addr: s.address}, false) require.True(t, updated) diff --git a/x/mongo/driverlegacy/topology/topology.go b/x/mongo/driverlegacy/topology/topology.go index 931f517575..d0a17eaac8 100644 --- a/x/mongo/driverlegacy/topology/topology.go +++ b/x/mongo/driverlegacy/topology/topology.go @@ -137,7 +137,7 @@ func (t *Topology) Connect(ctx context.Context) error { for _, a := range t.cfg.seedList { addr := address.Address(a).Canonicalize() t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr}) - err = t.addServer(ctx, addr) + err = t.addServer(addr) } t.serversLock.Unlock() @@ -445,7 +445,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { } for _, a := range diff.Added { addr := address.Address(a).Canonicalize() - _ = t.addServer(context.TODO(), addr) + _ = t.addServer(addr) t.fsm.addServer(addr) } //store new description @@ -502,7 +502,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) { } for _, added := range diff.Added { - _ = t.addServer(ctx, added.Addr) + _ = t.addServer(added.Addr) } t.desc.Store(current) @@ -520,7 +520,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) { } -func (t *Topology) addServer(ctx context.Context, addr address.Address) error { +func (t *Topology) addServer(addr address.Address) error { if _, ok := t.servers[addr]; ok { return nil } @@ -528,7 +528,7 @@ func (t *Topology) addServer(ctx context.Context, addr address.Address) error { topoFunc := func(desc description.Server) { t.apply(context.TODO(), desc) } - svr, err := ConnectServer(ctx, addr, topoFunc, t.cfg.serverOpts...) + svr, err := ConnectServer(addr, topoFunc, t.cfg.serverOpts...) if err != nil { return err } diff --git a/x/mongo/driverlegacy/topology/topology_options.go b/x/mongo/driverlegacy/topology/topology_options.go index ba58d9207b..9df10355e4 100644 --- a/x/mongo/driverlegacy/topology/topology_options.go +++ b/x/mongo/driverlegacy/topology/topology_options.go @@ -8,11 +8,12 @@ package topology import ( "bytes" + "crypto/tls" "strings" "time" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" - "go.mongodb.org/mongo-driver/x/network/command" connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" ) @@ -55,10 +56,10 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option c.serverSelectionTimeout = cs.ServerSelectionTimeout } - var connOpts []connectionlegacy.Option + var connOpts []ConnectionOption if cs.AppName != "" { - connOpts = append(connOpts, connectionlegacy.WithAppName(func(string) string { return cs.AppName })) + connOpts = append(connOpts, WithAppName(func(string) string { return cs.AppName })) } switch cs.Connect { @@ -70,14 +71,14 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option if cs.ConnectTimeout > 0 { c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) - connOpts = append(connOpts, connectionlegacy.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) + connOpts = append(connOpts, WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) } if cs.SocketTimeoutSet { connOpts = append( connOpts, - connectionlegacy.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), - connectionlegacy.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), + WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), + WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), ) } @@ -86,7 +87,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option } if cs.MaxConnIdleTime > 0 { - connOpts = append(connOpts, connectionlegacy.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime })) + connOpts = append(connOpts, WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime })) } if cs.MaxPoolSizeSet { @@ -137,7 +138,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option x509Username = b.String() } - connOpts = append(connOpts, connectionlegacy.WithTLSConfig(func(*connectionlegacy.TLSConfig) *connectionlegacy.TLSConfig { return tlsConfig })) + connOpts = append(connOpts, WithTLSConfig(func(*tls.Config) *tls.Config { return tlsConfig.Config })) } if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI { @@ -170,7 +171,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option return err } - connOpts = append(connOpts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { + connOpts = append(connOpts, WithHandshaker(func(h Handshaker) Handshaker { options := &auth.HandshakeOptions{ AppName: cs.AppName, Authenticator: authenticator, @@ -184,19 +185,19 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option })) } else { // We need to add a non-auth Handshaker to the connection options - connOpts = append(connOpts, connectionlegacy.WithHandshaker(func(h connectionlegacy.Handshaker) connectionlegacy.Handshaker { - return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors} + connOpts = append(connOpts, WithHandshaker(func(h driver.Handshaker) driver.Handshaker { + return driver.IsMaster().AppName(cs.AppName).Compressors(cs.Compressors) })) } if len(cs.Compressors) > 0 { - connOpts = append(connOpts, connectionlegacy.WithCompressors(func(compressors []string) []string { + connOpts = append(connOpts, WithCompressors(func(compressors []string) []string { return append(compressors, cs.Compressors...) })) for _, comp := range cs.Compressors { if comp == "zlib" { - connOpts = append(connOpts, connectionlegacy.WithZlibLevel(func(level *int) *int { + connOpts = append(connOpts, WithZlibLevel(func(level *int) *int { return &cs.ZlibLevel })) } @@ -208,7 +209,7 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option } if len(connOpts) > 0 { - c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connectionlegacy.Option) []connectionlegacy.Option { + c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { return append(opts, connOpts...) })) } diff --git a/x/mongo/driverlegacy/topology/topology_test.go b/x/mongo/driverlegacy/topology/topology_test.go index ccb1059ebc..d30885c022 100644 --- a/x/mongo/driverlegacy/topology/topology_test.go +++ b/x/mongo/driverlegacy/topology/topology_test.go @@ -21,6 +21,7 @@ import ( const testTimeout = 2 * time.Second func noerr(t *testing.T, err error) { + t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) t.FailNow() @@ -225,7 +226,7 @@ func TestServerSelection(t *testing.T) { topo, err := New() noerr(t, err) atomic.StoreInt32(&topo.connectionstate, connected) - srvr, err := NewServer(address.Address("one"), func(desc description.Server) { topo.apply(context.Background(), desc) }) + srvr, err := ConnectServer(address.Address("one"), func(desc description.Server) { topo.apply(context.Background(), desc) }) noerr(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) @@ -259,7 +260,7 @@ func TestServerSelection(t *testing.T) { // manually add the servers to the topology for _, srv := range desc.Servers { - s, err := NewServer(srv.Addr, func(desc description.Server) { topo.apply(context.Background(), desc) }) + s, err := ConnectServer(srv.Addr, func(desc description.Server) { topo.apply(context.Background(), desc) }) noerr(t, err) topo.servers[srv.Addr] = s } @@ -279,8 +280,6 @@ func TestServerSelection(t *testing.T) { // send a not master error to the server forcing an update serv, err := topo.FindServer(desc.Servers[0]) noerr(t, err) - err = serv.pool.Connect(context.Background()) - noerr(t, err) atomic.StoreInt32(&serv.connectionstate, connected) sc := &sconn{s: serv.Server} sc.processErr(command.Error{Message: "not master"}) diff --git a/x/mongo/driverlegacy/update.go b/x/mongo/driverlegacy/update.go index 34e0eecdf6..7bc2ed5332 100644 --- a/x/mongo/driverlegacy/update.go +++ b/x/mongo/driverlegacy/update.go @@ -125,7 +125,7 @@ func update( ) (result.Update, error) { desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { if oldErr != nil { return result.Update{}, oldErr diff --git a/x/mongo/driverlegacy/write.go b/x/mongo/driverlegacy/write.go index 05bae76876..3d69c540f7 100644 --- a/x/mongo/driverlegacy/write.go +++ b/x/mongo/driverlegacy/write.go @@ -38,7 +38,7 @@ func Write( } desc := ss.Description() - conn, err := ss.Connection(ctx) + conn, err := ss.ConnectionLegacy(ctx) if err != nil { return nil, err } diff --git a/x/network/examples/server_monitoring/main.go b/x/network/examples/server_monitoring/main.go index b243de064c..f13c813bf3 100644 --- a/x/network/examples/server_monitoring/main.go +++ b/x/network/examples/server_monitoring/main.go @@ -7,25 +7,22 @@ package main import ( - "context" "log" "time" "github.com/kr/pretty" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/network/address" - "go.mongodb.org/mongo-driver/x/network/connection" ) func main() { s, err := topology.ConnectServer( - context.Background(), address.Address("localhost:27017"), nil, topology.WithHeartbeatInterval(func(time.Duration) time.Duration { return 2 * time.Second }), topology.WithConnectionOptions( - func(opts ...connection.Option) []connection.Option { - return append(opts, connection.WithAppName(func(string) string { return "server monitoring test" })) + func(opts ...topology.ConnectionOption) []topology.ConnectionOption { + return append(opts, topology.WithAppName(func(string) string { return "server monitoring test" })) }, ), ) diff --git a/x/network/examples/workload/main.go b/x/network/examples/workload/main.go index c6a070027c..aafb9813e0 100644 --- a/x/network/examples/workload/main.go +++ b/x/network/examples/workload/main.go @@ -89,7 +89,7 @@ func prep(ctx context.Context, c *topology.Topology) error { return err } - conn, err := s.Connection(ctx) + conn, err := s.ConnectionLegacy(ctx) if err != nil { return err } diff --git a/x/network/integration/aggregate_test.go b/x/network/integration/aggregate_test.go index 40cdefdf7c..6d974eb547 100644 --- a/x/network/integration/aggregate_test.go +++ b/x/network/integration/aggregate_test.go @@ -91,7 +91,7 @@ func TestCommandAggregate(t *testing.T) { t.Run("AllowDiskUse", func(t *testing.T) { server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) ds := []bsonx.Doc{ {{"_id", bsonx.Int32(1)}}, @@ -118,9 +118,9 @@ func TestCommandAggregate(t *testing.T) { t.Run("MaxTime", func(t *testing.T) { t.Skip("max time is flaky on the server") - server, err := topology.ConnectServer(context.Background(), address.Address(*host), nil) + server, err := topology.ConnectServer(address.Address(*host), nil) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) _, err = (&command.Write{ diff --git a/x/network/integration/command_test.go b/x/network/integration/command_test.go index 1e2ace7ac9..9c8f82f1fb 100644 --- a/x/network/integration/command_test.go +++ b/x/network/integration/command_test.go @@ -31,7 +31,7 @@ func TestCommand(t *testing.T) { } t.Parallel() - server, err := topology.ConnectServer(context.Background(), address.Address(*host), nil, serveropts(t)...) + server, err := topology.ConnectServer(address.Address(*host), nil, serveropts(t)...) noerr(t, err) ctx := context.Background() @@ -41,7 +41,7 @@ func TestCommand(t *testing.T) { DB: "admin", Command: bsonx.Doc{{"getnonce", bsonx.Int32(1)}}, } - rw, err := server.Connection(ctx) + rw, err := server.ConnectionLegacy(ctx) noerr(t, err) rdr, err := cmd.RoundTrip(ctx, server.SelectedDescription(), rw) @@ -67,7 +67,7 @@ func TestCommand(t *testing.T) { result = result[:0] cmd.Command = bsonx.Doc{{"ping", bsonx.Int32(1)}} - rw, err = server.Connection(ctx) + rw, err = server.ConnectionLegacy(ctx) noerr(t, err) rdr, err = cmd.RoundTrip(ctx, server.SelectedDescription(), rw) noerr(t, err) @@ -87,7 +87,7 @@ func TestWriteCommands(t *testing.T) { ctx := context.TODO() server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) cmd := &command.Insert{ @@ -98,7 +98,7 @@ func TestWriteCommands(t *testing.T) { _, err = cmd.RoundTrip(ctx, server.SelectedDescription(), conn) noerr(t, err) - conn, err = server.Connection(context.Background()) + conn, err = server.ConnectionLegacy(context.Background()) noerr(t, err) res, err := cmd.RoundTrip(ctx, server.SelectedDescription(), conn) noerr(t, err) diff --git a/x/network/integration/compressor_test.go b/x/network/integration/compressor_test.go index 7686e674a3..7cf920e95c 100644 --- a/x/network/integration/compressor_test.go +++ b/x/network/integration/compressor_test.go @@ -42,7 +42,7 @@ func TestCompression(t *testing.T) { } ctx := context.Background() - rw, err := server.Connection(ctx) + rw, err := server.ConnectionLegacy(ctx) noerr(t, err) rdr, err := cmd.RoundTrip(ctx, server.SelectedDescription(), rw) diff --git a/x/network/integration/list_collections_test.go b/x/network/integration/list_collections_test.go index 423b08db32..04f434d117 100644 --- a/x/network/integration/list_collections_test.go +++ b/x/network/integration/list_collections_test.go @@ -62,7 +62,7 @@ func TestCommandListCollections(t *testing.T) { server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) _, err = (&command.ListCollections{}).RoundTrip(context.Background(), server.SelectedDescription(), conn) diff --git a/x/network/integration/list_databases_test.go b/x/network/integration/list_databases_test.go index d1235c38e8..4b4b298d06 100644 --- a/x/network/integration/list_databases_test.go +++ b/x/network/integration/list_databases_test.go @@ -27,7 +27,7 @@ func TestListDatabases(t *testing.T) { } server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) wc := writeconcern.New(writeconcern.WMajority()) diff --git a/x/network/integration/main_test.go b/x/network/integration/main_test.go index e4ac223946..df9d3f6dac 100644 --- a/x/network/integration/main_test.go +++ b/x/network/integration/main_test.go @@ -17,6 +17,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" + "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/network/connection" "go.mongodb.org/mongo-driver/x/network/connstring" ) @@ -60,7 +61,12 @@ func noerr(t *testing.T, err error) { func autherr(t *testing.T, err error) { t.Helper() - switch err.(type) { + switch e := err.(type) { + case topology.ConnectionError: + _, ok := e.Wrapped.(*auth.Error) + if !ok { + t.Fatal("Expected auth error and didn't get one") + } case *auth.Error: return default: diff --git a/x/network/integration/opmsg_test.go b/x/network/integration/opmsg_test.go index 148e7a6149..eb676d3cb7 100644 --- a/x/network/integration/opmsg_test.go +++ b/x/network/integration/opmsg_test.go @@ -27,7 +27,7 @@ import ( func createServerConn(t *testing.T) (*topology.SelectedServer, connection.Connection) { server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) noerr(t, err) - conn, err := server.Connection(context.Background()) + conn, err := server.ConnectionLegacy(context.Background()) noerr(t, err) return server, conn diff --git a/x/network/integration/server_test.go b/x/network/integration/server_test.go index de3f728b75..61f8fb998e 100644 --- a/x/network/integration/server_test.go +++ b/x/network/integration/server_test.go @@ -8,16 +8,17 @@ package integration import ( "context" + "crypto/tls" "net" "strings" "testing" "time" "go.mongodb.org/mongo-driver/internal/testutil" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/auth" "go.mongodb.org/mongo-driver/x/mongo/driverlegacy/topology" "go.mongodb.org/mongo-driver/x/network/address" - "go.mongodb.org/mongo-driver/x/network/command" "go.mongodb.org/mongo-driver/x/network/connection" ) @@ -31,11 +32,11 @@ func TestTopologyServer(t *testing.T) { } t.Run("After close, should not return new connection", func(t *testing.T) { - s, err := topology.ConnectServer(context.Background(), address.Address(*host), nil, serveropts(t)...) + s, err := topology.ConnectServer(address.Address(*host), nil, serveropts(t)...) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - _, err = s.Connection(context.Background()) + _, err = s.ConnectionLegacy(context.Background()) if err != topology.ErrServerClosed { t.Errorf("Expected error from getting a connection from closed server, but got %v", err) } @@ -43,7 +44,7 @@ func TestTopologyServer(t *testing.T) { t.Run("Shouldn't be able to get more than max connections", func(t *testing.T) { t.Parallel() - s, err := topology.ConnectServer(context.Background(), address.Address(*host), nil, + s, err := topology.ConnectServer(address.Address(*host), nil, serveropts( t, topology.WithMaxConnections(func(uint16) uint16 { return 2 }), @@ -51,15 +52,15 @@ func TestTopologyServer(t *testing.T) { )..., ) noerr(t, err) - c1, err := s.Connection(context.Background()) + c1, err := s.ConnectionLegacy(context.Background()) noerr(t, err) defer c1.Close() - c2, err := s.Connection(context.Background()) + c2, err := s.ConnectionLegacy(context.Background()) noerr(t, err) defer c2.Close() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() - _, err = s.Connection(ctx) + _, err = s.ConnectionLegacy(ctx) if !strings.Contains(err.Error(), "deadline exceeded") { t.Errorf("Expected timeout while trying to open more than max connections, but got %v", err) } @@ -83,7 +84,7 @@ func TestTopologyServer(t *testing.T) { t.Run("Write network timeout", func(t *testing.T) {}) }) t.Run("Close should close all subscription channels", func(t *testing.T) { - s, err := topology.ConnectServer(context.Background(), address.Address(*host), nil, serveropts(t)...) + s, err := topology.ConnectServer(address.Address(*host), nil, serveropts(t)...) noerr(t, err) var done1, done2 = make(chan struct{}), make(chan struct{}) @@ -124,7 +125,7 @@ func TestTopologyServer(t *testing.T) { } }) t.Run("Subscribe after Close should return an error", func(t *testing.T) { - s, err := topology.ConnectServer(context.Background(), address.Address(*host), nil, serveropts(t)...) + s, err := topology.ConnectServer(address.Address(*host), nil, serveropts(t)...) noerr(t, err) sub, err := s.Subscribe() @@ -142,7 +143,7 @@ func TestTopologyServer(t *testing.T) { }) t.Run("Disconnect", func(t *testing.T) { t.Run("cannot disconnect before connecting", func(t *testing.T) { - s, err := topology.NewServer(address.Address(*host), nil, serveropts(t)...) + s, err := topology.NewServer(address.Address(*host), serveropts(t)...) noerr(t, err) got := s.Disconnect(context.TODO()) @@ -151,9 +152,9 @@ func TestTopologyServer(t *testing.T) { } }) t.Run("cannot disconnect twice", func(t *testing.T) { - s, err := topology.NewServer(address.Address(*host), nil, serveropts(t)...) + s, err := topology.NewServer(address.Address(*host), serveropts(t)...) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) got := s.Disconnect(context.TODO()) @@ -168,21 +169,21 @@ func TestTopologyServer(t *testing.T) { t.Run("all open sockets should be closed after disconnect", func(t *testing.T) { d := newdialer(&net.Dialer{}) s, err := topology.NewServer( - address.Address(*host), nil, + address.Address(*host), serveropts( t, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { - return append(opts, connection.WithDialer(func(connection.Dialer) connection.Dialer { return d })) + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { + return append(opts, topology.WithDialer(func(topology.Dialer) topology.Dialer { return d })) }), )..., ) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) conns := [3]connection.Connection{} for idx := range [3]struct{}{} { - conns[idx], err = s.Connection(context.TODO()) + conns[idx], err = s.ConnectionLegacy(context.TODO()) noerr(t, err) } for idx := range [2]struct{}{} { @@ -203,50 +204,50 @@ func TestTopologyServer(t *testing.T) { }) t.Run("Connect", func(t *testing.T) { t.Run("can reconnect a disconnected server", func(t *testing.T) { - s, err := topology.NewServer(address.Address(*host), nil, serveropts(t)...) + s, err := topology.NewServer(address.Address(*host), serveropts(t)...) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) }) t.Run("cannot connect multiple times without disconnect", func(t *testing.T) { - s, err := topology.NewServer(address.Address(*host), nil, serveropts(t)...) + s, err := topology.NewServer(address.Address(*host), serveropts(t)...) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) if err != topology.ErrServerConnected { t.Errorf("Did not receive expected error. got %v; want %v", err, topology.ErrServerConnected) } }) t.Run("can disconnect and reconnect multiple times", func(t *testing.T) { - s, err := topology.NewServer(address.Address(*host), nil, serveropts(t)...) + s, err := topology.NewServer(address.Address(*host), serveropts(t)...) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) err = s.Disconnect(context.TODO()) noerr(t, err) - err = s.Connect(context.TODO()) + err = s.Connect(nil) noerr(t, err) }) }) @@ -260,7 +261,7 @@ func serveropts(t *testing.T, opts ...topology.ServerOption) []topology.ServerOp } } cs := testutil.ConnString(t) - var connOpts []connection.Option + var connOpts []topology.ConnectionOption if cs.Username != "" || cs.AuthMechanism == auth.GSSAPI { cred := &auth.Cred{ Source: "admin", @@ -284,15 +285,15 @@ func serveropts(t *testing.T, opts ...topology.ServerOption) []topology.ServerOp authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred) noerr(t, err) - connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { + connOpts = append(connOpts, topology.WithHandshaker(func(h driver.Handshaker) driver.Handshaker { return auth.Handshaker(h, &auth.HandshakeOptions{ AppName: cs.AppName, Authenticator: authenticator, }) })) } else { - connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { - return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors} + connOpts = append(connOpts, topology.WithHandshaker(func(h driver.Handshaker) driver.Handshaker { + return driver.IsMaster().AppName(cs.AppName).Compressors(cs.Compressors) })) } @@ -308,17 +309,17 @@ func serveropts(t *testing.T, opts ...topology.ServerOption) []topology.ServerOp tlsConfig.SetInsecure(true) } - connOpts = append(connOpts, connection.WithTLSConfig(func(*connection.TLSConfig) *connection.TLSConfig { return tlsConfig })) + connOpts = append(connOpts, topology.WithTLSConfig(func(*tls.Config) *tls.Config { return tlsConfig.Config })) } if len(cs.Compressors) > 0 { - connOpts = append(connOpts, connection.WithCompressors(func(compressors []string) []string { + connOpts = append(connOpts, topology.WithCompressors(func(compressors []string) []string { return append(compressors, cs.Compressors...) })) for _, comp := range cs.Compressors { if comp == "zlib" { - connOpts = append(connOpts, connection.WithZlibLevel(func(level *int) *int { + connOpts = append(connOpts, topology.WithZlibLevel(func(level *int) *int { return &cs.ZlibLevel })) } @@ -326,7 +327,7 @@ func serveropts(t *testing.T, opts ...topology.ServerOption) []topology.ServerOp } if len(connOpts) > 0 { - opts = append(opts, topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + opts = append(opts, topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append(opts, connOpts...) })) } diff --git a/x/network/integration/topology_test.go b/x/network/integration/topology_test.go index 2442189aaa..f2885e8a06 100644 --- a/x/network/integration/topology_test.go +++ b/x/network/integration/topology_test.go @@ -49,10 +49,10 @@ func TestTopologyTopology(t *testing.T) { topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { return append( opts, - topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { return append( opts, - connection.WithDialer(func(connection.Dialer) connection.Dialer { return d }), + topology.WithDialer(func(topology.Dialer) topology.Dialer { return d }), ) }), ) @@ -66,7 +66,7 @@ func TestTopologyTopology(t *testing.T) { conns := [3]connection.Connection{} for idx := range [3]struct{}{} { - conns[idx], err = ss.Connection(context.TODO()) + conns[idx], err = ss.ConnectionLegacy(context.TODO()) noerr(t, err) } for idx := range [2]struct{}{} {