Skip to content

Commit

Permalink
Change Client to depend on driver.Deployment
Browse files Browse the repository at this point in the history
- Move session pooling from Topology to Client.
- Add method to create Client from a Deployment.
- Add interfaces to allow arbitrary Deployments to support connecting, disconnecting, and sessions.

GODRIVER-1250

Change-Id: I64c02c29424a3690d43c995c8bd2bcbc19cb9e96
  • Loading branch information
Divjot Arora committed Aug 26, 2019
1 parent d7f1b55 commit bb9530d
Show file tree
Hide file tree
Showing 24 changed files with 233 additions and 239 deletions.
2 changes: 1 addition & 1 deletion .errcheck-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
(*go.mongodb.org/mongo-driver/x/network/connection.connection).Close
(go.mongodb.org/mongo-driver/x/network/connection.Connection).Close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.connection).close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Subscription).Unsubscribe
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Topology).Unsubscribe
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Server).Close
(*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.pool).close
Expand Down
22 changes: 22 additions & 0 deletions internal/testutil/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ import (
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
)

var connectionString connstring.ConnString
var connectionStringOnce sync.Once
var connectionStringErr error
var liveTopology *topology.Topology
var liveSessionPool *session.Pool
var liveTopologyOnce sync.Once
var liveTopologyErr error
var monitoredTopology *topology.Topology
var monitoredSessionPool *session.Pool
var monitoredTopologyOnce sync.Once
var monitoredTopologyErr error

Expand Down Expand Up @@ -144,6 +147,10 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol
Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background())

require.NoError(t, err)

sub, err := monitoredTopology.Subscribe()
require.NoError(t, err)
monitoredSessionPool = session.NewPool(sub.Updates)
}
})

Expand All @@ -154,6 +161,12 @@ func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topol
return monitoredTopology
}

// GlobalMonitoredSessionPool returns the globally configured session pool.
// Must be called after GlobalMonitoredTopology()
func GlobalMonitoredSessionPool() *session.Pool {
return monitoredSessionPool
}

// Topology gets the globally configured topology.
func Topology(t *testing.T) *topology.Topology {
cs := ConnString(t)
Expand All @@ -168,6 +181,10 @@ func Topology(t *testing.T) *topology.Topology {
err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background())
require.NoError(t, err)

sub, err := liveTopology.Subscribe()
require.NoError(t, err)
liveSessionPool = session.NewPool(sub.Updates)
}
})

Expand All @@ -178,6 +195,11 @@ func Topology(t *testing.T) *topology.Topology {
return liveTopology
}

// SessionPool gets the globally configured session pool. Must be called after Topology().
func SessionPool() *session.Pool {
return liveSessionPool
}

// TopologyWithConnString takes a connection string and returns a connected
// topology, or else bails out of testing
func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology {
Expand Down
6 changes: 3 additions & 3 deletions mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera
Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
Database(bw.collection.db.name).Collection(bw.collection.name).
Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt)
Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
}
Expand Down Expand Up @@ -223,7 +223,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera
Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
Database(bw.collection.db.name).Collection(bw.collection.name).
Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt)
Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
if bw.ordered != nil {
op = op.Ordered(*bw.ordered)
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera
Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
Database(bw.collection.db.name).Collection(bw.collection.name).
Deployment(bw.collection.client.topology).Crypt(bw.collection.client.crypt)
Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
if bw.ordered != nil {
op = op.Ordered(*bw.ordered)
}
Expand Down
10 changes: 5 additions & 5 deletions mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
}

