Skip to content

Commit

Permalink
GODRIVER-2086 Allow custom SRV service names with URI option srvServi…
Browse files Browse the repository at this point in the history
…ceName (mongodb#734)
  • Loading branch information
benjirewis authored Sep 24, 2021
1 parent f9279c4 commit be3e3f2
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
"seeds": [
"localhost.test.build.10gen.cc:27017",
"localhost.test.build.10gen.cc:27018"
],
"hosts": [
"localhost:27017",
"localhost:27018",
"localhost:27019"
],
"options": {
"ssl": true,
"srvServiceName": "customname"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
uri: "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
seeds:
- localhost.test.build.10gen.cc:27017
- localhost.test.build.10gen.cc:27018
hosts:
- localhost:27017
- localhost:27018
- localhost:27019
options:
ssl: true
srvServiceName: "customname"
24 changes: 24 additions & 0 deletions data/uri-options/srv-service-name-option.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"tests": [
{
"description": "SRV URI with custom srvServiceName",
"uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"srvServiceName": "customname"
}
},
{
"description": "Non-SRV URI with custom srvServiceName",
"uri": "mongodb://example.com/?srvServiceName=customname",
"valid": false,
"warning": true,
"hosts": null,
"auth": null,
"options": {}
}
]
}
16 changes: 16 additions & 0 deletions data/uri-options/srv-service-name-option.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
tests:
- description: "SRV URI with custom srvServiceName"
uri: "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
valid: true
warning: false
hosts: ~
auth: ~
options:
srvServiceName: "customname"
- description: "Non-SRV URI with custom srvServiceName"
uri: "mongodb://example.com/?srvServiceName=customname"
valid: false
warning: true
hosts: ~
auth: ~
options: {}
2 changes: 2 additions & 0 deletions internal/testutil/helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ func VerifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map
require.Equal(t, value, cs.RetryWrites)
case "serverselectiontimeoutms":
require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
case "srvservicename":
require.Equal(t, value, cs.SRVServiceName)
case "ssl", "tls":
require.Equal(t, value, cs.SSL)
case "sockettimeoutms":
Expand Down
18 changes: 12 additions & 6 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,16 @@ func (c *Client) configure(opts *options.ClientOptions) error {
// ClusterClock
c.clock = new(session.ClusterClock)

// Pass down URI so topology can determine whether or not SRV polling is required
topologyOpts = append(topologyOpts, topology.WithURI(func(uri string) string {
return opts.GetURI()
}))
// Pass down URI and SRV service name so topology can poll SRV records correctly
topologyOpts = append(topologyOpts,
topology.WithURI(func(uri string) string { return opts.GetURI() }),
topology.WithSRVServiceName(func(srvName string) string {
if opts.SRVServiceName != nil {
return *opts.SRVServiceName
}
return ""
}),
)

// AppName
var appName string
Expand Down Expand Up @@ -669,9 +675,9 @@ func (c *Client) configure(opts *options.ClientOptions) error {

// Deployment
if opts.Deployment != nil {
// topology options: WithSeedlist and WithURI
// topology options: WithSeedlist, WithURI and WithSRVServiceName
// server options: WithClock and WithConnectionOptions
if len(serverOpts) > 2 || len(topologyOpts) > 2 {
if len(serverOpts) > 2 || len(topologyOpts) > 3 {
return errors.New("cannot specify topology or server options with a deployment")
}
c.deployment = opts.Deployment
Expand Down
3 changes: 3 additions & 0 deletions mongo/integration/initial_dns_seedlist_discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ func verifyConnstringOptions(mt *mtest.T, expected bson.Raw, cs connstring.ConnS
lb := opt.Boolean()
assert.True(mt, cs.LoadBalancedSet, "expected cs.LoadBalancedSet set to be true, got false")
assert.Equal(mt, lb, cs.LoadBalanced, "expected cs.LoadBalanced to be %v, got %v", lb, cs.LoadBalanced)
case "srvServiceName":
srvName := opt.StringValue()
assert.Equal(mt, srvName, cs.SRVServiceName, "expected cs.SRVServiceName to be %q, got %q", srvName, cs.SRVServiceName)
default:
mt.Fatalf("unrecognized connstring option %v", key)
}
Expand Down
16 changes: 16 additions & 0 deletions mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ type ClientOptions struct {
ServerAPIOptions *ServerAPIOptions
ServerSelectionTimeout *time.Duration
SocketTimeout *time.Duration
SRVServiceName *string
TLSConfig *tls.Config
WriteConcern *writeconcern.WriteConcern
ZlibLevel *int
Expand Down Expand Up @@ -338,6 +339,10 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions {
c.SocketTimeout = &cs.SocketTimeout
}

if cs.SRVServiceName != "" {
c.SRVServiceName = &cs.SRVServiceName
}

if cs.SSL {
tlsConfig := new(tls.Config)

Expand Down Expand Up @@ -760,6 +765,14 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio
return c
}

// SetSRVServiceName specifies a custom SRV service name to use in SRV polling. To use a custom SRV service name
// in SRV discovery, this function must be called before ApplyURI. This can also be set through the "srvServiceName"
// URI option.
func (c *ClientOptions) SetSRVServiceName(srvName string) *ClientOptions {
c.SRVServiceName = &srvName
return c
}

// MergeClientOptions combines the given *ClientOptions into a single *ClientOptions in a last one wins fashion.
// The specified options are merged with the existing options on the client, with the specified options taking
// precedence.
Expand Down Expand Up @@ -849,6 +862,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions {
if opt.SocketTimeout != nil {
c.SocketTimeout = opt.SocketTimeout
}
if opt.SRVServiceName != nil {
c.SRVServiceName = opt.SRVServiceName
}
if opt.TLSConfig != nil {
c.TLSConfig = opt.TLSConfig
}
Expand Down
73 changes: 48 additions & 25 deletions x/mongo/driver/connstring/connstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type ConnString struct {
ServerSelectionTimeoutSet bool
SocketTimeout time.Duration
SocketTimeoutSet bool
SRVServiceName string
SSL bool
SSLSet bool
SSLClientCertificateKeyFile string
Expand Down Expand Up @@ -254,14 +255,25 @@ func (p *parser) parse(original string) error {
hosts = uri[:idx]
}

var connectionArgsFromTXT []string
parsedHosts := strings.Split(hosts, ",")
uri = uri[len(hosts):]
extractedDatabase, err := extractDatabaseFromURI(uri)
if err != nil {
return err
}

uri = extractedDatabase.uri
p.Database = extractedDatabase.db

// grab connection arguments from URI
connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
if err != nil {
return err
}

// grab connection arguments from TXT record and enable SSL if "mongodb+srv://"
var connectionArgsFromTXT []string
if p.Scheme == SchemeMongoDBSRV {
parsedHosts, err = p.dnsResolver.ParseHosts(hosts, true)
if err != nil {
return err
}
connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
if err != nil {
return err
Expand All @@ -272,35 +284,32 @@ func (p *parser) parse(original string) error {
p.SSLSet = true
}

for _, host := range parsedHosts {
err = p.addHost(host)
// add connection arguments from URI and TXT records to connstring
connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...)
for _, pair := range connectionArgPairs {
err := p.addOption(pair)
if err != nil {
return internal.WrapErrorf(err, "invalid host \"%s\"", host)
return err
}
}
if len(p.Hosts) == 0 {
return fmt.Errorf("must have at least 1 host")
}

uri = uri[len(hosts):]

extractedDatabase, err := extractDatabaseFromURI(uri)
if err != nil {
return err
// do SRV lookup if "mongodb+srv://"
if p.Scheme == SchemeMongoDBSRV {
parsedHosts, err = p.dnsResolver.ParseHosts(hosts, p.SRVServiceName, true)
if err != nil {
return err
}
}

uri = extractedDatabase.uri
p.Database = extractedDatabase.db

connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...)

for _, pair := range connectionArgPairs {
err = p.addOption(pair)
for _, host := range parsedHosts {
err = p.addHost(host)
if err != nil {
return err
return internal.WrapErrorf(err, "invalid host \"%s\"", host)
}
}
if len(p.Hosts) == 0 {
return fmt.Errorf("must have at least 1 host")
}

err = p.setDefaultAuthParams(extractedDatabase.db)
if err != nil {
Expand Down Expand Up @@ -762,6 +771,20 @@ func (p *parser) addOption(pair string) error {
}
p.SocketTimeout = time.Duration(n) * time.Millisecond
p.SocketTimeoutSet = true
case "srvservicename":
// srvServiceName can only be set on URIs with the "mongodb+srv" scheme
if p.Scheme != SchemeMongoDBSRV {
return fmt.Errorf("cannot specify srvServiceName on non-SRV URI")
}

// srvServiceName must be between 1 and 62 characters according to
// our specification. Empty service names are not valid, and the service
// name (including prepended underscore) should not exceed the 63 character
// limit for DNS query subdomains.
if len(value) < 1 || len(value) > 62 {
return fmt.Errorf("srvServiceName value must be between 1 and 62 characters")
}
p.SRVServiceName = value
case "ssl", "tls":
switch value {
case "true":
Expand Down
14 changes: 9 additions & 5 deletions x/mongo/driver/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ type Resolver struct {
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}

// ParseHosts uses the srv string to get the hosts.
func (r *Resolver) ParseHosts(host string, stopOnErr bool) ([]string, error) {
// ParseHosts uses the srv string and service name to get the hosts.
func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
parsedHosts := strings.Split(host, ",")

if len(parsedHosts) != 1 {
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
}
return r.fetchSeedlistFromSRV(parsedHosts[0], stopOnErr)
return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
}

// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
Expand Down Expand Up @@ -64,7 +64,7 @@ func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
return connectionArgsFromTXT, nil
}

func (r *Resolver) fetchSeedlistFromSRV(host string, stopOnErr bool) ([]string, error) {
func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
var err error

_, _, err = net.SplitHostPort(host)
Expand All @@ -75,7 +75,11 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, stopOnErr bool) ([]string,
return nil, fmt.Errorf("URI with srv must not include a port number")
}

_, addresses, err := r.LookupSRV("mongodb", "tcp", host)
// default to "mongodb" as service name if not supplied
if srvName == "" {
srvName = "mongodb"
}
_, addresses, err := r.LookupSRV(srvName, "tcp", host)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (t *Topology) pollSRVRecords() {
break
}

parsedHosts, err := t.dnsResolver.ParseHosts(hosts, false)
parsedHosts, err := t.dnsResolver.ParseHosts(hosts, t.cfg.srvServiceName, false)
// DNS problem or no verified hosts returned
if err != nil || len(parsedHosts) == 0 {
if !t.pollHeartbeatTime.Load().(bool) {
Expand Down
25 changes: 17 additions & 8 deletions x/mongo/driver/topology/topology_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type config struct {
uri string
serverSelectionTimeout time.Duration
serverMonitor *event.ServerMonitor
srvServiceName string
loadBalanced bool
}

Expand Down Expand Up @@ -304,6 +305,22 @@ func WithURI(fn func(string) string) Option {
}
}

// WithLoadBalanced specifies whether or not the cluster is behind a load balancer.
func WithLoadBalanced(fn func(bool) bool) Option {
return func(cfg *config) error {
cfg.loadBalanced = fn(cfg.loadBalanced)
return nil
}
}

// WithSRVServiceName specifies the SRV service name that was used to create the topology.
func WithSRVServiceName(fn func(string) string) Option {
return func(cfg *config) error {
cfg.srvServiceName = fn(cfg.srvServiceName)
return nil
}
}

// addCACertFromFile adds a root CA certificate to the configuration given a path
// to the containing file.
func addCACertFromFile(cfg *tls.Config, file string) error {
Expand Down Expand Up @@ -422,11 +439,3 @@ func addClientCertFromFile(cfg *tls.Config, clientFile, keyPasswd string) (strin

return x509CertSubject(crt), nil
}

// WithLoadBalanced specifies whether or not the cluster is behind a load balancer.
func WithLoadBalanced(fn func(bool) bool) Option {
return func(cfg *config) error {
cfg.loadBalanced = fn(cfg.loadBalanced)
return nil
}
}

0 comments on commit be3e3f2

Please sign in to comment.