Skip to content

Commit 8cdf277

Browse files
authored
feat: enforce proxy dial timeouts and improve timeout handling (#98)
* feat: add timeout handling and error detection for proxy connections - Add support for proxy dial timeouts with a dedicated error type and detection. - Apply connection timeout logic when connecting through a proxy. - Update Run method to correctly set timeout flag if proxy dial timeout occurs. - Introduce tests to verify proxy timeouts and error handling on proxy connections. Signed-off-by: appleboy <appleboy.tw@gmail.com> * fix: prevent goroutine leaks on proxy connections - Prevent goroutine leaks on proxy connection by handling context cancellation and closing connections if necessary - Add a test to verify that proxy dial timeouts do not cause goroutine leaks Signed-off-by: appleboy <appleboy.tw@gmail.com> --------- Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent abf4c52 commit 8cdf277

File tree

2 files changed

+200
-1
lines changed

2 files changed

+200
-1
lines changed

easyssh.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ var (
2828
defaultBufferSize = 4096
2929
)
3030

31+
var (
32+
// ErrProxyDialTimeout is returned when proxy dial connection times out
33+
ErrProxyDialTimeout = errors.New("proxy dial timeout")
34+
)
35+
3136
type Protocol string
3237

3338
const (
@@ -253,7 +258,43 @@ func (ssh_conf *MakeConfig) Connect() (*ssh.Session, *ssh.Client, error) {
253258
return nil, nil, err
254259
}
255260

256-
conn, err := proxyClient.Dial(string(ssh_conf.Protocol), net.JoinHostPort(ssh_conf.Server, ssh_conf.Port))
261+
// Apply timeout to the connection from proxy to target server
262+
timeout := ssh_conf.Timeout
263+
if timeout == 0 {
264+
timeout = defaultTimeout
265+
}
266+
267+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
268+
defer cancel()
269+
270+
type connResult struct {
271+
conn net.Conn
272+
err error
273+
}
274+
275+
connCh := make(chan connResult, 1)
276+
go func() {
277+
conn, err := proxyClient.Dial(string(ssh_conf.Protocol), net.JoinHostPort(ssh_conf.Server, ssh_conf.Port))
278+
select {
279+
case connCh <- connResult{conn: conn, err: err}:
280+
// Successfully sent result
281+
case <-ctx.Done():
282+
// Context was cancelled, clean up the connection if it was established
283+
if conn != nil {
284+
conn.Close()
285+
}
286+
}
287+
}()
288+
289+
var conn net.Conn
290+
select {
291+
case result := <-connCh:
292+
conn = result.conn
293+
err = result.err
294+
case <-ctx.Done():
295+
return nil, nil, fmt.Errorf("%w: %v", ErrProxyDialTimeout, ctx.Err())
296+
}
297+
257298
if err != nil {
258299
return nil, nil, err
259300
}
@@ -413,6 +454,10 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<-
413454
func (ssh_conf *MakeConfig) Run(command string, timeout ...time.Duration) (outStr string, errStr string, isTimeout bool, err error) {
414455
stdoutChan, stderrChan, doneChan, errChan, err := ssh_conf.Stream(command, timeout...)
415456
if err != nil {
457+
// Check if the error is from a proxy dial timeout
458+
if errors.Is(err, ErrProxyDialTimeout) {
459+
isTimeout = true
460+
}
416461
return outStr, errStr, isTimeout, err
417462
}
418463
// read from the output channel until the done signal is passed

easyssh_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package easyssh
22

33
import (
44
"context"
5+
"errors"
56
"os"
67
"os/user"
78
"path"
9+
"runtime"
810
"testing"
911
"time"
1012

@@ -512,3 +514,155 @@ func TestCommandTimeout(t *testing.T) {
512514
assert.NotNil(t, err)
513515
assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error())
514516
}
517+
518+
// TestProxyTimeoutHandling tests that timeout is properly respected when using proxy connections
519+
// This test uses a non-existent proxy server to force a timeout during proxy connection
520+
func TestProxyTimeoutHandling(t *testing.T) {
521+
ssh := &MakeConfig{
522+
Server: "example.com",
523+
User: "testuser",
524+
Port: "22",
525+
KeyPath: "./tests/.ssh/id_rsa",
526+
Timeout: 1 * time.Second, // Short timeout for testing
527+
Proxy: DefaultConfig{
528+
User: "testuser",
529+
Server: "10.255.255.1", // Non-routable IP that should timeout
530+
Port: "22",
531+
KeyPath: "./tests/.ssh/id_rsa",
532+
Timeout: 1 * time.Second,
533+
},
534+
}
535+
536+
// Test Connect() method directly to test proxy connection timeout
537+
start := time.Now()
538+
session, client, err := ssh.Connect()
539+
elapsed := time.Since(start)
540+
541+
// Should timeout within reasonable bounds
542+
assert.True(t, elapsed < 3*time.Second, "Connection should timeout within 3 seconds, took %v", elapsed)
543+
assert.True(t, elapsed >= 1*time.Second, "Connection should take at least 1 second (timeout value), took %v", elapsed)
544+
545+
// Should return nil session and client
546+
assert.Nil(t, session)
547+
assert.Nil(t, client)
548+
549+
// Should have error
550+
assert.NotNil(t, err)
551+
}
552+
553+
// TestProxyDialTimeout tests the specific scenario described in issue #93
554+
// where proxy dial timeout should be respected and properly detected
555+
func TestProxyDialTimeout(t *testing.T) {
556+
ssh := &MakeConfig{
557+
Server: "10.255.255.1", // Non-routable IP that should timeout
558+
User: "testuser",
559+
Port: "22",
560+
KeyPath: "./tests/.ssh/id_rsa",
561+
Timeout: 2 * time.Second, // Short timeout for testing
562+
Proxy: DefaultConfig{
563+
User: "testuser",
564+
Server: "10.255.255.2", // Another non-routable IP for proxy
565+
Port: "22",
566+
KeyPath: "./tests/.ssh/id_rsa",
567+
Timeout: 2 * time.Second,
568+
},
569+
}
570+
571+
// Test Connect() method directly to avoid SSH server dependency
572+
start := time.Now()
573+
session, client, err := ssh.Connect()
574+
elapsed := time.Since(start)
575+
576+
// Should timeout within reasonable bounds
577+
assert.True(t, elapsed < 5*time.Second, "Connection should timeout within 5 seconds, took %v", elapsed)
578+
assert.True(t, elapsed >= 2*time.Second, "Connection should take at least 2 seconds (timeout value), took %v", elapsed)
579+
580+
// Should return nil session and client
581+
assert.Nil(t, session)
582+
assert.Nil(t, client)
583+
584+
// Should have error
585+
assert.NotNil(t, err)
586+
// Note: This will timeout at the proxy connection level, not at proxy dial level
587+
// so it won't be ErrProxyDialTimeout, but we can still verify the timeout behavior
588+
}
589+
590+
// TestProxyDialTimeoutInRun tests timeout detection in Run method
591+
func TestProxyDialTimeoutInRun(t *testing.T) {
592+
ssh := &MakeConfig{
593+
Server: "example.com",
594+
User: "testuser",
595+
Port: "22",
596+
KeyPath: "./tests/.ssh/id_rsa",
597+
Timeout: 2 * time.Second,
598+
Proxy: DefaultConfig{
599+
User: "testuser",
600+
Server: "127.0.0.1", // Assume localhost SSH exists
601+
Port: "22",
602+
KeyPath: "./tests/.ssh/id_rsa",
603+
Timeout: 2 * time.Second,
604+
},
605+
}
606+
607+
// Mock a scenario where Connect() returns ErrProxyDialTimeout
608+
// by temporarily changing the target to a non-routable address
609+
ssh.Server = "10.255.255.1"
610+
611+
start := time.Now()
612+
outStr, errStr, isTimeout, err := ssh.Run("whoami")
613+
elapsed := time.Since(start)
614+
615+
// Should timeout within reasonable bounds
616+
assert.True(t, elapsed < 5*time.Second, "Should timeout within 5 seconds, took %v", elapsed)
617+
618+
// Should return empty output
619+
assert.Equal(t, "", outStr)
620+
assert.Equal(t, "", errStr)
621+
622+
// Should have error
623+
assert.NotNil(t, err)
624+
625+
// If it's specifically a proxy dial timeout, isTimeout should be true
626+
if errors.Is(err, ErrProxyDialTimeout) {
627+
assert.True(t, isTimeout, "isTimeout should be true for proxy dial timeout")
628+
}
629+
}
630+
631+
// TestProxyGoroutineLeak tests that no goroutines are leaked when proxy dial times out
632+
func TestProxyGoroutineLeak(t *testing.T) {
633+
// Get initial goroutine count
634+
initialGoroutines := runtime.NumGoroutine()
635+
636+
ssh := &MakeConfig{
637+
Server: "10.255.255.1", // Non-routable IP that should timeout
638+
User: "testuser",
639+
Port: "22",
640+
KeyPath: "./tests/.ssh/id_rsa",
641+
Timeout: 1 * time.Second, // Short timeout
642+
Proxy: DefaultConfig{
643+
User: "testuser",
644+
Server: "10.255.255.2", // Another non-routable IP for proxy
645+
Port: "22",
646+
KeyPath: "./tests/.ssh/id_rsa",
647+
Timeout: 1 * time.Second,
648+
},
649+
}
650+
651+
// Run multiple timeout operations
652+
for i := 0; i < 5; i++ {
653+
_, _, err := ssh.Connect()
654+
assert.NotNil(t, err) // Should have error due to timeout
655+
}
656+
657+
// Give some time for goroutines to cleanup
658+
time.Sleep(100 * time.Millisecond)
659+
runtime.GC() // Force garbage collection
660+
661+
// Check final goroutine count - should not have grown significantly
662+
finalGoroutines := runtime.NumGoroutine()
663+
664+
// Allow for some variance due to test framework overhead, but shouldn't grow by more than 2-3 goroutines
665+
assert.True(t, finalGoroutines <= initialGoroutines+3,
666+
"Goroutine leak detected: initial=%d, final=%d", initialGoroutines, finalGoroutines)
667+
}
668+

0 commit comments

Comments
 (0)