@@ -11,8 +11,14 @@ package topology
11
11
12
12
import (
13
13
"context"
14
+ "crypto/tls"
15
+ "crypto/x509"
14
16
"errors"
17
+ "fmt"
18
+ "io/ioutil"
19
+ "log"
15
20
"net"
21
+ "os"
16
22
"runtime"
17
23
"sync"
18
24
"sync/atomic"
@@ -54,14 +60,14 @@ type errorQueue struct {
54
60
mutex sync.Mutex
55
61
}
56
62
57
- func (eq * errorQueue ) head () error {
58
- eq .mutex .Lock ()
59
- defer eq .mutex .Unlock ()
60
- if len (eq .errors ) > 0 {
61
- return eq .errors [0 ]
62
- }
63
- return nil
64
- }
63
+ // func (eq *errorQueue) head() error {
64
+ // eq.mutex.Lock()
65
+ // defer eq.mutex.Unlock()
66
+ // if len(eq.errors) > 0 {
67
+ // return eq.errors[0]
68
+ // }
69
+ // return nil
70
+ // }
65
71
66
72
func (eq * errorQueue ) dequeue () bool {
67
73
eq .mutex .Lock ()
@@ -73,78 +79,95 @@ func (eq *errorQueue) dequeue() bool {
73
79
return false
74
80
}
75
81
82
+ /*
76
83
type timeoutConn struct {
77
84
net.Conn
78
85
errors *errorQueue
79
86
}
80
87
81
88
func (c *timeoutConn) Read(b []byte) (int, error) {
82
- n , err := 0 , c .errors .head ()
83
- if err == nil {
84
- n , err = c .Conn .Read (b )
85
- }
89
+ //n, err := 0, c.errors.head()
90
+ //if err == nil {
91
+ n, err := c.Conn.Read(b)
92
+ //}
93
+ log.Println("read", n, err)
86
94
return n, err
87
95
}
88
96
89
97
func (c *timeoutConn) Write(b []byte) (int, error) {
90
- n , err := 0 , c .errors .head ()
91
- if err == nil {
92
- n , err = c .Conn .Write (b )
93
- }
98
+ //n, err := 0, c.errors.head()
99
+ //if err == nil {
100
+ n, err := c.Conn.Write(b)
101
+ //}
102
+ log.Println("write", n, err)
94
103
return n, err
95
104
}
105
+ */
96
106
97
107
type timeoutDialer struct {
98
108
Dialer
99
109
errors * errorQueue
100
110
}
101
111
102
112
func (d * timeoutDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
103
- var dialer net.Dialer
104
- c , e := dialer .DialContext (ctx , network , address )
105
- return & timeoutConn {c , d .errors }, e
106
- }
113
+ c , e := d .Dialer .DialContext (ctx , network , address )
114
+
115
+ caFile := os .Getenv ("MONGO_GO_DRIVER_CA_FILE" )
116
+ log .Println ("dial" , caFile )
117
+ if len (caFile ) > 0 {
118
+ pem , err := ioutil .ReadFile (caFile )
119
+ if err != nil {
120
+ return nil , err
121
+ }
107
122
108
- type timeoutErr struct {
109
- net.UnknownNetworkError
110
- }
123
+ ca := x509 .NewCertPool ()
124
+ if ! ca .AppendCertsFromPEM (pem ) {
125
+ return nil , fmt .Errorf ("unable to load CA file" )
126
+ }
111
127
112
- func (e * timeoutErr ) Timeout () bool {
113
- return true
128
+ config := & tls.Config {
129
+ InsecureSkipVerify : true ,
130
+ RootCAs : ca ,
131
+ }
132
+ c = tls .Client (c , config )
133
+ }
134
+ return c , e
114
135
}
115
136
116
- var timeout = & timeoutErr {"test timeout" }
117
-
118
137
// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
119
138
func TestServerHeartbeatTimeout (t * testing.T ) {
139
+ networkTimeoutError := & net.DNSError {
140
+ IsTimeout : true ,
141
+ }
142
+
120
143
testCases := []struct {
121
144
desc string
122
145
ioErrors []error
123
146
expectPoolCleared bool
124
147
}{
125
148
{
126
149
desc : "one single timeout should not clear the pool" ,
127
- ioErrors : []error {nil , timeout , nil , timeout , nil },
150
+ ioErrors : []error {nil , networkTimeoutError , nil , networkTimeoutError , nil },
128
151
expectPoolCleared : false ,
129
152
},
130
- {
131
- desc : "continuous timeouts should clear the pool" ,
132
- ioErrors : []error {nil , timeout , timeout , nil },
133
- expectPoolCleared : true ,
134
- },
153
+ // {
154
+ // desc: "continuous timeouts should clear the pool",
155
+ // ioErrors: []error{nil, networkTimeoutError, networkTimeoutError , nil},
156
+ // expectPoolCleared: true,
157
+ // },
135
158
}
136
159
for _ , tc := range testCases {
137
160
tc := tc
138
161
t .Run (tc .desc , func (t * testing.T ) {
139
- t .Parallel ()
162
+ // t.Parallel()
140
163
141
164
var wg sync.WaitGroup
142
165
wg .Add (1 )
143
166
144
167
errors := & errorQueue {errors : tc .ioErrors }
145
168
tpm := monitor .NewTestPoolMonitor ()
146
169
server := NewServer (
147
- address .Address ("localhost:27017 " ),
170
+ address .Address ("localhost" ),
148
171
primitive .NewObjectID (),
149
172
WithConnectionPoolMonitor (func (* event.PoolMonitor ) * event.PoolMonitor {
150
173
return tpm .PoolMonitor
@@ -165,6 +188,9 @@ func TestServerHeartbeatTimeout(t *testing.T) {
165
188
},
166
189
}
167
190
}),
191
+ WithHeartbeatInterval (func (time.Duration ) time.Duration {
192
+ return 3 * time .Second
193
+ }),
168
194
)
169
195
require .NoError (t , server .Connect (nil ))
170
196
wg .Wait ()
0 commit comments