Skip to content

Commit 5e8c3a7

Browse files
committed
Add support for dialing with context to websockets
Right now there is no way to pass context.Context to websocket.Dial. In addition, this method can block indefinitely in the `NewClient` call. Fixes golang/go#57953
1 parent f8411da commit 5e8c3a7

File tree

3 files changed

+89
-15
lines changed

3 files changed

+89
-15
lines changed

websocket/client.go

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ package websocket
66

77
import (
88
"bufio"
9+
"context"
910
"io"
1011
"net"
1112
"net/http"
1213
"net/url"
14+
"time"
1315
)
1416

1517
// DialError is an error that occurs while dialling a websocket server.
@@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string {
7779
return location.Host
7880
}
7981

80-
// DialConfig opens a new client connection to a WebSocket with a config.
8182
func DialConfig(config *Config) (ws *Conn, err error) {
82-
var client net.Conn
83+
return config.DialContext(context.Background())
84+
}
85+
86+
// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
87+
func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
8388
if config.Location == nil {
8489
return nil, &DialError{config, ErrBadWebSocketLocation}
8590
}
8691
if config.Origin == nil {
8792
return nil, &DialError{config, ErrBadWebSocketOrigin}
8893
}
94+
8995
dialer := config.Dialer
9096
if dialer == nil {
9197
dialer = &net.Dialer{}
9298
}
93-
client, err = dialWithDialer(dialer, config)
94-
if err != nil {
95-
goto Error
96-
}
97-
ws, err = NewClient(config, client)
99+
100+
client, err := dialWithDialer(ctx, dialer, config)
98101
if err != nil {
99-
client.Close()
100-
goto Error
102+
return nil, &DialError{config, err}
101103
}
102-
return
103104

104-
Error:
105-
return nil, &DialError{config, err}
105+
// Cleanup the connection if we fail to create the websocket successfully
106+
success := false
107+
defer func() {
108+
if !success {
109+
_ = client.Close()
110+
}
111+
}()
112+
113+
var ws *Conn
114+
var wsErr error
115+
doneConnecting := make(chan struct{})
116+
go func() {
117+
defer close(doneConnecting)
118+
ws, err = NewClient(config, client)
119+
if err != nil {
120+
wsErr = &DialError{config, err}
121+
}
122+
}()
123+
124+
// The websocket.NewClient() function can block indefinitely, make sure that we
125+
// respect the deadlines specified by the context.
126+
select {
127+
case <-ctx.Done():
128+
// Force the pending operations to fail, terminating the pending connection attempt
129+
_ = client.SetDeadline(time.Now())
130+
<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
131+
return nil, &DialError{config, ctx.Err()}
132+
case <-doneConnecting:
133+
if wsErr == nil {
134+
success = true // Disarm the deferred connection cleanup
135+
}
136+
return ws, wsErr
137+
}
106138
}

websocket/dial.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@
55
package websocket
66

77
import (
8+
"context"
89
"crypto/tls"
910
"net"
1011
)
1112

12-
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
13+
func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
1314
switch config.Location.Scheme {
1415
case "ws":
15-
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
16+
conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
1617

1718
case "wss":
18-
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
19+
tlsDialer := &tls.Dialer{
20+
NetDialer: dialer,
21+
Config: config.TlsConfig,
22+
}
1923

24+
conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
2025
default:
2126
err = ErrBadScheme
2227
}

websocket/dial_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
package websocket
66

77
import (
8+
"context"
89
"crypto/tls"
10+
"errors"
911
"fmt"
1012
"log"
1113
"net"
14+
"net/http"
1215
"net/http/httptest"
1316
"testing"
1417
"time"
@@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
4144
t.Fatalf("expected timeout error, got %#v", neterr)
4245
}
4346
}
47+
48+
func TestDialConfigTLSWithTimeouts(t *testing.T) {
49+
t.Parallel()
50+
51+
finishedRequest := make(chan bool)
52+
53+
// Context for cancellation
54+
ctx, cancel := context.WithCancel(context.Background())
55+
56+
// This is a TLS server that blocks each request indefinitely (and cancels the context)
57+
tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
cancel()
59+
<-finishedRequest
60+
}))
61+
62+
tlsServerAddr := tlsServer.Listener.Addr().String()
63+
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
64+
defer tlsServer.Close()
65+
defer close(finishedRequest)
66+
67+
config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
68+
config.TlsConfig = &tls.Config{
69+
InsecureSkipVerify: true,
70+
}
71+
72+
_, err := config.DialContext(ctx)
73+
dialerr, ok := err.(*DialError)
74+
if !ok {
75+
t.Fatalf("DialError expected, got %#v", err)
76+
}
77+
if !errors.Is(dialerr.Err, context.Canceled) {
78+
t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
79+
}
80+
}

0 commit comments

Comments
 (0)