From 0aa9f5fe0b3ba58537488b8b232ba4a0f83bce38 Mon Sep 17 00:00:00 2001 From: Martin Hansen Date: Tue, 10 Dec 2024 11:41:24 +0100 Subject: [PATCH] fix: token is invalid or expired The refresh goroutine was stopped immediately after starting, fix that logic error to refresh the token in the background after half the access token lifetime. Fixes martinohansen/ynabber#97 --- client.go | 112 ++++++++++++++++++++++++++++++------------------- client_test.go | 2 +- token.go | 78 ++++++++++++++++++---------------- 3 files changed, 110 insertions(+), 82 deletions(-) diff --git a/client.go b/client.go index ad0aded..230fbd3 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,8 @@ package nordigen import ( + "context" + "errors" "fmt" "net/http" "strings" @@ -12,13 +14,15 @@ const baseUrl = "bankaccountdata.gocardless.com" const apiPath = "/api/v2" type Client struct { - c *http.Client - secretId string - secretKey string - expiration time.Time - token *Token - m *sync.Mutex - stopChan chan struct{} + c *http.Client + secretId string + secretKey string + + token *Token + nextRefresh time.Time + + m *sync.RWMutex + stopChan chan struct{} } type Transport struct { @@ -26,53 +30,68 @@ type Transport struct { cli *Client } -func (c *Client) refreshTokenIfNeeded() error { +// refreshTokenIfNeeded refreshes the token if refresh time has passed +func (c *Client) refreshTokenIfNeeded(ctx context.Context) error { c.m.Lock() defer c.m.Unlock() - if time.Now().Add(time.Minute).Before(c.expiration) { - return nil - } else { - // Refresh the token if its expiration is less than a minute away - newToken, err := c.refreshToken(c.token.Refresh) - if err != nil { - return err - } - c.token = newToken - c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second) + if time.Now().Before(c.nextRefresh) { return nil } + + newToken, err := c.refreshToken(ctx, c.token.Refresh) + if err != nil { + return err + } + c.updateToken(newToken) + return nil } -func (c *Client) StartTokenHandler() { - c.stopChan = make(chan struct{}) +// updateToken updates the client token and sets the refresh time to half the +// access token lifetime. +func (c *Client) updateToken(t *Token) { + c.token = t + c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second) +} - // Initialize the first token and start the token handler - token, err := c.newToken() +// StartTokenHandler handles token refreshes in the background +func (c *Client) StartTokenHandler(ctx context.Context) error { + // Initialize the first token + token, err := c.newToken(ctx) if err != nil { - panic("Failed to get initial token: " + err.Error()) + return errors.New("failed to get initial token: " + err.Error()) } - c.token = token - - go func() { - for { - timeToWait := time.Until(c.expiration) - time.Minute - if timeToWait < 0 { - // If the token is already expired, try to refresh immediately - timeToWait = 0 - } - select { - case <-c.stopChan: - return - case <-time.After(timeToWait): - if err := c.refreshTokenIfNeeded(); err != nil { - // TODO(Martin): add retry logic - panic("Failed to refresh token: " + err.Error()) - } + c.m.Lock() + c.updateToken(token) + c.m.Unlock() + + go c.tokenRefreshLoop(ctx) + return nil +} + +func (c *Client) tokenRefreshLoop(ctx context.Context) { + for { + c.m.RLock() + refreshTime := c.nextRefresh + c.m.RUnlock() + + timeToWait := time.Until(refreshTime) + if timeToWait < 0 { + timeToWait = 0 + } + + select { + case <-c.stopChan: + return + case <-time.After(timeToWait): + if err := c.refreshTokenIfNeeded(ctx); err != nil { + panic(fmt.Sprintf("failed to refresh token: %s", err)) } + case <-ctx.Done(): + return } - }() + } } func (c *Client) StopTokenHandler() { @@ -98,17 +117,22 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) { // NewClient creates a new Nordigen client that handles token refreshes and adds // the necessary headers, host, and path to all requests. func NewClient(secretId, secretKey string) (*Client, error) { - c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}, + c := &Client{ + c: &http.Client{Timeout: 60 * time.Second}, secretId: secretId, secretKey: secretKey, + + m: &sync.RWMutex{}, + stopChan: make(chan struct{}), } // Add transport to handle headers, host and path for all requests c.c.Transport = Transport{rt: http.DefaultTransport, cli: c} // Start token handler - c.StartTokenHandler() - defer c.StopTokenHandler() + if err := c.StartTokenHandler(context.Background()); err != nil { + return nil, err + } return c, nil } diff --git a/client_test.go b/client_test.go index e452181..ca27667 100644 --- a/client_test.go +++ b/client_test.go @@ -22,7 +22,7 @@ func TestClientTokenRefresh(t *testing.T) { t.Fatalf("NewClient: %s", err) } - c.expiration = time.Now().Add(-time.Hour) + c.nextRefresh = time.Now().Add(-time.Hour) _, err = c.ListRequisitions() if err != nil { t.Fatalf("ListRequisitions: %s", err) diff --git a/token.go b/token.go index 77fb6c2..c3b5318 100644 --- a/token.go +++ b/token.go @@ -2,6 +2,7 @@ package nordigen import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -16,6 +17,10 @@ type Token struct { RefreshExpires int `json:"refresh_expires"` } +type TokenRefresh struct { + Refresh string `json:"refresh"` +} + type Secret struct { SecretId string `json:"secret_id"` AccessId string `json:"secret_key"` @@ -23,16 +28,9 @@ type Secret struct { const tokenPath = "token" const tokenNewPath = "new/" -const tokenRefreshPath = "refresh" - -func (c Client) newToken() (*Token, error) { - req := http.Request{ - Method: http.MethodPost, - URL: &url.URL{ - Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"), - }, - } +const tokenRefreshPath = "refresh/" +func (c Client) newToken(ctx context.Context) (*Token, error) { data, err := json.Marshal(Secret{ SecretId: c.secretId, AccessId: c.secretKey, @@ -40,63 +38,69 @@ func (c Client) newToken() (*Token, error) { if err != nil { return nil, err } - req.Body = io.NopCloser(bytes.NewBuffer(data)) - resp, err := c.c.Do(&req) - if err != nil { - return nil, err + req := &http.Request{ + Method: http.MethodPost, + Body: io.NopCloser(bytes.NewBuffer(data)), + URL: &url.URL{ + Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"), + }, } - body, err := io.ReadAll(resp.Body) + req = req.WithContext(ctx) + resp, err := c.c.Do(req) if err != nil { return nil, err } + defer resp.Body.Close() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, readErr + } if resp.StatusCode != http.StatusOK { - return nil, &APIError{resp.StatusCode, string(body), err} + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} } - t := &Token{} - err = json.Unmarshal(body, &t) - if err != nil { + t := &Token{} + if err := json.Unmarshal(body, t); err != nil { return nil, err } - return t, nil } -func (c Client) refreshToken(refresh string) (*Token, error) { - req := http.Request{ +func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) { + data, err := json.Marshal(TokenRefresh{Refresh: refresh}) + if err != nil { + return nil, err + } + + req := &http.Request{ Method: http.MethodPost, + Body: io.NopCloser(bytes.NewBuffer(data)), URL: &url.URL{ Path: strings.Join([]string{tokenPath, tokenRefreshPath}, "/"), }, } - data, err := json.Marshal(refresh) - - if err != nil { - return &Token{}, err - } - req.Body = io.NopCloser(bytes.NewBuffer(data)) - - resp, err := c.c.Do(&req) + req = req.WithContext(ctx) + resp, err := c.c.Do(req) if err != nil { return nil, err } - body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() - if err != nil { - return nil, err + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, readErr } if resp.StatusCode != http.StatusOK { - return nil, &APIError{resp.StatusCode, string(body), err} + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} } - t := &Token{} - err = json.Unmarshal(body, &t) - if err != nil { + t := &Token{} + if err := json.Unmarshal(body, t); err != nil { return nil, err } - return t, nil }