From 90bf9b8ac4dec4b8be4faea37d096918cf634b4d Mon Sep 17 00:00:00 2001 From: Robin Date: Sat, 22 Oct 2022 20:18:45 +0200 Subject: [PATCH] Correct context --- ws/dial.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ws/dial.go b/ws/dial.go index 575c7b3..0e2e6d4 100644 --- a/ws/dial.go +++ b/ws/dial.go @@ -26,15 +26,17 @@ func (o *ClientOptions) dial(ctx context.Context) (context.Context, *websocket.C header := metadata.Join( metadata.MD(o.DialOptions.HTTPHeader), - mdFromContext(o.DialContext), - mdFromContext(ctx), + fromOutgoingContext(o.DialContext), + fromOutgoingContext(ctx), ) - if o.DialContext != nil { - ctx = o.DialContext + + dialContext := o.DialContext + if dialContext == nil { + dialContext = ctx } // Dial service - ws, res, err := websocket.Dial(ctx, o.URL, &websocket.DialOptions{ + ws, res, err := websocket.Dial(dialContext, o.URL, &websocket.DialOptions{ HTTPClient: o.DialOptions.HTTPClient, HTTPHeader: http.Header(header), Subprotocols: o.DialOptions.Subprotocols, @@ -71,7 +73,7 @@ func (e *dialErr) Error() string { } func (e *dialErr) Unwrap() error { return e.error } -func mdFromContext(ctx context.Context) metadata.MD { +func fromOutgoingContext(ctx context.Context) metadata.MD { if ctx == nil { return nil }