7
7
"os"
8
8
"strconv"
9
9
"strings"
10
- "sync"
10
+ "sync/atomic"
11
+ "time"
11
12
12
13
"golang.org/x/crypto/ssh"
13
14
)
@@ -24,8 +25,9 @@ func NewDialer(addr string) (*Dialer, error) {
24
25
25
26
func NewDialerWithConfig (host string , config * ssh.ClientConfig ) (* Dialer , error ) {
26
27
return & Dialer {
27
- host : host ,
28
- config : config ,
28
+ host : host ,
29
+ config : config ,
30
+ clients : make (chan * ssh.Client , 5 ),
29
31
}, nil
30
32
}
31
33
@@ -69,6 +71,17 @@ func parseClientConfig(addr string) (*clientConfig, error) {
69
71
config .Auth = append (config .Auth , ssh .PublicKeys (signer ))
70
72
}
71
73
74
+ var timeout = 30 * time .Second
75
+ timeoutStr := ur .Query ().Get ("timeout" )
76
+ if timeoutStr != "" {
77
+ timeout , err = time .ParseDuration (timeoutStr )
78
+ if err != nil {
79
+ return nil , err
80
+ }
81
+ }
82
+
83
+ config .Timeout = timeout
84
+
72
85
host := ur .Hostname ()
73
86
port := ur .Port ()
74
87
if port == "" {
@@ -82,7 +95,6 @@ func parseClientConfig(addr string) (*clientConfig, error) {
82
95
}
83
96
84
97
type Dialer struct {
85
- mut sync.Mutex
86
98
localAddr net.Addr
87
99
// ProxyDial specifies the optional dial function for
88
100
// establishing the transport connection.
@@ -91,17 +103,12 @@ type Dialer struct {
91
103
host string
92
104
config * ssh.ClientConfig
93
105
94
- pool sync.Pool
106
+ conns int32
107
+ clients chan * ssh.Client
95
108
}
96
109
97
110
func (d * Dialer ) Close () error {
98
- for {
99
- a := d .pool .Get ()
100
- if a == nil {
101
- break
102
- }
103
- a .(* ssh.Client ).Close ()
104
- }
111
+ // In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
105
112
return nil
106
113
}
107
114
@@ -115,33 +122,35 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
115
122
}
116
123
117
124
func (d * Dialer ) SSHClient (ctx context.Context ) (* ssh.Client , error ) {
118
- return d .GetClient (ctx )
119
- }
120
-
121
- func (d * Dialer ) GetClient (ctx context.Context ) (* ssh.Client , error ) {
122
- a := d .pool .Get ()
123
- if a != nil {
124
- return a .(* ssh.Client ), nil
125
+ cli , err := d .getClient (ctx )
126
+ if err != nil {
127
+ return nil , err
125
128
}
129
+ d .putClient (cli )
130
+ return cli , nil
131
+ }
126
132
127
- d .mut .Lock ()
128
- defer d .mut .Unlock ()
129
-
130
- a = d .pool .Get ()
131
- if a != nil {
132
- return a .(* ssh.Client ), nil
133
+ func (d * Dialer ) getClient (ctx context.Context ) (* ssh.Client , error ) {
134
+ if atomic .LoadInt32 (& d .conns ) >= int32 (cap (d .clients )) {
135
+ select {
136
+ case <- ctx .Done ():
137
+ return nil , ctx .Err ()
138
+ case cli := <- d .clients :
139
+ return cli , nil
140
+ }
133
141
}
142
+ atomic .AddInt32 (& d .conns , 1 )
134
143
135
144
cli , err := d .sshClient (ctx )
136
145
if err != nil {
146
+ atomic .AddInt32 (& d .conns , - 1 )
137
147
return nil , err
138
148
}
139
-
140
149
return cli , nil
141
150
}
142
151
143
- func (d * Dialer ) PutClient (cli * ssh.Client ) {
144
- d .pool . Put ( cli )
152
+ func (d * Dialer ) putClient (cli * ssh.Client ) {
153
+ d .clients <- cli
145
154
}
146
155
147
156
func (d * Dialer ) sshClient (ctx context.Context ) (* ssh.Client , error ) {
@@ -167,20 +176,16 @@ func buildCmd(name string, args ...string) string {
167
176
}
168
177
169
178
func (d * Dialer ) CommandDialContext (ctx context.Context , name string , args ... string ) (net.Conn , error ) {
170
- cli , err := d .GetClient (ctx )
179
+ cli , err := d .getClient (ctx )
171
180
if err != nil {
172
181
return nil , err
173
182
}
183
+ defer d .putClient (cli )
184
+
174
185
sess , err := cli .NewSession ()
175
186
if err != nil {
176
- if isSSHError (err ) {
177
- d .PutClient (cli )
178
- } else {
179
- cli .Close ()
180
- }
181
187
return nil , err
182
188
}
183
- defer d .PutClient (cli )
184
189
185
190
conn1 , conn2 := net .Pipe ()
186
191
sess .Stdin = conn1
@@ -217,42 +222,32 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
217
222
}
218
223
219
224
func (d * Dialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
220
- cli , err := d .GetClient (ctx )
225
+ cli , err := d .getClient (ctx )
221
226
if err != nil {
222
227
return nil , err
223
228
}
229
+ defer d .putClient (cli )
224
230
225
231
conn , err := cli .DialContext (ctx , network , address )
226
232
if err != nil {
227
- if isSSHError (err ) {
228
- d .PutClient (cli )
229
- } else {
230
- cli .Close ()
231
- }
232
233
return nil , err
233
234
}
234
235
235
- d .PutClient (cli )
236
236
return conn , nil
237
237
}
238
238
239
239
func (d * Dialer ) Dial (network , address string ) (net.Conn , error ) {
240
- cli , err := d .GetClient (context .Background ())
240
+ cli , err := d .getClient (context .Background ())
241
241
if err != nil {
242
242
return nil , err
243
243
}
244
+ defer d .putClient (cli )
244
245
245
246
conn , err := cli .Dial (network , address )
246
247
if err != nil {
247
- if isSSHError (err ) {
248
- d .PutClient (cli )
249
- } else {
250
- cli .Close ()
251
- }
252
248
return nil , err
253
249
}
254
250
255
- d .PutClient (cli )
256
251
return conn , nil
257
252
}
258
253
@@ -264,43 +259,16 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
264
259
}
265
260
}
266
261
267
- cli , err := d .GetClient (ctx )
262
+ cli , err := d .getClient (ctx )
268
263
if err != nil {
269
264
return nil , err
270
265
}
266
+ defer d .putClient (cli )
271
267
272
268
listener , err := cli .Listen (network , address )
273
269
if err != nil {
274
- if isSSHError (err ) {
275
- d .PutClient (cli )
276
- } else {
277
- cli .Close ()
278
- }
279
270
return nil , err
280
271
}
281
272
282
- listener = & listenerWithCleanup {
283
- Listener : listener ,
284
- cleanup : func () {
285
- d .PutClient (cli )
286
- },
287
- }
288
-
289
273
return listener , nil
290
274
}
291
-
292
- type listenerWithCleanup struct {
293
- net.Listener
294
- cleanup func ()
295
- }
296
-
297
- func (l * listenerWithCleanup ) Close () error {
298
- err := l .Listener .Close ()
299
- l .cleanup ()
300
- return err
301
- }
302
-
303
- func isSSHError (err error ) bool {
304
- msg := err .Error ()
305
- return strings .Contains (msg , "ssh: " )
306
- }
0 commit comments