From bb9530df7191054d68c7879dc4a2045eb71ad856 Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Mon, 8 Jul 2019 19:53:43 -0400 Subject: [PATCH] Change Client to depend on driver.Deployment - Move session pooling from Topology to Client. - Add method to create Client from a Deployment. - Add interfaces to allow arbitrary Deployments to support connecting, disconnecting, and sessions. GODRIVER-1250 Change-Id: I64c02c29424a3690d43c995c8bd2bcbc19cb9e96 --- .errcheck-excludes | 2 +- internal/testutil/config.go | 22 +++++ mongo/bulk_write.go | 6 +- mongo/change_stream.go | 10 +-- mongo/change_stream_test.go | 6 -- mongo/client.go | 63 ++++++++++----- mongo/client_internal_test.go | 35 ++------ mongo/collection.go | 64 +++++++-------- mongo/command_monitoring_test.go | 3 +- mongo/database.go | 18 ++--- mongo/database_internal_test.go | 12 --- mongo/index_view.go | 18 ++--- mongo/index_view_internal_test.go | 22 ----- mongo/operation_legacy_test.go | 8 +- mongo/options/clientoptions.go | 6 +- mongo/results_test.go | 6 -- mongo/retryable_writes_test.go | 6 +- mongo/session.go | 7 +- mongo/sessions_test.go | 21 +++-- mongo/transactions_test.go | 22 +++-- x/mongo/driver/driver.go | 24 ++++++ .../examples/cluster_monitoring/main.go | 2 +- .../topology/polling_srv_records_test.go | 8 +- x/mongo/driver/topology/topology.go | 81 +++++++++---------- 24 files changed, 233 insertions(+), 239 deletions(-) diff --git a/.errcheck-excludes b/.errcheck-excludes index c35c977ae7..b0b95df09a 100644 --- a/.errcheck-excludes +++ b/.errcheck-excludes @@ -2,7 +2,7 @@ (*go.mongodb.org/mongo-driver/x/network/connection.connection).Close (go.mongodb.org/mongo-driver/x/network/connection.Connection).Close (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.connection).close -(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Subscription).Unsubscribe +(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Topology).Unsubscribe (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Server).Close (*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.pool).close diff --git a/internal/testutil/config.go b/internal/testutil/config.go index 54a5e9c1c6..7d5dc56e95 100644 --- a/internal/testutil/config.go +++ b/internal/testutil/config.go @@ -23,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) @@ -30,9 +31,11 @@ var connectionString connstring.ConnString var connectionStringOnce sync.Once var connectionStringErr error var liveTopology *topology.Topology +var liveSessionPool *session.Pool var liveTopologyOnce sync.Once var liveTopologyErr error var monitoredTopology *topology.Topology +var monitoredSessionPool *session.Pool var monitoredTopologyOnce sync.Once var monitoredTopologyErr error @@ -144,6 +147,10 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) require.NoError(t, err) + + sub, err := monitoredTopology.Subscribe() + require.NoError(t, err) + monitoredSessionPool = session.NewPool(sub.Updates) } }) @@ -154,6 +161,12 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol return monitoredTopology } +// GlobalMonitoredSessionPool returns the globally configured session pool. +// Must be called after GlobalMonitoredTopology() +func GlobalMonitoredSessionPool() *session.Pool { + return monitoredSessionPool +} + // Topology gets the globally configured topology. func Topology(t *testing.T) *topology.Topology { cs := ConnString(t) @@ -168,6 +181,10 @@ func Topology(t *testing.T) *topology.Topology { err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background()) require.NoError(t, err) + + sub, err := liveTopology.Subscribe() + require.NoError(t, err) + liveSessionPool = session.NewPool(sub.Updates) } }) @@ -178,6 +195,11 @@ func Topology(t *testing.T) *topology.Topology { return liveTopology } +// SessionPool gets the globally configured session pool. Must be called after Topology(). +func SessionPool() *session.Pool { + return liveSessionPool +} + // TopologyWithConnString takes a connection string and returns a connected // topology, or else bails out of testing func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology { diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index e4597840b0..282e4665df 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -177,7 +177,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt) + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) } @@ -223,7 +223,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt) + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) if bw.ordered != nil { op = op.Ordered(*bw.ordered) } @@ -287,7 +287,7 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt) + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) if bw.ordered != nil { op = op.Ordered(*bw.ordered) } diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 85a86aeca0..8ffed0ee4b 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -87,8 +87,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in } cs.sess = sessionFromContext(ctx) - if cs.sess == nil && cs.client.topology.SessionPool != nil { - cs.sess, cs.err = session.NewClientSession(cs.client.topology.SessionPool, cs.client.id, session.Implicit) + if cs.sess == nil && cs.client.sessionPool != nil { + cs.sess, cs.err = session.NewClientSession(cs.client.sessionPool, cs.client.id, session.Implicit) if cs.err != nil { return nil, cs.Err() } @@ -100,7 +100,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cs.aggregate = operation.NewAggregate(nil). ReadPreference(config.readPreference).ReadConcern(config.readConcern). - Deployment(cs.client.topology).ClusterClock(cs.client.clock). + Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone) if cs.options.Collation != nil { @@ -161,7 +161,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error { var server driver.Server var conn driver.Connection - if server, cs.err = cs.client.topology.SelectServer(ctx, cs.selector); cs.err != nil { + if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { return cs.Err() } if conn, cs.err = server.Connection(ctx); cs.err != nil { @@ -204,7 +204,7 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err break } - server, err := cs.client.topology.SelectServer(ctx, cs.selector) + server, err := cs.client.deployment.SelectServer(ctx, cs.selector) if err != nil { break } diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index c64efbc03f..bfb6cc79db 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -205,8 +205,6 @@ func TestChangeStream(t *testing.T) { skipIfBelow36(t) t.Run("TestFirstStage", func(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -236,8 +234,6 @@ func TestChangeStream(t *testing.T) { }) t.Run("TestReplaceRoot", func(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -280,8 +276,6 @@ func TestChangeStream(t *testing.T) { }) t.Run("TestNoCustomStandaloneError", func(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } diff --git a/mongo/client.go b/mongo/client.go index e37944fd4e..0e11bbde20 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -9,6 +9,7 @@ package mongo import ( "context" "crypto/tls" + "errors" "strconv" "strings" "time" @@ -42,7 +43,7 @@ var keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority( type Client struct { id uuid.UUID topologyOptions []topology.Option - topology *topology.Topology + deployment driver.Deployment connString connstring.ConnString localThreshold time.Duration retryWrites bool @@ -54,6 +55,7 @@ type Client struct { registry *bsoncodec.Registry marshaller BSONAppender monitor *event.CommandMonitor + sessionPool *session.Pool // client-side encryption fields keyVaultClient *Client @@ -102,7 +104,7 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return nil, err } - client.topology, err = topology.New(client.topologyOptions...) + client.deployment, err = topology.New(client.topologyOptions...) if err != nil { return nil, replaceErrors(err) } @@ -113,20 +115,33 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { // Connect initializes the Client by starting background monitoring goroutines. // This method must be called before a Client can be used. func (c *Client) Connect(ctx context.Context) error { - err := c.topology.Connect() - if err != nil { - return replaceErrors(err) + if connector, ok := c.deployment.(driver.Connector); ok { + err := connector.Connect() + if err != nil { + return replaceErrors(err) + } } + if c.mongocryptd != nil { - if err = c.mongocryptd.connect(ctx); err != nil { + if err := c.mongocryptd.connect(ctx); err != nil { return err } } if c.keyVaultClient != nil { - if err = c.keyVaultClient.Connect(ctx); err != nil { + if err := c.keyVaultClient.Connect(ctx); err != nil { return err } } + + var updateChan <-chan description.Topology + if subscriber, ok := c.deployment.(driver.Subscriber); ok { + sub, err := subscriber.Subscribe() + if err != nil { + return replaceErrors(err) + } + updateChan = sub.Updates + } + c.sessionPool = session.NewPool(updateChan) return nil } @@ -157,7 +172,11 @@ func (c *Client) Disconnect(ctx context.Context) error { if c.crypt != nil { c.crypt.Close() } - return replaceErrors(c.topology.Disconnect(ctx)) + + if disconnector, ok := c.deployment.(driver.Disconnector); ok { + return replaceErrors(disconnector.Disconnect(ctx)) + } + return nil } // Ping verifies that the client can connect to the topology. @@ -182,7 +201,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { // StartSession starts a new session. func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) { - if c.topology.SessionPool == nil { + if c.sessionPool == nil { return nil, ErrClientDisconnected } @@ -208,7 +227,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) coreOpts.DefaultMaxCommitTime = sopts.DefaultMaxCommitTime } - sess, err := session.NewClientSession(c.topology.SessionPool, c.id, session.Explicit, coreOpts) + sess, err := session.NewClientSession(c.sessionPool, c.id, session.Explicit, coreOpts) if err != nil { return nil, replaceErrors(err) } @@ -219,16 +238,16 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) return &sessionImpl{ clientSession: sess, client: c, - topo: c.topology, + deployment: c.deployment, }, nil } func (c *Client) endSessions(ctx context.Context) { - if c.topology.SessionPool == nil { + if c.sessionPool == nil { return } - ids := c.topology.SessionPool.IDSlice() + ids := c.sessionPool.IDSlice() idx, idArray := bsoncore.AppendArrayStart(nil) for i, id := range ids { idDoc, _ := id.MarshalBSON() @@ -236,7 +255,7 @@ func (c *Client) endSessions(ctx context.Context) { } idArray, _ = bsoncore.AppendArrayEnd(idArray, idx) - op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.topology). + op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.deployment). ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor). Database("admin").Crypt(c.crypt) @@ -493,6 +512,14 @@ func (c *Client) configure(opts *options.ClientOptions) error { func(...topology.ServerOption) []topology.ServerOption { return serverOpts }, )) + // Deployment + if opts.Deployment != nil { + if len(serverOpts) > 2 || len(topologyOpts) > 1 { + return errors.New("cannot specify topology or server options with a deployment") + } + c.deployment = opts.Deployment + } + return nil } @@ -585,8 +612,8 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... sess := sessionFromContext(ctx) err := c.validSession(sess) - if sess == nil && c.topology.SessionPool != nil { - sess, err = session.NewClientSession(c.topology.SessionPool, c.id, session.Implicit) + if sess == nil && c.sessionPool != nil { + sess, err = session.NewClientSession(c.sessionPool, c.id, session.Implicit) if err != nil { return ListDatabasesResult{}, err } @@ -612,7 +639,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... ldo := options.MergeListDatabasesOptions(opts...) op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). - ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.topology).Crypt(c.crypt) + ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.crypt) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) } @@ -700,7 +727,7 @@ func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.Sessio // The client must have read concern majority or no read concern for a change stream to be created successfully. func (c *Client) Watch(ctx context.Context, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { - if c.topology.SessionPool == nil { + if c.sessionPool == nil { return nil, ErrClientDisconnected } diff --git a/mongo/client_internal_test.go b/mongo/client_internal_test.go index af42d68506..379c7fb56a 100644 --- a/mongo/client_internal_test.go +++ b/mongo/client_internal_test.go @@ -43,12 +43,13 @@ func createTestClient(t *testing.T) *Client { id, _ := uuid.New() return &Client{ id: id, - topology: testutil.Topology(t), + deployment: testutil.Topology(t), connString: testutil.ConnString(t), readPreference: readpref.Primary(), clock: &session.ClusterClock{}, registry: bson.DefaultRegistry, retryWrites: true, + sessionPool: testutil.SessionPool(), } } @@ -56,7 +57,7 @@ func createTestClientWithConnstring(t *testing.T, cs connstring.ConnString) *Cli id, _ := uuid.New() return &Client{ id: id, - topology: testutil.TopologyWithConnString(t, cs), + deployment: testutil.TopologyWithConnString(t, cs), connString: cs, readPreference: readpref.Primary(), clock: &session.ClusterClock{}, @@ -74,15 +75,11 @@ func skipIfBelow30(t *testing.T) { } func TestNewClient(t *testing.T) { - t.Parallel() - c := createTestClient(t) - require.NotNil(t, c.topology) + require.NotNil(t, c.deployment) } func TestClient_Database(t *testing.T) { - t.Parallel() - dbName := "foo" c := createTestClient(t) @@ -149,8 +146,6 @@ func TestClientRegistryPassedToCursors(t *testing.T) { func TestClient_TLSConnection(t *testing.T) { skipIfBelow30(t) // 3.0 doesn't return a security field in the serverStatus response - t.Parallel() - if testing.Short() { t.Skip() } @@ -182,8 +177,6 @@ func TestClient_TLSConnection(t *testing.T) { } func TestClient_X509Auth(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -265,8 +258,6 @@ func TestClient_X509Auth(t *testing.T) { } func TestClient_ReplaceTopologyError(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -427,7 +418,7 @@ func TestRetryWritesError20Wrapped(t *testing.T) { op := operation.NewInsert(writeError).CommandMonitor(coll.client.monitor).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Deployment(deployment).Retry(driver.RetryOnce).Session(sess.(*sessionImpl).clientSession) + Deployment(coll.client.deployment).Deployment(deployment).Retry(driver.RetryOnce).Session(sess.(*sessionImpl).clientSession) err = op.Execute(context.Background()) if test.shouldError { @@ -446,8 +437,6 @@ func TestRetryWritesError20Wrapped(t *testing.T) { } func TestClient_ListDatabases_noFilter(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -478,8 +467,6 @@ func TestClient_ListDatabases_noFilter(t *testing.T) { } func TestClient_ListDatabases_filter(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -508,8 +495,6 @@ func TestClient_ListDatabases_filter(t *testing.T) { } func TestClient_ListDatabaseNames_noFilter(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -540,8 +525,6 @@ func TestClient_ListDatabaseNames_noFilter(t *testing.T) { } func TestClient_ListDatabaseNames_filter(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -571,8 +554,6 @@ func TestClient_ListDatabaseNames_filter(t *testing.T) { } func TestClient_NilDocumentError(t *testing.T) { - t.Parallel() - c := createTestClient(t) _, err := c.Watch(context.Background(), nil) @@ -586,8 +567,6 @@ func TestClient_NilDocumentError(t *testing.T) { } func TestClient_ReadPreference(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -619,8 +598,6 @@ func TestClient_ReadPreference(t *testing.T) { } func TestClient_ReadPreferenceAbsent(t *testing.T) { - t.Parallel() - cs := testutil.ConnString(t) c, err := NewClient(options.Client().ApplyURI(cs.String())) require.NoError(t, err) @@ -754,7 +731,7 @@ func TestIsMaster(t *testing.T) { ) require.NoError(t, err) - isMaster := operation.NewIsMaster().ClusterClock(client.clock).Deployment(client.topology). + isMaster := operation.NewIsMaster().ClusterClock(client.clock).Deployment(client.deployment). AppName(cs.AppName).Compressors(cs.Compressors) err = isMaster.Execute(ctx) diff --git a/mongo/collection.go b/mongo/collection.go index fd598304ca..7415712b84 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -179,8 +179,8 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err := session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err := session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -244,9 +244,9 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { + if sess == nil && coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -272,7 +272,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.crypt) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) @@ -361,8 +361,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -401,7 +401,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.crypt) // deleteMany cannot be retried retryMode := driver.RetryNone @@ -468,9 +468,9 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx) sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { + if sess == nil && coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -496,7 +496,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.crypt) if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation { op = op.BypassDocumentValidation(*uo.BypassDocumentValidation) @@ -630,8 +630,8 @@ func aggregate(a aggregateParams) (*Cursor, error) { } sess := sessionFromContext(a.ctx) - if sess == nil && a.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(a.client.topology.SessionPool, a.client.id, session.Implicit) + if sess == nil && a.client.sessionPool != nil { + sess, err = session.NewClientSession(a.client.sessionPool, a.client.id, session.Implicit) if err != nil { return nil, err } @@ -666,7 +666,7 @@ func aggregate(a aggregateParams) (*Cursor, error) { } op := operation.NewAggregate(pipelineArr).Session(sess).WriteConcern(wc).ReadConcern(rc).ReadPreference(a.readPreference).CommandMonitor(a.client.monitor). - ServerSelector(selector).ClusterClock(a.client.clock).Database(a.db).Collection(a.col).Deployment(a.client.topology).Crypt(a.client.crypt) + ServerSelector(selector).ClusterClock(a.client.clock).Database(a.db).Collection(a.col).Deployment(a.client.deployment).Crypt(a.client.crypt) if ao.AllowDiskUse != nil { op.AllowDiskUse(*ao.AllowDiskUse) } @@ -740,8 +740,8 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return 0, err } @@ -759,7 +759,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). - Collection(coll.name).Deployment(coll.client.topology).Crypt(coll.client.crypt) + Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.crypt) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -813,8 +813,8 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, sess := sessionFromContext(ctx) var err error - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return 0, err } @@ -834,7 +834,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). - Deployment(coll.client.topology).ReadConcern(rc).ReadPreference(coll.readPreference). + Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.crypt) co := options.MergeEstimatedDocumentCountOptions(opts...) @@ -868,8 +868,8 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -892,7 +892,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i op := operation.NewDistinct(fieldName, bsoncore.Document(f)). Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). - Deployment(coll.client.topology).ReadConcern(rc).ReadPreference(coll.readPreference). + Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.crypt) if option.Collation != nil { @@ -949,9 +949,9 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { + if sess == nil && coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -973,7 +973,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.crypt) fo := options.MergeFindOptions(opts...) cursorOpts := driver.CursorOptions{ @@ -1140,8 +1140,8 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd sess := sessionFromContext(ctx) var err error - if sess == nil && coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + if sess == nil && coll.client.sessionPool != nil { + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return &SingleResult{err: err} } @@ -1175,7 +1175,7 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd ClusterClock(coll.client.clock). Database(coll.db.name). Collection(coll.name). - Deployment(coll.client.topology). + Deployment(coll.client.deployment). Retry(retry). Crypt(coll.client.crypt) @@ -1369,9 +1369,9 @@ func (coll *Collection) Drop(ctx context.Context) error { } sess := sessionFromContext(ctx) - if sess == nil && coll.client.topology.SessionPool != nil { + if sess == nil && coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(coll.client.topology.SessionPool, coll.client.id, session.Implicit) + sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) if err != nil { return err } @@ -1397,7 +1397,7 @@ func (coll *Collection) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.topology).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.crypt) err = op.Execute(ctx) // ignore namespace not found erorrs diff --git a/mongo/command_monitoring_test.go b/mongo/command_monitoring_test.go index 5184858e26..f2611bb8b3 100644 --- a/mongo/command_monitoring_test.go +++ b/mongo/command_monitoring_test.go @@ -52,7 +52,8 @@ var monitor = &event.CommandMonitor{ func createMonitoredClient(t *testing.T, monitor *event.CommandMonitor) *Client { client, err := NewClient() testhelpers.RequireNil(t, err, "unable to create client") - client.topology = testutil.GlobalMonitoredTopology(t, monitor) + client.deployment = testutil.GlobalMonitoredTopology(t, monitor) + client.sessionPool = testutil.GlobalMonitoredSessionPool() client.connString = testutil.ConnString(t) client.readPreference = readpref.Primary() client.clock = &session.ClusterClock{} diff --git a/mongo/database.go b/mongo/database.go index e76c53890f..6aad8ce044 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -120,9 +120,9 @@ func (db *Database) Aggregate(ctx context.Context, pipeline interface{}, func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, opts ...*options.RunCmdOptions) (*operation.Command, *session.Client, error) { sess := sessionFromContext(ctx) - if sess == nil && db.client.topology.SessionPool != nil { + if sess == nil && db.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(db.client.topology.SessionPool, db.client.id, session.Implicit) + sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) if err != nil { return nil, sess, err } @@ -153,7 +153,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, return operation.NewCommand(runCmdDoc). Session(sess).CommandMonitor(db.client.monitor). ServerSelector(readSelect).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.topology).ReadConcern(db.readConcern).Crypt(db.client.crypt), sess, nil + Database(db.name).Deployment(db.client.deployment).ReadConcern(db.readConcern).Crypt(db.client.crypt), sess, nil } // RunCommand runs a command on the database. A user can supply a custom @@ -211,8 +211,8 @@ func (db *Database) Drop(ctx context.Context) error { } sess := sessionFromContext(ctx) - if sess == nil && db.client.topology.SessionPool != nil { - sess, err := session.NewClientSession(db.client.topology.SessionPool, db.client.id, session.Implicit) + if sess == nil && db.client.sessionPool != nil { + sess, err := session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) if err != nil { return err } @@ -237,7 +237,7 @@ func (db *Database) Drop(ctx context.Context) error { op := operation.NewDropDatabase(). Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.topology).Crypt(db.client.crypt) + Database(db.name).Deployment(db.client.deployment).Crypt(db.client.crypt) err = op.Execute(ctx) @@ -260,8 +260,8 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt } sess := sessionFromContext(ctx) - if sess == nil && db.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(db.client.topology.SessionPool, db.client.id, session.Implicit) + if sess == nil && db.client.sessionPool != nil { + sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) if err != nil { return nil, err } @@ -283,7 +283,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt op := operation.NewListCollections(filterDoc). Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.topology).Crypt(db.client.crypt) + Database(db.name).Deployment(db.client.deployment).Crypt(db.client.crypt) if lco.NameOnly != nil { op = op.NameOnly(*lco.NameOnly) } diff --git a/mongo/database_internal_test.go b/mongo/database_internal_test.go index 4bd8be745d..133c92d109 100644 --- a/mongo/database_internal_test.go +++ b/mongo/database_internal_test.go @@ -41,8 +41,6 @@ func createTestDatabase(t *testing.T, name *string, opts ...*options.DatabaseOpt } func TestDatabase_initialize(t *testing.T) { - t.Parallel() - name := "foo" db := createTestDatabase(t, &name) @@ -110,8 +108,6 @@ func TestDatabase_InheritOptions(t *testing.T) { } func TestDatabase_ReplaceTopologyError(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -134,8 +130,6 @@ func TestDatabase_ReplaceTopologyError(t *testing.T) { } func TestDatabase_RunCommand(t *testing.T) { - t.Parallel() - db := createTestDatabase(t, nil) var result bsonx.Doc @@ -154,8 +148,6 @@ func TestDatabase_RunCommand(t *testing.T) { } func TestDatabase_RunCommand_DecodeStruct(t *testing.T) { - t.Parallel() - db := createTestDatabase(t, nil) result := struct { @@ -170,8 +162,6 @@ func TestDatabase_RunCommand_DecodeStruct(t *testing.T) { } func TestDatabase_NilDocumentError(t *testing.T) { - t.Parallel() - db := createTestDatabase(t, nil) err := db.RunCommand(context.Background(), nil).Err() @@ -188,8 +178,6 @@ func TestDatabase_NilDocumentError(t *testing.T) { } func TestDatabase_Drop(t *testing.T) { - t.Parallel() - name := "TestDatabase_Drop" db := createTestDatabase(t, &name) diff --git a/mongo/index_view.go b/mongo/index_view.go index e21e9b045f..ebb075d8ed 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -61,9 +61,9 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption } sess := sessionFromContext(ctx) - if sess == nil && iv.coll.client.topology.SessionPool != nil { + if sess == nil && iv.coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(iv.coll.client.topology.SessionPool, iv.coll.client.id, session.Implicit) + sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -84,7 +84,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption Session(sess).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). - Deployment(iv.coll.client.topology) + Deployment(iv.coll.client.deployment) var cursorOpts driver.CursorOptions lio := options.MergeListIndexesOptions(opts...) @@ -185,8 +185,8 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. sess := sessionFromContext(ctx) - if sess == nil && iv.coll.client.topology.SessionPool != nil { - sess, err = session.NewClientSession(iv.coll.client.topology.SessionPool, iv.coll.client.id, session.Implicit) + if sess == nil && iv.coll.client.sessionPool != nil { + sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -205,7 +205,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. op := operation.NewCreateIndexes(indexes). Session(sess).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). - Deployment(iv.coll.client.topology).ServerSelector(selector) + Deployment(iv.coll.client.deployment).ServerSelector(selector) if option.MaxTime != nil { op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) @@ -308,9 +308,9 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop } sess := sessionFromContext(ctx) - if sess == nil && iv.coll.client.topology.SessionPool != nil { + if sess == nil && iv.coll.client.sessionPool != nil { var err error - sess, err = session.NewClientSession(iv.coll.client.topology.SessionPool, iv.coll.client.id, session.Implicit) + sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) if err != nil { return nil, err } @@ -337,7 +337,7 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). - Deployment(iv.coll.client.topology) + Deployment(iv.coll.client.deployment) if dio.MaxTime != nil { op.MaxTimeMS(int64(*dio.MaxTime / time.Millisecond)) } diff --git a/mongo/index_view_internal_test.go b/mongo/index_view_internal_test.go index 0b96402d4b..dfc2d53584 100644 --- a/mongo/index_view_internal_test.go +++ b/mongo/index_view_internal_test.go @@ -88,8 +88,6 @@ func getIndexableCollection(t *testing.T) (string, *Collection) { } func TestIndexView_List(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -121,8 +119,6 @@ func TestIndexView_List(t *testing.T) { } func TestIndexView_CreateOne(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -162,8 +158,6 @@ func TestIndexView_CreateOne(t *testing.T) { } func TestIndexView_CreateOneWithNameOption(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -206,8 +200,6 @@ func TestIndexView_CreateOneWithNameOption(t *testing.T) { // Omits collation option because it's incompatible with version option func TestIndexView_CreateOneWithAllOptions(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -248,8 +240,6 @@ func TestIndexView_CreateOneWithAllOptions(t *testing.T) { func TestIndexView_CreateOneWithCollationOption(t *testing.T) { skipIfBelow34(t, createTestDatabase(t, nil)) // collation invalid for server versions < 3.4 - t.Parallel() - if testing.Short() { t.Skip() } @@ -315,8 +305,6 @@ func TestIndexView_CreateOneWildcard(t *testing.T) { } func TestIndexView_CreateOneWithNilKeys(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -334,8 +322,6 @@ func TestIndexView_CreateOneWithNilKeys(t *testing.T) { } func TestIndexView_CreateMany(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -397,8 +383,6 @@ func TestIndexView_CreateMany(t *testing.T) { } func TestIndexView_DropOne(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -447,8 +431,6 @@ func TestIndexView_DropOne(t *testing.T) { } func TestIndexView_DropAll(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -497,8 +479,6 @@ func TestIndexView_DropAll(t *testing.T) { } func TestIndexView_CreateIndexesOptioner(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } @@ -567,8 +547,6 @@ func TestIndexView_CreateIndexesOptioner(t *testing.T) { } func TestIndexView_DropIndexesOptioner(t *testing.T) { - t.Parallel() - if testing.Short() { t.Skip() } diff --git a/mongo/operation_legacy_test.go b/mongo/operation_legacy_test.go index 01000e381b..20ead8e78d 100644 --- a/mongo/operation_legacy_test.go +++ b/mongo/operation_legacy_test.go @@ -442,7 +442,7 @@ func TestOperationLegacy(t *testing.T) { expected = append(expected, docBytes) } - res, srvr := runOperationWithDeployment(t, cmd, db.client.topology, driver.LegacyFind) + res, srvr := runOperationWithDeployment(t, cmd, db.client.deployment, driver.LegacyFind) docs := parseAndIterateCursor(t, res, srvr, 2) if len(docs) != len(expected) { t.Fatalf("documents length match; expected %d, got %d", len(expected), len(docs)) @@ -459,7 +459,7 @@ func TestOperationLegacy(t *testing.T) { } cmd := bson.D{{"listCollections", 1}} - res, srvr := runOperationWithDeployment(t, cmd, db.client.topology, driver.LegacyListCollections) + res, srvr := runOperationWithDeployment(t, cmd, db.client.deployment, driver.LegacyListCollections) docs := parseAndIterateCursor(t, res, srvr, 2) if len(docs) != 3 { t.Fatalf("documents length mismatch; expected 3, got %d", len(docs)) @@ -481,7 +481,7 @@ func TestOperationLegacy(t *testing.T) { cmd := bson.D{ {"listIndexes", "foo"}, } - res, srvr := runOperationWithDeployment(t, cmd, db.client.topology, driver.LegacyListIndexes) + res, srvr := runOperationWithDeployment(t, cmd, db.client.deployment, driver.LegacyListIndexes) docs := parseAndIterateCursor(t, res, srvr, 2) if len(docs) != 3 { t.Fatalf("documents length mismatch; expected 3, got %d", len(docs)) @@ -509,7 +509,7 @@ func TestOperationLegacy(t *testing.T) { {"killCursors", "foo"}, {"cursors", bson.A{int64(1), int64(2)}}, } - res, _ := runOperationWithDeployment(t, cmd, db.client.topology, driver.LegacyKillCursors) + res, _ := runOperationWithDeployment(t, cmd, db.client.deployment, driver.LegacyKillCursors) if len(res) != 0 { t.Fatalf("got non-empty response %v", res) } diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 629e0e6e6d..c60dd81667 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -25,6 +25,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" ) @@ -94,9 +95,10 @@ type ClientOptions struct { err error - // Adds an option for internal use only and should not be set. This option is deprecated and is - // not part of the stability guarantee. It may be removed in the future. + // These options are for internal use only and should not be set. They are deprecated and are + // not part of the stability guarantee. They may be removed in the future. AuthenticateToAnything *bool + Deployment driver.Deployment } // Client creates a new ClientOptions instance. diff --git a/mongo/results_test.go b/mongo/results_test.go index b7aa3eecf4..6eda93dcde 100644 --- a/mongo/results_test.go +++ b/mongo/results_test.go @@ -15,8 +15,6 @@ import ( ) func TestDeleteResult_unmarshalInto(t *testing.T) { - t.Parallel() - doc := bsonx.Doc{ {"n", bsonx.Int64(2)}, {"ok", bsonx.Int64(1)}, @@ -32,8 +30,6 @@ func TestDeleteResult_unmarshalInto(t *testing.T) { } func TestDeleteResult_marshalFrom(t *testing.T) { - t.Parallel() - result := DeleteResult{DeletedCount: 1} buf, err := bson.Marshal(result) require.Nil(t, err) @@ -49,8 +45,6 @@ func TestDeleteResult_marshalFrom(t *testing.T) { } func TestUpdateOneResult_unmarshalInto(t *testing.T) { - t.Parallel() - doc := bsonx.Doc{ {"n", bsonx.Int32(1)}, {"nModified", bsonx.Int32(2)}, diff --git a/mongo/retryable_writes_test.go b/mongo/retryable_writes_test.go index dd96367ce4..42f7346de8 100644 --- a/mongo/retryable_writes_test.go +++ b/mongo/retryable_writes_test.go @@ -370,7 +370,7 @@ func createRetryMonitoredClient(t *testing.T, monitor *event.CommandMonitor) *Cl clock := &session.ClusterClock{} c := &Client{ - topology: createRetryMonitoredTopology(t, clock, monitor), + deployment: createRetryMonitoredTopology(t, clock, monitor), connString: testutil.ConnString(t), readPreference: readpref.Primary(), clock: clock, @@ -378,9 +378,9 @@ func createRetryMonitoredClient(t *testing.T, monitor *event.CommandMonitor) *Cl monitor: monitor, } - subscription, err := c.topology.Subscribe() + subscription, err := c.deployment.(driver.Subscriber).Subscribe() testhelpers.RequireNil(t, err, "error subscribing to topology: %s", err) - c.topology.SessionPool = session.NewPool(subscription.C) + c.sessionPool = session.NewPool(subscription.Updates) return c } diff --git a/mongo/session.go b/mongo/session.go index a84a33bad1..00ccf09360 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -19,7 +19,6 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" - "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) // ErrWrongClient is returned when a user attempts to pass in a session created by a different client than @@ -65,7 +64,7 @@ type Session interface { type sessionImpl struct { clientSession *session.Client client *Client - topo *topology.Topology + deployment driver.Deployment didCommitAfterStart bool // true if commit was called after start with no other operations } @@ -186,7 +185,7 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { s.clientSession.Aborting = true _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). - Deployment(s.topo).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). + Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).Execute(ctx) s.clientSession.Aborting = false @@ -216,7 +215,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { s.clientSession.Committing = true op := operation.NewCommitTransaction(). - Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.topo). + Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)) if s.clientSession.CurrentMct != nil { diff --git a/mongo/sessions_test.go b/mongo/sessions_test.go index 39f2e6fe5a..1a9d10cfe4 100644 --- a/mongo/sessions_test.go +++ b/mongo/sessions_test.go @@ -7,18 +7,16 @@ package mongo import ( + "bytes" "context" + "fmt" + "os" "path" "reflect" + "strings" "testing" - - "fmt" - "os" "time" - "bytes" - "strings" - "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" @@ -30,6 +28,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -223,7 +222,7 @@ func createSessionsMonitoredClient(t *testing.T, monitor *event.CommandMonitor) clock := &session.ClusterClock{} c := &Client{ - topology: createMonitoredTopology(t, clock, monitor, nil), + deployment: createMonitoredTopology(t, clock, monitor, nil), connString: testutil.ConnString(t), readPreference: readpref.Primary(), readConcern: readconcern.Local(), @@ -232,9 +231,9 @@ func createSessionsMonitoredClient(t *testing.T, monitor *event.CommandMonitor) monitor: monitor, } - subscription, err := c.topology.Subscribe() + subscription, err := c.deployment.(driver.Subscriber).Subscribe() testhelpers.RequireNil(t, err, "error subscribing to topology: %s", err) - c.topology.SessionPool = session.NewPool(subscription.C) + c.sessionPool = session.NewPool(subscription.Updates) return c } @@ -294,7 +293,7 @@ func getTestName(t *testing.T) string { } func verifySessionsReturned(t *testing.T, client *Client) { - checkedOut := client.topology.SessionPool.CheckedOut() + checkedOut := client.sessionPool.CheckedOut() if checkedOut != 0 { t.Fatalf("%d sessions not returned for %s", checkedOut, t.Name()) } @@ -323,7 +322,7 @@ func drainCursor(returnVals []reflect.Value) { } func testCheckedOut(t *testing.T, client *Client, expected int) { - actual := client.topology.SessionPool.CheckedOut() + actual := client.sessionPool.CheckedOut() if actual != expected { t.Fatalf("checked out mismatch. expected %d got %d", expected, actual) } diff --git a/mongo/transactions_test.go b/mongo/transactions_test.go index cc4f03b703..bfafeb06bd 100644 --- a/mongo/transactions_test.go +++ b/mongo/transactions_test.go @@ -7,18 +7,15 @@ package mongo import ( + "bytes" + "context" "encoding/json" "io/ioutil" - "testing" - - "context" - - "strings" - "time" - - "bytes" "os" "path" + "strings" + "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -33,6 +30,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -365,7 +363,7 @@ func runTransactionsTestCase(t *testing.T, test *transTestCase, testfile transTe func killSessions(t *testing.T, client *Client) { err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendArrayElement(nil, "killAllSessions", bsoncore.BuildArray(nil)))). - Database("admin").ServerSelector(description.WriteSelector()).Deployment(client.topology).Execute(context.Background()) + Database("admin").ServerSelector(description.WriteSelector()).Deployment(client.deployment).Execute(context.Background()) require.NoError(t, err) } @@ -396,7 +394,7 @@ func createTransactionsMonitoredClient(t *testing.T, monitor *event.CommandMonit cs.Hosts = []string{host} } c := &Client{ - topology: createMonitoredTopology(t, clock, monitor, &cs), + deployment: createMonitoredTopology(t, clock, monitor, &cs), connString: cs, readPreference: readpref.Primary(), clock: clock, @@ -405,9 +403,9 @@ func createTransactionsMonitoredClient(t *testing.T, monitor *event.CommandMonit } addClientOptions(c, opts) - subscription, err := c.topology.Subscribe() + subscription, err := c.deployment.(driver.Subscriber).Subscribe() testhelpers.RequireNil(t, err, "error subscribing to topology: %s", err) - c.topology.SessionPool = session.NewPool(subscription.C) + c.sessionPool = session.NewPool(subscription.Updates) return c } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 4ff81a8ae1..a9c8601700 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -14,6 +14,30 @@ type Deployment interface { Kind() description.TopologyKind } +// Connector represents a type that can connect to a server. +type Connector interface { + Connect() error +} + +// Disconnector represents a type that can disconnect from a server. +type Disconnector interface { + Disconnect(context.Context) error +} + +// Subscription represents a subscription to topology updates. A subscriber can receive updates through the +// Updates field. +type Subscription struct { + Updates <-chan description.Topology + ID uint64 +} + +// Subscriber represents a type to which another type can subscribe. A subscription contains a channel that +// is updated with topology descriptions. +type Subscriber interface { + Subscribe() (*Subscription, error) + Unsubscribe(*Subscription) error +} + // Server represents a MongoDB server. Implementations should pool connections and handle the // retrieving and returning of connections. type Server interface { diff --git a/x/mongo/driver/examples/cluster_monitoring/main.go b/x/mongo/driver/examples/cluster_monitoring/main.go index 485db0801b..f36eefbaa2 100644 --- a/x/mongo/driver/examples/cluster_monitoring/main.go +++ b/x/mongo/driver/examples/cluster_monitoring/main.go @@ -28,7 +28,7 @@ func main() { log.Fatalf("could not subscribe to topology: %v", err) } - for desc := range sub.C { + for desc := range sub.Updates { log.Printf("%# v", pretty.Formatter(desc)) } } diff --git a/x/mongo/driver/topology/polling_srv_records_test.go b/x/mongo/driver/topology/polling_srv_records_test.go index 33f82f5301..a7b93a7147 100644 --- a/x/mongo/driver/topology/polling_srv_records_test.go +++ b/x/mongo/driver/topology/polling_srv_records_test.go @@ -141,7 +141,7 @@ func TestPollingSRVRecordsSpec(t *testing.T) { require.NoError(t, err, "Couldn't subscribe: %v", err) var desc description.Topology for atomic.LoadInt32(&mockRes.ranLookup) < 2 { - desc = <-sub.C + desc = <-sub.Updates } require.True(t, tt.heartbeatTime == topo.pollHeartbeatTime.Load().(bool), "Not polling on correct intervals") @@ -187,7 +187,7 @@ func TestPollSRVRecords(t *testing.T) { } for i := 0; i < 4; i++ { - <-sub.C + <-sub.Updates } require.False(t, atomic.LoadInt32(&mockRes.ranLookup) > 0) @@ -219,7 +219,7 @@ func TestPollSRVRecords(t *testing.T) { require.NoError(t, err, "Couldn't subscribe: %v", err) var desc description.Topology for atomic.LoadInt32(&mockRes.ranLookup) < 2 { - desc = <-sub.C + desc = <-sub.Updates } require.False(t, topo.pollHeartbeatTime.Load().(bool)) @@ -251,7 +251,7 @@ func TestPollSRVRecords(t *testing.T) { require.NoError(t, err, "Couldn't subscribe: %v", err) var desc description.Topology for atomic.LoadInt32(&mockRes.ranLookup) < 3 { - desc = <-sub.C + desc = <-sub.Updates } require.False(t, topo.pollHeartbeatTime.Load().(bool)) diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 97cedeb47a..a36007a336 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -25,7 +25,6 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/address" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/dns" - "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) // ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a @@ -72,8 +71,6 @@ type Topology struct { fsm *fsm - SessionPool *session.Pool - // This should really be encapsulated into it's own type. This will likely // require a redesign so we can share a minimum of data between the // subscribers and the topology. @@ -91,6 +88,9 @@ type Topology struct { servers map[address.Address]*Server } +var _ driver.Deployment = &Topology{} +var _ driver.Subscriber = &Topology{} + // New creates a new topology. func New(opts ...Option) (*Topology, error) { cfg, err := newConfig(opts...) @@ -136,6 +136,9 @@ func (t *Topology) Connect() error { addr := address.Address(a).Canonicalize() t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr}) err = t.addServer(addr) + if err != nil { + return err + } } t.serversLock.Unlock() @@ -147,11 +150,7 @@ func (t *Topology) Connect() error { t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected atomic.StoreInt32(&t.connectionstate, connected) - - // After connection, make a subscription to keep the pool updated - sub, err := t.Subscribe() - t.SessionPool = session.NewPool(sub.C) - return err + return nil } // Disconnect closes the topology. It stops the monitoring thread and @@ -211,7 +210,8 @@ func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind // Subscribe returns a Subscription on which all updated description.Topologys // will be sent. The channel of the subscription will have a buffer size of one, // and will be pre-populated with the current description.Topology. -func (t *Topology) Subscribe() (*Subscription, error) { +// Subscribe implements the driver.Subscriber interface. +func (t *Topology) Subscribe() (*driver.Subscription, error) { if atomic.LoadInt32(&t.connectionstate) != connected { return nil, errors.New("cannot subscribe to Topology that is not connected") } @@ -231,13 +231,32 @@ func (t *Topology) Subscribe() (*Subscription, error) { t.subscribers[id] = ch t.currentSubscriberID++ - return &Subscription{ - C: ch, - t: t, - id: id, + return &driver.Subscription{ + Updates: ch, + ID: id, }, nil } +// Unsubscribe unsubscribes the given subscription from the topology and closes the subscription channel. +// Unsubscribe implements the driver.Subscriber interface. +func (t *Topology) Unsubscribe(sub *driver.Subscription) error { + t.subLock.Lock() + defer t.subLock.Unlock() + + if t.subscriptionsClosed { + return nil + } + + ch, ok := t.subscribers[sub.ID] + if !ok { + return nil + } + + close(ch) + delete(t.subscribers, sub.ID) + return nil +} + // RequestImmediateCheck will send heartbeats to all the servers in the // topology right away, instead of waiting for the heartbeat timeout. func (t *Topology) RequestImmediateCheck() { @@ -280,10 +299,10 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect if err != nil { return nil, err } - defer sub.Unsubscribe() + defer t.Unsubscribe(sub) for { - suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh) + suitable, err := t.selectServer(ctx, sub.Updates, ss, ssTimeoutCh) if err != nil { return nil, err } @@ -323,10 +342,10 @@ func (t *Topology) SelectServerLegacy(ctx context.Context, ss description.Server if err != nil { return nil, err } - defer sub.Unsubscribe() + defer t.Unsubscribe(sub) for { - suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh) + suitable, err := t.selectServer(ctx, sub.Updates, ss, ssTimeoutCh) if err != nil { return nil, err } @@ -596,31 +615,3 @@ func (t *Topology) String() string { } return str } - -// Subscription is a subscription to updates to the description of the Topology that created this -// Subscription. -type Subscription struct { - C <-chan description.Topology - t *Topology - id uint64 -} - -// Unsubscribe unsubscribes this Subscription from updates and closes the -// subscription channel. -func (s *Subscription) Unsubscribe() error { - s.t.subLock.Lock() - defer s.t.subLock.Unlock() - if s.t.subscriptionsClosed { - return nil - } - - ch, ok := s.t.subscribers[s.id] - if !ok { - return nil - } - - close(ch) - delete(s.t.subscribers, s.id) - - return nil -}