7
7
"os"
8
8
"strconv"
9
9
"strings"
10
- "sync/atomic "
10
+ "sync"
11
11
"time"
12
12
13
13
"golang.org/x/crypto/ssh"
@@ -25,9 +25,8 @@ func NewDialer(addr string) (*Dialer, error) {
25
25
26
26
func NewDialerWithConfig (host string , config * ssh.ClientConfig ) (* Dialer , error ) {
27
27
return & Dialer {
28
- host : host ,
29
- config : config ,
30
- clients : make (chan * ssh.Client , 5 ),
28
+ host : host ,
29
+ config : config ,
31
30
}, nil
32
31
}
33
32
@@ -103,13 +102,12 @@ type Dialer struct {
103
102
host string
104
103
config * ssh.ClientConfig
105
104
106
- conns int32
107
- clients chan * ssh.Client
105
+ mut sync. RWMutex
106
+ sshCli * ssh.Client
108
107
}
109
108
110
109
func (d * Dialer ) Close () error {
111
- // In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
112
- return nil
110
+ return d .sshCli .Close ()
113
111
}
114
112
115
113
func (d * Dialer ) proxyDial (ctx context.Context , network , address string ) (net.Conn , error ) {
@@ -122,38 +120,20 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
122
120
}
123
121
124
122
func (d * Dialer ) SSHClient (ctx context.Context ) (* ssh.Client , error ) {
125
- cli , err := d .getClient (ctx )
126
- if err != nil {
127
- return nil , err
128
- }
129
- d .putClient (cli )
130
- return cli , nil
131
- }
123
+ d .mut .RLock ()
124
+ sshCli := d .sshCli
125
+ d .mut .RUnlock ()
132
126
133
- func (d * Dialer ) getClient (ctx context.Context ) (* ssh.Client , error ) {
134
- if atomic .LoadInt32 (& d .conns ) >= int32 (cap (d .clients )) {
135
- select {
136
- case <- ctx .Done ():
137
- return nil , ctx .Err ()
138
- case cli := <- d .clients :
139
- return cli , nil
140
- }
127
+ if sshCli != nil {
128
+ return sshCli , nil
141
129
}
142
- atomic .AddInt32 (& d .conns , 1 )
143
130
144
- cli , err := d . sshClient ( ctx )
145
- if err != nil {
146
- atomic . AddInt32 ( & d . conns , - 1 )
147
- return nil , err
131
+ d . mut . Lock ( )
132
+ defer d . mut . Unlock ()
133
+ if d . sshCli != nil {
134
+ return d . sshCli , nil
148
135
}
149
- return cli , nil
150
- }
151
136
152
- func (d * Dialer ) putClient (cli * ssh.Client ) {
153
- d .clients <- cli
154
- }
155
-
156
- func (d * Dialer ) sshClient (ctx context.Context ) (* ssh.Client , error ) {
157
137
conn , err := d .proxyDial (ctx , "tcp" , d .host )
158
138
if err != nil {
159
139
return nil , err
@@ -163,7 +143,9 @@ func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
163
143
if err != nil {
164
144
return nil , err
165
145
}
166
- return ssh .NewClient (con , chans , reqs ), nil
146
+
147
+ d .sshCli = ssh .NewClient (con , chans , reqs )
148
+ return d .sshCli , nil
167
149
}
168
150
169
151
func buildCmd (name string , args ... string ) string {
@@ -176,11 +158,10 @@ func buildCmd(name string, args ...string) string {
176
158
}
177
159
178
160
func (d * Dialer ) CommandDialContext (ctx context.Context , name string , args ... string ) (net.Conn , error ) {
179
- cli , err := d .getClient (ctx )
161
+ cli , err := d .SSHClient (ctx )
180
162
if err != nil {
181
163
return nil , err
182
164
}
183
- defer d .putClient (cli )
184
165
185
166
sess , err := cli .NewSession ()
186
167
if err != nil {
@@ -222,11 +203,10 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
222
203
}
223
204
224
205
func (d * Dialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
225
- cli , err := d .getClient (ctx )
206
+ cli , err := d .SSHClient (ctx )
226
207
if err != nil {
227
208
return nil , err
228
209
}
229
- defer d .putClient (cli )
230
210
231
211
conn , err := cli .DialContext (ctx , network , address )
232
212
if err != nil {
@@ -237,11 +217,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
237
217
}
238
218
239
219
func (d * Dialer ) Dial (network , address string ) (net.Conn , error ) {
240
- cli , err := d .getClient (context .Background ())
220
+ cli , err := d .SSHClient (context .Background ())
241
221
if err != nil {
242
222
return nil , err
243
223
}
244
- defer d .putClient (cli )
245
224
246
225
conn , err := cli .Dial (network , address )
247
226
if err != nil {
@@ -259,11 +238,10 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
259
238
}
260
239
}
261
240
262
- cli , err := d .getClient (ctx )
241
+ cli , err := d .SSHClient (ctx )
263
242
if err != nil {
264
243
return nil , err
265
244
}
266
- defer d .putClient (cli )
267
245
268
246
listener , err := cli .Listen (network , address )
269
247
if err != nil {
0 commit comments