Skip to content

Commit 28df838

Browse files
author
Emi Gutekanst
authored
gateway: benchmark: fix 'all websocket requests fail' case (#1131)
* gateway: benchmark: fix authentication Signed-off-by: Stephen Gutekanst <stephen@sourcegraph.com> * gateway: benchmark: only connect to websocket once ready, add reconnect logic Signed-off-by: Stephen Gutekanst <stephen@sourcegraph.com> --------- Signed-off-by: Stephen Gutekanst <stephen@sourcegraph.com>
1 parent 7a2b789 commit 28df838

File tree

1 file changed

+74
-74
lines changed

1 file changed

+74
-74
lines changed

cmd/src/gateway_benchmark.go

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,6 @@ type Stats struct {
2626
Total time.Duration
2727
}
2828

29-
// httpEndpointConfig represents the configuration for an HTTP endpoint.
30-
type httpEndpointConfig struct {
31-
client *http.Client
32-
url string
33-
}
34-
35-
// sgAuthTransport is an http.RoundTripper that adds an Authorization header to requests.
36-
// It is used to add the Sourcegraph access token to requests to Sourcegraph endpoints.
37-
type sgAuthTransport struct {
38-
token string
39-
base http.RoundTripper
40-
}
41-
func (t *sgAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
42-
req.Header.Add("Authorization", "token "+t.token)
43-
return t.base.RoundTrip(req)
44-
}
45-
4629
func init() {
4730
usage := `
4831
'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints.
@@ -53,10 +36,10 @@ Usage:
5336
5437
Examples:
5538
56-
$ src gateway benchmark
57-
$ src gateway benchmark --requests 50
58-
$ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp sgp_***** --requests 50
59-
$ src gateway benchmark --requests 50 --csv results.csv
39+
$ src gateway benchmark --sgp <token>
40+
$ src gateway benchmark --requests 50 --sgp <token>
41+
$ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp <token>
42+
$ src gateway benchmark --requests 50 --csv results.csv --sgp <token>
6043
`
6144

6245
flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError)
@@ -66,7 +49,7 @@ Examples:
6649
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
6750
gatewayEndpoint = flagSet.String("gateway", "https://cody-gateway.sourcegraph.com", "Cody Gateway endpoint")
6851
sgEndpoint = flagSet.String("sourcegraph", "https://sourcegraph.com", "Sourcegraph endpoint")
69-
sgpToken = flagSet.String("sgp", "sgp_*****", "Sourcegraph personal access token for the called instance")
52+
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
7053
)
7154

7255
handler := func(args []string) error {
@@ -79,78 +62,49 @@ Examples:
7962
}
8063

8164
var (
82-
gatewayWebsocket, sourcegraphWebsocket *websocket.Conn
83-
err error
84-
gatewayClient = &http.Client{}
85-
sourcegraphClient = &http.Client{}
86-
endpoints = map[string]any{} // Values: URL `string`s or `*websocket.Conn`s
65+
httpClient = &http.Client{}
66+
endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s
8767
)
88-
89-
// Connect to endpoints
9068
if *gatewayEndpoint != "" {
9169
fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint)
92-
wsURL := strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1)
93-
fmt.Println("Connecting to Cody Gateway via WebSocket..", wsURL)
94-
gatewayWebsocket, _, err = websocket.DefaultDialer.Dial(wsURL, nil)
95-
if err != nil {
96-
return fmt.Errorf("WebSocket dial(%s): %v", wsURL, err)
97-
}
98-
fmt.Println("Connected!")
99-
endpoints["ws(s): gateway"] = gatewayWebsocket
100-
endpoints["http(s): gateway"] = &httpEndpointConfig{
101-
client: gatewayClient,
102-
url: fmt.Sprint(*gatewayEndpoint, "/v2/http"),
70+
endpoints["ws(s): gateway"] = &webSocketClient{
71+
conn: nil,
72+
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
10373
}
74+
endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http")
10475
} else {
10576
fmt.Println("warning: not benchmarking Cody Gateway (-gateway endpoint not provided)")
10677
}
10778
if *sgEndpoint != "" {
108-
// Add auth header to sourcegraphClient transport
109-
if *sgpToken != "" {
110-
sourcegraphClient.Transport = &sgAuthTransport{
111-
token: *sgpToken,
112-
base: http.DefaultTransport,
113-
}
79+
if *sgpToken == "" {
80+
return cmderrors.Usage("must specify --sgp <Sourcegraph personal access token>")
11481
}
11582
fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint)
116-
wsURL := strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1)
117-
header := http.Header{}
118-
header.Add("Authorization", "token "+*sgpToken)
119-
fmt.Println("Connecting to Sourcegraph instance via WebSocket..", wsURL)
120-
sourcegraphWebsocket, _, err = websocket.DefaultDialer.Dial(wsURL, header)
121-
if err != nil {
122-
return fmt.Errorf("WebSocket dial(%s): %v", wsURL, err)
123-
}
124-
fmt.Println("Connected!")
125-
126-
endpoints["ws(s): sourcegraph"] = sourcegraphWebsocket
127-
endpoints["http(s): sourcegraph"] = &httpEndpointConfig{
128-
client: sourcegraphClient,
129-
url: fmt.Sprint(*sgEndpoint, "/.api/gateway/http"),
130-
}
131-
endpoints["http(s): http-then-ws"] = &httpEndpointConfig{
132-
client: sourcegraphClient,
133-
url: fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket"),
83+
endpoints["ws(s): sourcegraph"] = &webSocketClient{
84+
conn: nil,
85+
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
13486
}
87+
endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http")
88+
endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket")
13589
} else {
13690
fmt.Println("warning: not benchmarking Sourcegraph instance (-sourcegraph endpoint not provided)")
13791
}
13892