cs.sess = sessionFromContext(ctx)
if cs.sess == nil && cs.client.topology.SessionPool != nil {
cs.sess, cs.err = session.NewClientSession(cs.client.topology.SessionPool, cs.client.id, session.Implicit)
if cs.sess == nil && cs.client.sessionPool != nil {
cs.sess, cs.err = session.NewClientSession(cs.client.sessionPool, cs.client.id, session.Implicit)
if cs.err != nil {
return nil, cs.Err()
}
Expand All @@ -100,7 +100,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in

cs.aggregate = operation.NewAggregate(nil).
ReadPreference(config.readPreference).ReadConcern(config.readConcern).
Deployment(cs.client.topology).ClusterClock(cs.client.clock).
Deployment(cs.client.deployment).ClusterClock(cs.client.clock).
CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone)

if cs.options.Collation != nil {
Expand Down Expand Up @@ -161,7 +161,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error {
var server driver.Server
var conn driver.Connection
if server, cs.err = cs.client.topology.SelectServer(ctx, cs.selector); cs.err != nil {
if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil {
return cs.Err()
}
if conn, cs.err = server.Connection(ctx); cs.err != nil {
Expand Down Expand Up @@ -204,7 +204,7 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err
break
}

server, err := cs.client.topology.SelectServer(ctx, cs.selector)
server, err := cs.client.deployment.SelectServer(ctx, cs.selector)
if err != nil {
break
}
Expand Down
6 changes: 0 additions & 6 deletions mongo/change_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ func TestChangeStream(t *testing.T) {
skipIfBelow36(t)

t.Run("TestFirstStage", func(t *testing.T) {
t.Parallel()

if testing.Short() {
t.Skip()
}
Expand Down Expand Up @@ -236,8 +234,6 @@ func TestChangeStream(t *testing.T) {
})

t.Run("TestReplaceRoot", func(t *testing.T) {
t.Parallel()

if testing.Short() {
t.Skip()
}
Expand Down Expand Up @@ -280,8 +276,6 @@ func TestChangeStream(t *testing.T) {
})

t.Run("TestNoCustomStandaloneError", func(t *testing.T) {
t.Parallel()

if testing.Short() {
t.Skip()
}
Expand Down
63 changes: 45 additions & 18 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package mongo
import (
"context"
"crypto/tls"
"errors"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -42,7 +43,7 @@ var keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority(
type Client struct {
id uuid.UUID
topologyOptions []topology.Option
topology *topology.Topology
deployment driver.Deployment
connString connstring.ConnString
localThreshold time.Duration
retryWrites bool
Expand All @@ -54,6 +55,7 @@ type Client struct {
registry *bsoncodec.Registry
marshaller BSONAppender
monitor *event.CommandMonitor
sessionPool *session.Pool

// client-side encryption fields
keyVaultClient *Client
Expand Down Expand Up @@ -102,7 +104,7 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
return nil, err
}

client.topology, err = topology.New(client.topologyOptions...)
client.deployment, err = topology.New(client.topologyOptions...)
if err != nil {
return nil, replaceErrors(err)
}
Expand All @@ -113,20 +115,33 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
// Connect initializes the Client by starting background monitoring goroutines.
// This method must be called before a Client can be used.
func (c *Client) Connect(ctx context.Context) error {
err := c.topology.Connect()
if err != nil {
return replaceErrors(err)
if connector, ok := c.deployment.(driver.Connector); ok {
err := connector.Connect()
if err != nil {
return replaceErrors(err)
}
}

if c.mongocryptd != nil {
if err = c.mongocryptd.connect(ctx); err != nil {
if err := c.mongocryptd.connect(ctx); err != nil {
return err
}
}
if c.keyVaultClient != nil {
if err = c.keyVaultClient.Connect(ctx); err != nil {
if err := c.keyVaultClient.Connect(ctx); err != nil {
return err
}
}

var updateChan <-chan description.Topology
if subscriber, ok := c.deployment.(driver.Subscriber); ok {
sub, err := subscriber.Subscribe()
if err != nil {
return replaceErrors(err)
}
updateChan = sub.Updates
}
c.sessionPool = session.NewPool(updateChan)
return nil
}

Expand Down Expand Up @@ -157,7 +172,11 @@ func (c *Client) Disconnect(ctx context.Context) error {
if c.crypt != nil {
c.crypt.Close()
}
return replaceErrors(c.topology.Disconnect(ctx))

if disconnector, ok := c.deployment.(driver.Disconnector); ok {
return replaceErrors(disconnector.Disconnect(ctx))
}
return nil
}

// Ping verifies that the client can connect to the topology.
Expand All @@ -182,7 +201,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {

// StartSession starts a new session.
func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) {
if c.topology.SessionPool == nil {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

Expand All @@ -208,7 +227,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error)
coreOpts.DefaultMaxCommitTime = sopts.DefaultMaxCommitTime
}

sess, err := session.NewClientSession(c.topology.SessionPool, c.id, session.Explicit, coreOpts)
sess, err := session.NewClientSession(c.sessionPool, c.id, session.Explicit, coreOpts)
if err != nil {
return nil, replaceErrors(err)
}
Expand All @@ -219,24 +238,24 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error)
return &sessionImpl{
clientSession: sess,
client: c,
topo: c.topology,
deployment: c.deployment,
}, nil
}

func (c *Client) endSessions(ctx context.Context) {
if c.topology.SessionPool == nil {
if c.sessionPool == nil {
return
}

ids := c.topology.SessionPool.IDSlice()
ids := c.sessionPool.IDSlice()
idx, idArray := bsoncore.AppendArrayStart(nil)
for i, id := range ids {
idDoc, _ := id.MarshalBSON()
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
}
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)

op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.topology).
op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.deployment).
ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor).
Database("admin").Crypt(c.crypt)

Expand Down Expand Up @@ -493,6 +512,14 @@ func (c *Client) configure(opts *options.ClientOptions) error {
func(...topology.ServerOption) []topology.ServerOption { return serverOpts },
))

// Deployment
if opts.Deployment != nil {
if len(serverOpts) > 2 || len(topologyOpts) > 1 {
return errors.New("cannot specify topology or server options with a deployment")
}
c.deployment = opts.Deployment
}

return nil
}

Expand Down Expand Up @@ -585,8 +612,8 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...
sess := sessionFromContext(ctx)

err := c.validSession(sess)
if sess == nil && c.topology.SessionPool != nil {
sess, err = session.NewClientSession(c.topology.SessionPool, c.id, session.Implicit)
if sess == nil && c.sessionPool != nil {
sess, err = session.NewClientSession(c.sessionPool, c.id, session.Implicit)
if err != nil {
return ListDatabasesResult{}, err
}
Expand All @@ -612,7 +639,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...
ldo := options.MergeListDatabasesOptions(opts...)
op := operation.NewListDatabases(filterDoc).
Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor).
ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.topology).Crypt(c.crypt)
ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.crypt)
if ldo.NameOnly != nil {
op = op.NameOnly(*ldo.NameOnly)
}
Expand Down Expand Up @@ -700,7 +727,7 @@ func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.Sessio
// The client must have read concern majority or no read concern for a change stream to be created successfully.
func (c *Client) Watch(ctx context.Context, pipeline interface{},
opts ...*options.ChangeStreamOptions) (*ChangeStream, error) {
if c.topology.SessionPool == nil {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

Expand Down
Loading

0 comments on commit bb9530d

Please sign in to comment.