Skip to content

Commit 28c8df9

Browse files
committed
GODRIVER-2577 Fix SSL dialer
1 parent 5f63f1b commit 28c8df9

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

x/mongo/driver/topology/server_test.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ package topology
1111

1212
import (
1313
"context"
14+
"crypto/tls"
15+
"crypto/x509"
1416
"errors"
17+
"io/ioutil"
1518
"net"
19+
"os"
1620
"runtime"
1721
"sync"
1822
"sync/atomic"
@@ -100,36 +104,47 @@ type timeoutDialer struct {
100104
}
101105

102106
func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
103-
var dialer net.Dialer
104-
c, e := dialer.DialContext(ctx, network, address)
105-
return &timeoutConn{c, d.errors}, e
106-
}
107+
c, e := d.Dialer.DialContext(ctx, network, address)
107108

108-
type timeoutErr struct {
109-
net.UnknownNetworkError
110-
}
109+
if caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE"); len(caFile) > 0 {
110+
pem, err := ioutil.ReadFile(caFile)
111+
if err != nil {
112+
return nil, err
113+
}
111114

112-
func (e *timeoutErr) Timeout() bool {
113-
return true
114-
}
115+
ca := x509.NewCertPool()
116+
if !ca.AppendCertsFromPEM(pem) {
117+
return nil, errors.New("unable to load CA file")
118+
}
115119

116-
var timeout = &timeoutErr{"test timeout"}
120+
config := &tls.Config{
121+
InsecureSkipVerify: true,
122+
RootCAs: ca,
123+
}
124+
c = tls.Client(c, config)
125+
}
126+
return &timeoutConn{c, d.errors}, e
127+
}
117128

118129
// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
119130
func TestServerHeartbeatTimeout(t *testing.T) {
131+
networkTimeoutError := &net.DNSError{
132+
IsTimeout: true,
133+
}
134+
120135
testCases := []struct {
121136
desc string
122137
ioErrors []error
123138
expectPoolCleared bool
124139
}{
125140
{
126141
desc: "one single timeout should not clear the pool",
127-
ioErrors: []error{nil, timeout, nil, timeout, nil},
142+
ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
128143
expectPoolCleared: false,
129144
},
130145
{
131146
desc: "continuous timeouts should clear the pool",
132-
ioErrors: []error{nil, timeout, timeout, nil},
147+
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
133148
expectPoolCleared: true,
134149
},
135150
}
@@ -165,6 +180,9 @@ func TestServerHeartbeatTimeout(t *testing.T) {
165180
},
166181
}
167182
}),
183+
WithHeartbeatInterval(func(time.Duration) time.Duration {
184+
return 2 * time.Second
185+
}),
168186
)
169187
require.NoError(t, server.Connect(nil))
170188
wg.Wait()

0 commit comments

Comments
 (0)