13993
fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount)
14094

14195
var results []endpointResult
142-
for name, clientOrEndpointConfig := range endpoints {
96+
for name, clientOrURL := range endpoints {
14397
durations := make([]time.Duration, 0, *requestCount)
14498
fmt.Printf("\nTesting %s...", name)
14599

146100
for i := 0; i < *requestCount; i++ {
147-
if ws, ok := clientOrEndpointConfig.(*websocket.Conn); ok {
101+
if ws, ok := clientOrURL.(*webSocketClient); ok {
148102
duration := benchmarkEndpointWebSocket(ws)
149103
if duration > 0 {
150104
durations = append(durations, duration)
151105
}
152-
} else if epConf, ok := clientOrEndpointConfig.(*httpEndpointConfig); ok {
153-
duration := benchmarkEndpointHTTP(epConf)
106+
} else if url, ok := clientOrURL.(string); ok {
107+
duration := benchmarkEndpointHTTP(httpClient, url, *sgpToken)
154108
if duration > 0 {
155109
durations = append(durations, duration)
156110
}
@@ -200,6 +154,26 @@ Examples:
200154
})
201155
}
202156

157+
type webSocketClient struct {
158+
conn *websocket.Conn
159+
URL string
160+
}
161+
162+
func (c *webSocketClient) reconnect() error {
163+
if c.conn != nil {
164+
c.conn.Close() // don't leak connections
165+
}
166+
fmt.Println("Connecting to WebSocket..", c.URL)
167+
var err error
168+
c.conn, _, err = websocket.DefaultDialer.Dial(c.URL, nil)
169+
if err != nil {
170+
c.conn = nil // retry again later
171+
return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err)
172+
}
173+
fmt.Println("Connected!")
174+
return nil
175+
}
176+
203177
type endpointResult struct {
204178
name string
205179
avg time.Duration
@@ -212,11 +186,18 @@ type endpointResult struct {
212186
successful int
213187
}
214188

215-
func benchmarkEndpointHTTP(epConfig *httpEndpointConfig) time.Duration {
189+
func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Duration {
216190
start := time.Now()
217-
resp, err := epConfig.client.Post(epConfig.url, "application/json", strings.NewReader("ping"))
191+
req, err := http.NewRequest("POST", url, strings.NewReader("ping"))
192+
if err != nil {
193+
fmt.Printf("Error creating request: %v\n", err)
194+
return 0
195+
}
196+
req.Header.Set("Content-Type", "application/json")
197+
req.Header.Set("Authorization", "token "+accessToken)
198+
resp, err := client.Do(req)
218199
if err != nil {
219-
fmt.Printf("Error calling %s: %v\n", epConfig.url, err)
200+
fmt.Printf("Error calling %s: %v\n", url, err)
220201
return 0
221202
}
222203
defer func() {
@@ -242,20 +223,39 @@ func benchmarkEndpointHTTP(epConfig *httpEndpointConfig) time.Duration {
242223
return time.Since(start)
243224
}
244225

245-
func benchmarkEndpointWebSocket(conn *websocket.Conn) time.Duration {
226+
func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
227+
// Perform initial websocket connection, if needed.
228+
if client.conn == nil {
229+
if err := client.reconnect(); err != nil {
230+
fmt.Printf("Error reconnecting: %v\n", err)
231+
return 0
232+
}
233+
}
234+
235+
// Perform the benchmarked request using the websocket.
246236
start := time.Now()
247-
err := conn.WriteMessage(websocket.TextMessage, []byte("ping"))
237+
err := client.conn.WriteMessage(websocket.TextMessage, []byte("ping"))
248238
if err != nil {
249239
fmt.Printf("WebSocket write error: %v\n", err)
240+
if err := client.reconnect(); err != nil {
241+
fmt.Printf("Error reconnecting: %v\n", err)
242+
}
250243
return 0
251244
}
252-
_, message, err := conn.ReadMessage()
245+
_, message, err := client.conn.ReadMessage()
246+
253247
if err != nil {
254248
fmt.Printf("WebSocket read error: %v\n", err)
249+
if err := client.reconnect(); err != nil {
250+
fmt.Printf("Error reconnecting: %v\n", err)
251+
}
255252
return 0
256253
}
257254
if string(message) != "pong" {
258255
fmt.Printf("Expected 'pong' response, got: %q\n", string(message))
256+
if err := client.reconnect(); err != nil {
257+
fmt.Printf("Error reconnecting: %v\n", err)
258+
}
259259
return 0
260260
}
261261
return time.Since(start)

0 commit comments

Comments
 (0)