Skip to content

Commit 3dfd003

Browse files
Cyberaxgopherbot
authored andcommitted
websocket: add support for dialing with context
Right now there is no way to pass context.Context to websocket.Dial. In addition, this method can block indefinitely in the NewClient call. Fixes golang/go#57953. Change-Id: Ic52d4b8306cd0850e78d683abb1bf11f0d4247ca GitHub-Last-Rev: 5e8c3a7 GitHub-Pull-Request: #160 Reviewed-on: https://go-review.googlesource.com/c/net/+/463097 Auto-Submit: Damien Neil <dneil@google.com> Reviewed-by: Damien Neil <dneil@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
1 parent fa11427 commit 3dfd003

File tree

3 files changed

+89
-15
lines changed

3 files changed

+89
-15
lines changed

websocket/client.go

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ package websocket
66

77
import (
88
"bufio"
9+
"context"
910
"io"
1011
"net"
1112
"net/http"
1213
"net/url"
14+
"time"
1315
)
1416

1517
// DialError is an error that occurs while dialling a websocket server.
@@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string {
7779
return location.Host
7880
}
7981

80-
// DialConfig opens a new client connection to a WebSocket with a config.
8182
func DialConfig(config *Config) (ws *Conn, err error) {
82-
var client net.Conn
83+
return config.DialContext(context.Background())
84+
}
85+
86+
// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
87+
func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
8388
if config.Location == nil {
8489
return nil, &DialError{config, ErrBadWebSocketLocation}
8590
}
8691
if config.Origin == nil {
8792
return nil, &DialError{config, ErrBadWebSocketOrigin}
8893
}
94+
8995
dialer := config.Dialer
9096
if dialer == nil {
9197
dialer = &net.Dialer{}
9298
}
93-
client, err = dialWithDialer(dialer, config)
94-
if err != nil {
95-
goto Error
96-
}
97-
ws, err = NewClient(config, client)
99+
100+
client, err := dialWithDialer(ctx, dialer, config)
98101
if err != nil {
99-
client.Close()
100-
goto Error
102+
return nil, &DialError{config, err}
101103
}
102-
return
103104

104-
Error:
105-
return nil, &DialError{config, err}
105+
// Cleanup the connection if we fail to create the websocket successfully
106+
success := false
107+
defer func() {
108+
if !success {
109+
_ = client.Close()
110+
}
111+
}()
112+
113+
var ws *Conn
114+
var wsErr error
115+
doneConnecting := make(chan struct{})
116+
go func() {
117+
defer close(doneConnecting)
118+
ws, err = NewClient(config, client)
119+
if err != nil {
120+
wsErr = &DialError{config, err}
121+
}
122+
}()
123+
124+
// The websocket.NewClient() function can block indefinitely, make sure that we
125+
// respect the deadlines specified by the context.
126+
select {
127+
case <-ctx.Done():
128+
// Force the pending operations to fail, terminating the pending connection attempt
129+
_ = client.SetDeadline(time.Now())
130+
<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
131+
return nil, &DialError{config, ctx.Err()}
132+
case <-doneConnecting:
133+
if wsErr == nil {
134+
success = true // Disarm the deferred connection cleanup
135+
}
136+
return ws, wsErr
137+
}
106138
}

websocket/dial.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@
55
package websocket
66

77
import (
8+
"context"
89
"crypto/tls"
910
"net"
1011
)
1112

12-
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
13+
func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
1314
switch config.Location.Scheme {
1415
case "ws":
15-
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
16+
conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
1617

1718
case "wss":
18-
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
19+
tlsDialer := &tls.Dialer{
20+
NetDialer: dialer,
21+
Config: config.TlsConfig,
22+
}
1923

24+
conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
2025
default:
2126
err = ErrBadScheme
2227
}

websocket/dial_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
package websocket
66

77
import (
8+
"context"
89
"crypto/tls"
10+
"errors"
911
"fmt"
1012
"log"
1113
"net"
14+
"net/http"
1215
"net/http/httptest"
1316
"testing"
1417
"time"
@@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
4144
t.Fatalf("expected timeout error, got %#v", neterr)
4245
}
4346
}
47+
48+
func TestDialConfigTLSWithTimeouts(t *testing.T) {
49+
t.Parallel()
50+
51+
finishedRequest := make(chan bool)
52+
53+
// Context for cancellation
54+
ctx, cancel := context.WithCancel(context.Background())
55+
56+
// This is a TLS server that blocks each request indefinitely (and cancels the context)
57+
tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
cancel()
59+
<-finishedRequest
60+
}))
61+
62+
tlsServerAddr := tlsServer.Listener.Addr().String()
63+
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
64+
defer tlsServer.Close()
65+
defer close(finishedRequest)
66+
67+
config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
68+
config.TlsConfig = &tls.Config{
69+
InsecureSkipVerify: true,
70+
}
71+
72+
_, err := config.DialContext(ctx)
73+
dialerr, ok := err.(*DialError)
74+
if !ok {
75+
t.Fatalf("DialError expected, got %#v", err)
76+
}
77+
if !errors.Is(dialerr.Err, context.Canceled) {
78+
t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
79+
}
80+
}

0 commit comments

Comments
 (0)