Skip to content

Commit

Permalink
Merge pull request #26 from martinohansen/main
Browse files Browse the repository at this point in the history
fix: token is invalid or expired
  • Loading branch information
frieser authored Dec 13, 2024
2 parents b9f2fcc + 0aa9f5f commit 3683fd6
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 82 deletions.
112 changes: 68 additions & 44 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package nordigen

import (
"context"
"errors"
"fmt"
"net/http"
"strings"
Expand All @@ -12,67 +14,84 @@ const baseUrl = "bankaccountdata.gocardless.com"
const apiPath = "/api/v2"

type Client struct {
c *http.Client
secretId string
secretKey string
expiration time.Time
token *Token
m *sync.Mutex
stopChan chan struct{}
c *http.Client
secretId string
secretKey string

token *Token
nextRefresh time.Time

m *sync.RWMutex
stopChan chan struct{}
}

type Transport struct {
rt http.RoundTripper
cli *Client
}

func (c *Client) refreshTokenIfNeeded() error {
// refreshTokenIfNeeded refreshes the token if refresh time has passed
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
c.m.Lock()
defer c.m.Unlock()

if time.Now().Add(time.Minute).Before(c.expiration) {
return nil
} else {
// Refresh the token if its expiration is less than a minute away
newToken, err := c.refreshToken(c.token.Refresh)
if err != nil {
return err
}
c.token = newToken
c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second)
if time.Now().Before(c.nextRefresh) {
return nil
}

newToken, err := c.refreshToken(ctx, c.token.Refresh)
if err != nil {
return err
}
c.updateToken(newToken)
return nil
}

func (c *Client) StartTokenHandler() {
c.stopChan = make(chan struct{})
// updateToken updates the client token and sets the refresh time to half the
// access token lifetime.
func (c *Client) updateToken(t *Token) {
c.token = t
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
}

// Initialize the first token and start the token handler
token, err := c.newToken()
// StartTokenHandler handles token refreshes in the background
func (c *Client) StartTokenHandler(ctx context.Context) error {
// Initialize the first token
token, err := c.newToken(ctx)
if err != nil {
panic("Failed to get initial token: " + err.Error())
return errors.New("failed to get initial token: " + err.Error())
}
c.token = token

go func() {
for {
timeToWait := time.Until(c.expiration) - time.Minute
if timeToWait < 0 {
// If the token is already expired, try to refresh immediately
timeToWait = 0
}

select {
case <-c.stopChan:
return
case <-time.After(timeToWait):
if err := c.refreshTokenIfNeeded(); err != nil {
// TODO(Martin): add retry logic
panic("Failed to refresh token: " + err.Error())
}
c.m.Lock()
c.updateToken(token)
c.m.Unlock()

go c.tokenRefreshLoop(ctx)
return nil
}

func (c *Client) tokenRefreshLoop(ctx context.Context) {
for {
c.m.RLock()
refreshTime := c.nextRefresh
c.m.RUnlock()

timeToWait := time.Until(refreshTime)
if timeToWait < 0 {
timeToWait = 0
}

select {
case <-c.stopChan:
return
case <-time.After(timeToWait):
if err := c.refreshTokenIfNeeded(ctx); err != nil {
panic(fmt.Sprintf("failed to refresh token: %s", err))
}
case <-ctx.Done():
return
}
}()
}
}

func (c *Client) StopTokenHandler() {
Expand All @@ -98,17 +117,22 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// NewClient creates a new Nordigen client that handles token refreshes and adds
// the necessary headers, host, and path to all requests.
func NewClient(secretId, secretKey string) (*Client, error) {
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{},
c := &Client{
c: &http.Client{Timeout: 60 * time.Second},
secretId: secretId,
secretKey: secretKey,

m: &sync.RWMutex{},
stopChan: make(chan struct{}),
}

// Add transport to handle headers, host and path for all requests
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}

// Start token handler
c.StartTokenHandler()
defer c.StopTokenHandler()
if err := c.StartTokenHandler(context.Background()); err != nil {
return nil, err
}

return c, nil
}
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestClientTokenRefresh(t *testing.T) {
t.Fatalf("NewClient: %s", err)
}

c.expiration = time.Now().Add(-time.Hour)
c.nextRefresh = time.Now().Add(-time.Hour)
_, err = c.ListRequisitions()
if err != nil {
t.Fatalf("ListRequisitions: %s", err)
Expand Down
78 changes: 41 additions & 37 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nordigen

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
Expand All @@ -16,87 +17,90 @@ type Token struct {
RefreshExpires int `json:"refresh_expires"`
}

type TokenRefresh struct {
Refresh string `json:"refresh"`
}

type Secret struct {
SecretId string `json:"secret_id"`
AccessId string `json:"secret_key"`
}

const tokenPath = "token"
const tokenNewPath = "new/"
const tokenRefreshPath = "refresh"

func (c Client) newToken() (*Token, error) {
req := http.Request{
Method: http.MethodPost,
URL: &url.URL{
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
},
}
const tokenRefreshPath = "refresh/"

func (c Client) newToken(ctx context.Context) (*Token, error) {
data, err := json.Marshal(Secret{
SecretId: c.secretId,
AccessId: c.secretKey,
})
if err != nil {
return nil, err
}
req.Body = io.NopCloser(bytes.NewBuffer(data))
resp, err := c.c.Do(&req)

if err != nil {
return nil, err
req := &http.Request{
Method: http.MethodPost,
Body: io.NopCloser(bytes.NewBuffer(data)),
URL: &url.URL{
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
},
}
body, err := io.ReadAll(resp.Body)
req = req.WithContext(ctx)

resp, err := c.c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
}
if resp.StatusCode != http.StatusOK {
return nil, &APIError{resp.StatusCode, string(body), err}
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
}
t := &Token{}
err = json.Unmarshal(body, &t)

if err != nil {
t := &Token{}
if err := json.Unmarshal(body, t); err != nil {
return nil, err
}

return t, nil
}

func (c Client) refreshToken(refresh string) (*Token, error) {
req := http.Request{
func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
if err != nil {
return nil, err
}

req := &http.Request{
Method: http.MethodPost,
Body: io.NopCloser(bytes.NewBuffer(data)),
URL: &url.URL{
Path: strings.Join([]string{tokenPath, tokenRefreshPath}, "/"),
},
}
data, err := json.Marshal(refresh)

if err != nil {
return &Token{}, err
}
req.Body = io.NopCloser(bytes.NewBuffer(data))

resp, err := c.c.Do(&req)
req = req.WithContext(ctx)

resp, err := c.c.Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()

if err != nil {
return nil, err
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
}
if resp.StatusCode != http.StatusOK {
return nil, &APIError{resp.StatusCode, string(body), err}
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
}
t := &Token{}
err = json.Unmarshal(body, &t)

if err != nil {
t := &Token{}
if err := json.Unmarshal(body, t); err != nil {
return nil, err
}

return t, nil
}

0 comments on commit 3683fd6

Please sign in to comment.