diff --git a/client.go b/client.go index ad0aded..230fbd3 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,8 @@ package nordigen import ( + "context" + "errors" "fmt" "net/http" "strings" @@ -12,13 +14,15 @@ const baseUrl = "bankaccountdata.gocardless.com" const apiPath = "/api/v2" type Client struct { - c *http.Client - secretId string - secretKey string - expiration time.Time - token *Token - m *sync.Mutex - stopChan chan struct{} + c *http.Client + secretId string + secretKey string + + token *Token + nextRefresh time.Time + + m *sync.RWMutex + stopChan chan struct{} } type Transport struct { @@ -26,53 +30,68 @@ type Transport struct { cli *Client } -func (c *Client) refreshTokenIfNeeded() error { +// refreshTokenIfNeeded refreshes the token if refresh time has passed +func (c *Client) refreshTokenIfNeeded(ctx context.Context) error { c.m.Lock() defer c.m.Unlock() - if time.Now().Add(time.Minute).Before(c.expiration) { - return nil - } else { - // Refresh the token if its expiration is less than a minute away - newToken, err := c.refreshToken(c.token.Refresh) - if err != nil { - return err - } - c.token = newToken - c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second) + if time.Now().Before(c.nextRefresh) { return nil } + + newToken, err := c.refreshToken(ctx, c.token.Refresh) + if err != nil { + return err + } + c.updateToken(newToken) + return nil } -func (c *Client) StartTokenHandler() { - c.stopChan = make(chan struct{}) +// updateToken updates the client token and sets the refresh time to half the +// access token lifetime. +func (c *Client) updateToken(t *Token) { + c.token = t + c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second) +} - // Initialize the first token and start the token handler - token, err := c.newToken() +// StartTokenHandler handles token refreshes in the background +func (c *Client) StartTokenHandler(ctx context.Context) error { + // Initialize the first token + token, err := c.newToken(ctx) if err != nil { - panic("Failed to get initial token: " + err.Error()) + return errors.New("failed to get initial token: " + err.Error()) } - c.token = token - - go func() { - for { - timeToWait := time.Until(c.expiration) - time.Minute - if timeToWait < 0 { - // If the token is already expired, try to refresh immediately - timeToWait = 0 - } - select { - case <-c.stopChan: - return - case <-time.After(timeToWait): - if err := c.refreshTokenIfNeeded(); err != nil { - // TODO(Martin): add retry logic - panic("Failed to refresh token: " + err.Error()) - } + c.m.Lock() + c.updateToken(token) + c.m.Unlock() + + go c.tokenRefreshLoop(ctx) + return nil +} + +func (c *Client) tokenRefreshLoop(ctx context.Context) { + for { + c.m.RLock() + refreshTime := c.nextRefresh + c.m.RUnlock() + + timeToWait := time.Until(refreshTime) + if timeToWait < 0 { + timeToWait = 0 + } + + select { + case <-c.stopChan: + return + case <-time.After(timeToWait): + if err := c.refreshTokenIfNeeded(ctx); err != nil { + panic(fmt.Sprintf("failed to refresh token: %s", err)) } + case <-ctx.Done(): + return } - }() + } } func (c *Client) StopTokenHandler() { @@ -98,17 +117,22 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) { // NewClient creates a new Nordigen client that handles token refreshes and adds // the necessary headers, host, and path to all requests. func NewClient(secretId, secretKey string) (*Client, error) { - c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}, + c := &Client{ + c: &http.Client{Timeout: 60 * time.Second}, secretId: secretId, secretKey: secretKey, + + m: &sync.RWMutex{}, + stopChan: make(chan struct{}), } // Add transport to handle headers, host and path for all requests c.c.Transport = Transport{rt: http.DefaultTransport, cli: c} // Start token handler - c.StartTokenHandler() - defer c.StopTokenHandler() + if err := c.StartTokenHandler(context.Background()); err != nil { + return nil, err + } return c, nil } diff --git a/client_test.go b/client_test.go index e452181..ca27667 100644 --- a/client_test.go +++ b/client_test.go @@ -22,7 +22,7 @@ func TestClientTokenRefresh(t *testing.T) { t.Fatalf("NewClient: %s", err) } - c.expiration = time.Now().Add(-time.Hour) + c.nextRefresh = time.Now().Add(-time.Hour) _, err = c.ListRequisitions() if err != nil { t.Fatalf("ListRequisitions: %s", err) diff --git a/token.go b/token.go index 77fb6c2..c3b5318 100644 --- a/token.go +++ b/token.go @@ -2,6 +2,7 @@ package nordigen import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -16,6 +17,10 @@ type Token struct { RefreshExpires int `json:"refresh_expires"` } +type TokenRefresh struct { + Refresh string `json:"refresh"` +} + type Secret struct { SecretId string `json:"secret_id"` AccessId string `json:"secret_key"` @@ -23,16 +28,9 @@ type Secret struct { const tokenPath = "token" const tokenNewPath = "new/" -const tokenRefreshPath = "refresh" - -func (c Client) newToken() (*Token, error) { - req := http.Request{ - Method: http.MethodPost, - URL: &url.URL{ - Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"), - }, - } +const tokenRefreshPath = "refresh/" +func (c Client) newToken(ctx context.Context) (*Token, error) { data, err := json.Marshal(Secret{ SecretId: c.secretId, AccessId: c.secretKey, @@ -40,63 +38,69 @@ func (c Client) newToken() (*Token, error) { if err != nil { return nil, err } - req.Body = io.NopCloser(bytes.NewBuffer(data)) - resp, err := c.c.Do(&req) - if err != nil { - return nil, err + req := &http.Request{ + Method: http.MethodPost, + Body: io.NopCloser(bytes.NewBuffer(data)), + URL: &url.URL{ + Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"), + }, } - body, err := io.ReadAll(resp.Body) + req = req.WithContext(ctx) + resp, err := c.c.Do(req) if err != nil { return nil, err } + defer resp.Body.Close() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, readErr + } if resp.StatusCode != http.StatusOK { - return nil, &APIError{resp.StatusCode, string(body), err} + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} } - t := &Token{} - err = json.Unmarshal(body, &t) - if err != nil { + t := &Token{} + if err := json.Unmarshal(body, t); err != nil { return nil, err } - return t, nil } -func (c Client) refreshToken(refresh string) (*Token, error) { - req := http.Request{ +func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) { + data, err := json.Marshal(TokenRefresh{Refresh: refresh}) + if err != nil { + return nil, err + } + + req := &http.Request{ Method: http.MethodPost, + Body: io.NopCloser(bytes.NewBuffer(data)), URL: &url.URL{ Path: strings.Join([]string{tokenPath, tokenRefreshPath}, "/"), }, } - data, err := json.Marshal(refresh) - - if err != nil { - return &Token{}, err - } - req.Body = io.NopCloser(bytes.NewBuffer(data)) - - resp, err := c.c.Do(&req) + req = req.WithContext(ctx) + resp, err := c.c.Do(req) if err != nil { return nil, err } - body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() - if err != nil { - return nil, err + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, readErr } if resp.StatusCode != http.StatusOK { - return nil, &APIError{resp.StatusCode, string(body), err} + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} } - t := &Token{} - err = json.Unmarshal(body, &t) - if err != nil { + t := &Token{} + if err := json.Unmarshal(body, t); err != nil { return nil, err } - return t, nil }