Skip to content

Commit a023543

Browse files
Refresh expired websocket sessions (stripe#468)
1 parent fca10fa commit a023543

File tree

14 files changed

+289
-152
lines changed

14 files changed

+289
-152
lines changed

pkg/ansi/ansi.go

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,30 @@ func Linkify(text, url string, w io.Writer) string {
105105
return fmt.Sprintf("\x1b]8;;%s\x1b\\%s\x1b]8;;\x1b\\", url, text)
106106
}
107107

108-
// StartSpinner starts a spinner with the given message. If the writer doesn't
109-
// support colors, it simply prints the message.
110-
func StartSpinner(msg string, w io.Writer) *spinner.Spinner {
111-
if !shouldUseColors(w) {
112-
fmt.Fprintln(w, msg)
113-
return nil
114-
}
108+
type charset = []string
115109

110+
func getCharset() charset {
116111
// See https://github.com/briandowns/spinner#available-character-sets for
117112
// list of available charsets
118-
charSetIdx := 11
119113
if runtime.GOOS == "windows" {
120114
// Less fancy, but uses ASCII characters so works with Windows default
121115
// console.
122-
charSetIdx = 8
116+
return spinner.CharSets[8]
117+
}
118+
return spinner.CharSets[11]
119+
}
120+
121+
const duration = time.Duration(100) * time.Millisecond
122+
123+
// StartNewSpinner starts a new spinner with the given message. If the writer doesn't
124+
// support colors, it simply prints the message.
125+
func StartNewSpinner(msg string, w io.Writer) *spinner.Spinner {
126+
if !shouldUseColors(w) {
127+
fmt.Fprintln(w, msg)
128+
return nil
123129
}
124130

125-
s := spinner.New(spinner.CharSets[charSetIdx], time.Duration(100)*time.Millisecond)
131+
s := spinner.New(getCharset(), duration)
126132
s.Writer = w
127133

128134
if msg != "" {
@@ -134,6 +140,20 @@ func StartSpinner(msg string, w io.Writer) *spinner.Spinner {
134140
return s
135141
}
136142

143+
// StartSpinner updates an existing spinner's message, and starts it if it was stopped
144+
func StartSpinner(s *spinner.Spinner, msg string, w io.Writer) {
145+
if s == nil {
146+
fmt.Fprintln(w, msg)
147+
return
148+
}
149+
if msg != "" {
150+
s.Suffix = " " + msg
151+
}
152+
if !s.Active() {
153+
s.Start()
154+
}
155+
}
156+
137157
// StopSpinner stops a spinner with the given message. If the writer doesn't
138158
// support colors, it simply prints the message.
139159
func StopSpinner(s *spinner.Spinner, msg string, w io.Writer) {

pkg/cmd/listen.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cmd
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"net/url"
@@ -163,7 +164,7 @@ func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error {
163164
NoWSS: lc.noWSS,
164165
}, lc.events)
165166

166-
err = p.Run()
167+
err = p.Run(context.Background())
167168
if err != nil {
168169
return err
169170
}

pkg/cmd/logs/tail.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
log "github.com/sirupsen/logrus"
77
"github.com/spf13/cobra"
88

9+
"context"
10+
911
"github.com/stripe/stripe-cli/pkg/config"
1012
logTailing "github.com/stripe/stripe-cli/pkg/logtailing"
1113
"github.com/stripe/stripe-cli/pkg/validators"
@@ -158,7 +160,7 @@ func (tailCmd *TailCmd) runTailCmd(cmd *cobra.Command, args []string) error {
158160
WebSocketFeature: requestLogsWebSocketFeature,
159161
})
160162

161-
err = tailer.Run()
163+
err = tailer.Run(context.Background())
162164
if err != nil {
163165
return err
164166
}

pkg/cmd/samples/create.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ To see supported samples, run 'stripe samples list'`, args[0])
8484
return fmt.Errorf("Path already exists for: %s", destination)
8585
}
8686

87-
spinner := ansi.StartSpinner(fmt.Sprintf("Downloading %s", selectedSample), os.Stdout)
87+
spinner := ansi.StartNewSpinner(fmt.Sprintf("Downloading %s", selectedSample), os.Stdout)
8888

8989
if cc.forceRefresh {
9090
err := sample.DeleteCache(selectedSample)
@@ -145,7 +145,7 @@ To see supported samples, run 'stripe samples list'`, args[0])
145145
os.Exit(1)
146146
}()
147147

148-
spinner = ansi.StartSpinner(fmt.Sprintf("Copying files over... %s", selectedSample), os.Stdout)
148+
spinner = ansi.StartNewSpinner(fmt.Sprintf("Copying files over... %s", selectedSample), os.Stdout)
149149
// Create the target folder to copy the sample in to. We do
150150
// this here in case any of the steps above fail, minimizing
151151
// the change that we create a dangling empty folder
@@ -164,7 +164,7 @@ To see supported samples, run 'stripe samples list'`, args[0])
164164
ansi.StopSpinner(spinner, "", os.Stdout)
165165
fmt.Printf("%s %s\n", color.Green("✔"), ansi.Faint("Files copied"))
166166

167-
spinner = ansi.StartSpinner(fmt.Sprintf("Configuring your code... %s", selectedSample), os.Stdout)
167+
spinner = ansi.StartNewSpinner(fmt.Sprintf("Configuring your code... %s", selectedSample), os.Stdout)
168168

169169
err = sample.ConfigureDotEnv(targetPath)
170170
if err != nil {

pkg/cmd/status.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (sc *statusCmd) runStatusCmd(cmd *cobra.Command, args []string) error {
7777
if sc.hideSpinner {
7878
time.Sleep(time.Duration(sc.pollRate) * time.Second)
7979
} else {
80-
spinner := ansi.StartSpinner("", os.Stderr)
80+
spinner := ansi.StartNewSpinner("", os.Stderr)
8181
time.Sleep(time.Duration(sc.pollRate) * time.Second)
8282
ansi.StopSpinner(spinner, "", os.Stderr)
8383
}

pkg/login/client_login.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ func Login(baseURL string, config *config.Config, input io.Reader) error {
5454
if isSSH() {
5555
fmt.Printf("To authenticate with Stripe, please go to: %s\n", links.BrowserURL)
5656

57-
s = ansi.StartSpinner("Waiting for confirmation...", os.Stdout)
57+
s = ansi.StartNewSpinner("Waiting for confirmation...", os.Stdout)
5858
} else {
5959
fmt.Printf("Press Enter to open the browser (^C to quit)")
6060
fmt.Fscanln(input)
6161

62-
s = ansi.StartSpinner("Waiting for confirmation...", os.Stdout)
62+
s = ansi.StartNewSpinner("Waiting for confirmation...", os.Stdout)
6363

6464
err = openBrowser(links.BrowserURL)
6565
if err != nil {
6666
msg := fmt.Sprintf("Failed to open browser, please go to %s manually.", links.BrowserURL)
6767
ansi.StopSpinner(s, msg, os.Stdout)
68-
s = ansi.StartSpinner("Waiting for confirmation...", os.Stdout)
68+
s = ansi.StartNewSpinner("Waiting for confirmation...", os.Stdout)
6969
}
7070
}
7171

pkg/logtailing/tailer.go

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ type Tailer struct {
6666
webSocketClient *websocket.Client
6767

6868
interruptCh chan os.Signal
69-
70-
ctx context.Context
7169
}
7270

7371
// EventPayload is the mapping for fields in event payloads from request log tailing
@@ -107,60 +105,84 @@ func New(cfg *Config) *Tailer {
107105
}
108106
}
109107

110-
// Run sets the websocket connection
111-
func (t *Tailer) Run() error {
112-
s := ansi.StartSpinner("Getting ready...", t.cfg.Log.Out)
113-
108+
func withSIGTERMCancel(ctx context.Context, onCancel func()) context.Context {
114109
// Create a context that will be canceled when Ctrl+C is pressed
115-
ctx, cancel := context.WithCancel(context.Background())
116-
t.ctx = ctx
117-
signal.Notify(t.interruptCh, os.Interrupt, syscall.SIGTERM)
110+
ctx, cancel := context.WithCancel(ctx)
111+
112+
interruptCh := make(chan os.Signal, 1)
113+
signal.Notify(interruptCh, os.Interrupt, syscall.SIGTERM)
118114

119115
go func() {
116+
<-interruptCh
117+
onCancel()
118+
cancel()
119+
}()
120+
return ctx
121+
}
122+
123+
const maxConnectAttempts = 3
124+
125+
// Run sets the websocket connection
126+
func (t *Tailer) Run(ctx context.Context) error {
127+
s := ansi.StartNewSpinner("Getting ready...", t.cfg.Log.Out)
128+
129+
ctx = withSIGTERMCancel(ctx, func() {
120130
log.WithFields(log.Fields{
121131
"prefix": "logtailing.Tailer.Run",
122132
}).Debug("Ctrl+C received, cleaning up...")
133+
})
123134

124-
<-t.interruptCh
125-
cancel()
126-
}()
135+
var warned = false
136+
var nAttempts int = 0
127137

128-
// Create the CLI session
129-
session, err := t.createSession()
130-
if err != nil {
131-
ansi.StopSpinner(s, "", t.cfg.Log.Out)
132-
t.cfg.Log.Fatalf("Error while authenticating with Stripe: %v", err)
133-
}
138+
for nAttempts < maxConnectAttempts {
139+
session, err := t.createSession(ctx)
134140

135-
// Create and start the websocket client
136-
t.webSocketClient = websocket.NewClient(
137-
session.WebSocketURL,
138-
session.WebSocketID,
139-
session.WebSocketAuthorizedFeature,
140-
&websocket.Config{
141-
EventHandler: websocket.EventHandlerFunc(t.processRequestLogEvent),
142-
Log: t.cfg.Log,
143-
NoWSS: t.cfg.NoWSS,
144-
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
145-
},
146-
)
147-
go t.webSocketClient.Run()
148-
149-
select {
150-
case <-t.webSocketClient.Connected():
151-
ansi.StopSpinner(s, "Ready! You're now waiting to receive API request logs (^C to quit)", t.cfg.Log.Out)
152-
case <-t.ctx.Done():
153-
ansi.StopSpinner(s, "", t.cfg.Log.Out)
154-
t.cfg.Log.Fatalf("Aborting")
155-
}
141+
if err != nil {
142+
ansi.StopSpinner(s, "", t.cfg.Log.Out)
143+
t.cfg.Log.Fatalf("Error while authenticating with Stripe: %v", err)
144+
}
156145

157-
if session.DisplayConnectFilterWarning {
158-
color := ansi.Color(os.Stdout)
159-
fmt.Printf("%s you specified the 'account' filter for Connect accounts but are not a Connect user, so the filter will not be applied.\n", color.Yellow("Warning"))
160-
}
146+
if session.DisplayConnectFilterWarning && !warned {
147+
color := ansi.Color(os.Stdout)
148+
fmt.Printf("%s you specified the 'account' filter for Connect accounts but are not a Connect user, so the filter will not be applied.\n", color.Yellow("Warning"))
149+
// Only display this warning once
150+
warned = true
151+
}
161152

162-
// Block until context is done (i.e. Ctrl+C is pressed)
163-
<-t.ctx.Done()
153+
t.webSocketClient = websocket.NewClient(
154+
session.WebSocketURL,
155+
session.WebSocketID,
156+
session.WebSocketAuthorizedFeature,
157+
&websocket.Config{
158+
EventHandler: websocket.EventHandlerFunc(t.processRequestLogEvent),
159+
Log: t.cfg.Log,
160+
NoWSS: t.cfg.NoWSS,
161+
ReconnectInterval: time.Duration(session.ReconnectDelay) * time.Second,
162+
},
163+
)
164+
165+
go func() {
166+
<-t.webSocketClient.Connected()
167+
nAttempts = 0
168+
ansi.StopSpinner(s, "Ready! You're now waiting to receive API request logs (^C to quit)", t.cfg.Log.Out)
169+
}()
170+
171+
go t.webSocketClient.Run(ctx)
172+
nAttempts++
173+
174+
select {
175+
case <-ctx.Done():
176+
ansi.StopSpinner(s, "", t.cfg.Log.Out)
177+
t.cfg.Log.Fatalf("Aborting")
178+
case <-t.webSocketClient.NotifyExpired:
179+
if nAttempts < maxConnectAttempts {
180+
ansi.StartSpinner(s, "Session expired, reconnecting...", t.cfg.Log.Out)
181+
} else {
182+
t.cfg.Log.Fatalf("Session expired. Terminating after %d failed attempts to reauthorize", nAttempts)
183+
}
184+
}
185+
}
164186

165187
if t.webSocketClient != nil {
166188
t.webSocketClient.Stop()
@@ -173,7 +195,7 @@ func (t *Tailer) Run() error {
173195
return nil
174196
}
175197

176-
func (t *Tailer) createSession() (*stripeauth.StripeCLISession, error) {
198+
func (t *Tailer) createSession(ctx context.Context) (*stripeauth.StripeCLISession, error) {
177199
var session *stripeauth.StripeCLISession
178200

179201
var err error
@@ -189,15 +211,15 @@ func (t *Tailer) createSession() (*stripeauth.StripeCLISession, error) {
189211
// Try to authorize at least 5 times before failing. Sometimes we have random
190212
// transient errors that we just need to retry for.
191213
for i := 0; i <= 5; i++ {
192-
session, err = t.stripeAuthClient.Authorize(t.ctx, t.cfg.DeviceName, t.cfg.WebSocketFeature, &filters)
214+
session, err = t.stripeAuthClient.Authorize(ctx, t.cfg.DeviceName, t.cfg.WebSocketFeature, &filters)
193215

194216
if err == nil {
195217
exitCh <- struct{}{}
196218
return
197219
}
198220

199221
select {
200-
case <-t.ctx.Done():
222+
case <-ctx.Done():
201223
exitCh <- struct{}{}
202224
return
203225
case <-time.After(1 * time.Second):

0 commit comments

Comments
 (0)