Skip to content

Commit fe326ee

Browse files
committed
Revert client pool
1 parent 7710186 commit fe326ee

File tree

1 file changed

+22
-44
lines changed

1 file changed

+22
-44
lines changed

client.go

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"os"
88
"strconv"
99
"strings"
10-
"sync/atomic"
10+
"sync"
1111
"time"
1212

1313
"golang.org/x/crypto/ssh"
@@ -25,9 +25,8 @@ func NewDialer(addr string) (*Dialer, error) {
2525

2626
func NewDialerWithConfig(host string, config *ssh.ClientConfig) (*Dialer, error) {
2727
return &Dialer{
28-
host: host,
29-
config: config,
30-
clients: make(chan *ssh.Client, 5),
28+
host: host,
29+
config: config,
3130
}, nil
3231
}
3332

@@ -103,13 +102,12 @@ type Dialer struct {
103102
host string
104103
config *ssh.ClientConfig
105104

106-
conns int32
107-
clients chan *ssh.Client
105+
mut sync.RWMutex
106+
sshCli *ssh.Client
108107
}
109108

110109
func (d *Dialer) Close() error {
111-
// In practice, closing the connection doesn't actually release the ssh.Conn but causes a memory leak
112-
return nil
110+
return d.sshCli.Close()
113111
}
114112

115113
func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -122,38 +120,20 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
122120
}
123121

124122
func (d *Dialer) SSHClient(ctx context.Context) (*ssh.Client, error) {
125-
cli, err := d.getClient(ctx)
126-
if err != nil {
127-
return nil, err
128-
}
129-
d.putClient(cli)
130-
return cli, nil
131-
}
123+
d.mut.RLock()
124+
sshCli := d.sshCli
125+
d.mut.RUnlock()
132126

133-
func (d *Dialer) getClient(ctx context.Context) (*ssh.Client, error) {
134-
if atomic.LoadInt32(&d.conns) >= int32(cap(d.clients)) {
135-
select {
136-
case <-ctx.Done():
137-
return nil, ctx.Err()
138-
case cli := <-d.clients:
139-
return cli, nil
140-
}
127+
if sshCli != nil {
128+
return sshCli, nil
141129
}
142-
atomic.AddInt32(&d.conns, 1)
143130

144-
cli, err := d.sshClient(ctx)
145-
if err != nil {
146-
atomic.AddInt32(&d.conns, -1)
147-
return nil, err
131+
d.mut.Lock()
132+
defer d.mut.Unlock()
133+
if d.sshCli != nil {
134+
return d.sshCli, nil
148135
}
149-
return cli, nil
150-
}
151136

152-
func (d *Dialer) putClient(cli *ssh.Client) {
153-
d.clients <- cli
154-
}
155-
156-
func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
157137
conn, err := d.proxyDial(ctx, "tcp", d.host)
158138
if err != nil {
159139
return nil, err
@@ -163,7 +143,9 @@ func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
163143
if err != nil {
164144
return nil, err
165145
}
166-
return ssh.NewClient(con, chans, reqs), nil
146+
147+
d.sshCli = ssh.NewClient(con, chans, reqs)
148+
return d.sshCli, nil
167149
}
168150

169151
func buildCmd(name string, args ...string) string {
@@ -176,11 +158,10 @@ func buildCmd(name string, args ...string) string {
176158
}
177159

178160
func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...string) (net.Conn, error) {
179-
cli, err := d.getClient(ctx)
161+
cli, err := d.SSHClient(ctx)
180162
if err != nil {
181163
return nil, err
182164
}
183-
defer d.putClient(cli)
184165

185166
sess, err := cli.NewSession()
186167
if err != nil {
@@ -222,11 +203,10 @@ func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...st
222203
}
223204

224205
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
225-
cli, err := d.getClient(ctx)
206+
cli, err := d.SSHClient(ctx)
226207
if err != nil {
227208
return nil, err
228209
}
229-
defer d.putClient(cli)
230210

231211
conn, err := cli.DialContext(ctx, network, address)
232212
if err != nil {
@@ -237,11 +217,10 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
237217
}
238218

239219
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
240-
cli, err := d.getClient(context.Background())
220+
cli, err := d.SSHClient(context.Background())
241221
if err != nil {
242222
return nil, err
243223
}
244-
defer d.putClient(cli)
245224

246225
conn, err := cli.Dial(network, address)
247226
if err != nil {
@@ -259,11 +238,10 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
259238
}
260239
}
261240

262-
cli, err := d.getClient(ctx)
241+
cli, err := d.SSHClient(ctx)
263242
if err != nil {
264243
return nil, err
265244
}
266-
defer d.putClient(cli)
267245

268246
listener, err := cli.Listen(network, address)
269247
if err != nil {

0 commit comments

Comments
 (0)