Skip to content

Commit

Permalink
fix refresh too many times issue
Browse files Browse the repository at this point in the history
  • Loading branch information
mozillazg committed Aug 17, 2023
1 parent a8cac0f commit 1638505
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 6 deletions.
31 changes: 25 additions & 6 deletions pkg/credentials/provider/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type Updater struct {
cred *Credentials
lockForCred sync.RWMutex

Logger Logger
Logger Logger
nowFunc func() time.Time
}

type UpdaterOptions struct {
Expand All @@ -40,6 +41,7 @@ func NewUpdater(getter getCredentialsFunc, opts UpdaterOptions) *Updater {
cred: nil,
lockForCred: sync.RWMutex{},
Logger: opts.Logger,
nowFunc: time.Now,
}
return u
}
Expand Down Expand Up @@ -83,19 +85,25 @@ func (u *Updater) Credentials(ctx context.Context) (*Credentials, error) {
func (u *Updater) refreshCredForLoop(ctx context.Context) {
exp := u.expiration()

if exp.Add(-u.expiryWindowForRefreshLoop).Before(time.Now()) {
if !u.expired(u.expiryWindowForRefreshLoop) {
return
}

u.logger().Debug(fmt.Sprintf("start refresh credentials, current expiration: %s",
exp.Format("2006-01-02T15:04:05Z")))

for i := 0; i < 5; i++ {
maxRetry := 5
for i := 0; i < maxRetry; i++ {
err := u.refreshCred(ctx)
if err == nil {
return
}
if _, ok := err.(*NotEnableError); ok {
return
}
time.Sleep(time.Second * time.Duration(i))
if i < maxRetry-1 {
time.Sleep(time.Second * time.Duration(i))
}
}
}

Expand Down Expand Up @@ -135,9 +143,13 @@ func (u *Updater) getCred() *Credentials {
}

func (u *Updater) Expired() bool {
return u.expired(0)
}

func (u *Updater) expired(expiryDelta time.Duration) bool {
exp := u.expiration()

return exp.Before(time.Now())
return exp.Add(-expiryDelta).Before(u.now())
}

func (u *Updater) expiration() time.Time {
Expand All @@ -147,7 +159,14 @@ func (u *Updater) expiration() time.Time {
return time.Time{}
}

return cred.Expiration
return cred.Expiration.Round(0)
}

func (u *Updater) now() time.Time {
if u.nowFunc == nil {
return time.Now()
}
return u.nowFunc()
}

func (u *Updater) logger() Logger {
Expand Down
141 changes: 141 additions & 0 deletions pkg/credentials/provider/updater_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package provider

import (
"context"
"errors"
"fmt"
"testing"
"time"
)

type TLogger struct {
t *testing.T
}

func (d TLogger) Info(msg string) {
d.t.Logf(fmt.Sprintf("%s, %s", time.Now().Format(time.RFC3339), msg))
}

func (d TLogger) Debug(msg string) {
d.t.Logf(fmt.Sprintf("%s, %s", time.Now().Format(time.RFC3339), msg))
}

func (d TLogger) Error(err error, msg string) {
d.t.Logf(fmt.Sprintf("%s, %s", time.Now().Format(time.RFC3339), msg))
}

func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
var callCount int
fakeCred := Credentials{
Expiration: time.Now().Add(time.Minute),
}
u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
callCount++
return &fakeCred, nil
}, UpdaterOptions{
ExpiryWindow: 0,
RefreshPeriod: 0,
Logger: TLogger{t: t},
})

u.refreshCredForLoop(context.TODO())
if callCount != 1 {
t.Errorf("callCount should be 1 but got %d", callCount)
}
ret := u.Expired()
if ret {
t.Errorf("should not expired")
}

u.refreshCredForLoop(context.TODO())
if callCount != 1 {
t.Errorf("callCount should be 1 but got %d", callCount)
}

u.nowFunc = func() time.Time {
return time.Now().Add(time.Minute)
}
ret = u.Expired()
if !ret {
t.Errorf("should expired")
}

fakeCred.Expiration = time.Now().Add(time.Minute * 5)
u.refreshCredForLoop(context.TODO())
if callCount != 2 {
t.Errorf("callCount should be 2 but got %d", callCount)
}
ret = u.Expired()
if ret {
t.Errorf("should not expired")
}
}

func TestUpdater_refreshCredForLoop_erorr(t *testing.T) {
var callCount int

u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
callCount++
return nil, errors.New("error message")
}, UpdaterOptions{
ExpiryWindow: 0,
RefreshPeriod: 0,
Logger: TLogger{t: t},
})

u.refreshCredForLoop(context.TODO())
if callCount != 5 {
t.Errorf("callCount should be 5 but got %d", callCount)
}
ret := u.Expired()
if !ret {
t.Errorf("should expired")
}
}

func TestUpdater_Credentials_refresh(t *testing.T) {
var callCount int
fakeCred := Credentials{
Expiration: time.Now().Add(time.Minute),
}
u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
callCount++
return &fakeCred, nil
}, UpdaterOptions{
ExpiryWindow: 0,
RefreshPeriod: 0,
Logger: TLogger{t: t},
})

u.Credentials(context.TODO())
if callCount != 1 {
t.Errorf("callCount should be 1 but got %d", callCount)
}
ret := u.Expired()
if ret {
t.Errorf("should not expired")
}

u.Credentials(context.TODO())
if callCount != 1 {
t.Errorf("callCount should be 1 but got %d", callCount)
}

u.nowFunc = func() time.Time {
return time.Now().Add(time.Minute)
}
ret = u.Expired()
if !ret {
t.Errorf("should expired")
}

fakeCred.Expiration = time.Now().Add(time.Minute * 5)
u.Credentials(context.TODO())
if callCount != 2 {
t.Errorf("callCount should be 2 but got %d", callCount)
}
ret = u.Expired()
if ret {
t.Errorf("should not expired")
}
}

0 comments on commit 1638505

Please sign in to comment.