From 447e04cf07a54f37a435ee98477978a6634c4dfd Mon Sep 17 00:00:00 2001 From: Philip Laine Date: Mon, 30 Aug 2021 10:29:29 +0200 Subject: [PATCH] Fix MSI token refresh Signed-off-by: Philip Laine --- pkg/objstore/azure/azure.go | 4 ++-- pkg/objstore/azure/helpers.go | 36 +++++++++++++++++------------- pkg/objstore/azure/helpers_test.go | 3 ++- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/pkg/objstore/azure/azure.go b/pkg/objstore/azure/azure.go index 4b42f2283c..5feb67048c 100644 --- a/pkg/objstore/azure/azure.go +++ b/pkg/objstore/azure/azure.go @@ -174,7 +174,7 @@ func NewBucket(logger log.Logger, azureConfig []byte, component string) (*Bucket } ctx := context.Background() - container, err := createContainer(ctx, conf) + container, err := createContainer(ctx, logger, conf) if err != nil { ret, ok := err.(blob.StorageError) if !ok { @@ -182,7 +182,7 @@ func NewBucket(logger log.Logger, azureConfig []byte, component string) (*Bucket } if ret.ServiceCode() == "ContainerAlreadyExists" { level.Debug(logger).Log("msg", "Getting connection to existing Azure blob container", "container", conf.ContainerName) - container, err = getContainer(ctx, conf) + container, err = getContainer(ctx, logger, conf) if err != nil { return nil, errors.Wrapf(err, "cannot get existing Azure blob container: %s", container) } diff --git a/pkg/objstore/azure/helpers.go b/pkg/objstore/azure/helpers.go index 00a2c1e03c..7d2977cf0c 100644 --- a/pkg/objstore/azure/helpers.go +++ b/pkg/objstore/azure/helpers.go @@ -16,6 +16,8 @@ import ( "github.com/Azure/azure-pipeline-go/pipeline" blob "github.com/Azure/azure-storage-blob-go/azblob" "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" ) // DirDelim is the delimiter used to model a directory structure in an object store bucket. @@ -34,24 +36,29 @@ func init() { pipeline.SetForceLogEnabled(false) } -func getAzureStorageCredentials(conf Config) (blob.Credential, error) { +func getAzureStorageCredentials(logger log.Logger, conf Config) (blob.Credential, error) { if conf.MSIResource != "" { msiConfig := auth.NewMSIConfig() msiConfig.Resource = conf.MSIResource - azureServicePrincipalToken, err := msiConfig.ServicePrincipalToken() + spt, err := msiConfig.ServicePrincipalToken() if err != nil { return nil, err } - - // Get a new token. - err = azureServicePrincipalToken.Refresh() - if err != nil { + if err := spt.Refresh(); err != nil { return nil, err } - token := azureServicePrincipalToken.Token() - return blob.NewTokenCredential(token.AccessToken, nil), nil + return blob.NewTokenCredential(spt.Token().AccessToken, func(tc blob.TokenCredential) time.Duration { + err := spt.Refresh() + if err != nil { + level.Error(logger).Log("msg", "could not refresh MSI token", "err", err) + // Retry later as the error can be related to API throttling + return 30 * time.Second + } + tc.SetToken(spt.Token().AccessToken) + return spt.Token().Expires().Sub(time.Now().Add(2 * time.Minute)) + }), nil } credential, err := blob.NewSharedKeyCredential(conf.StorageAccountName, conf.StorageAccountKey) @@ -61,9 +68,8 @@ func getAzureStorageCredentials(conf Config) (blob.Credential, error) { return credential, nil } -func getContainerURL(ctx context.Context, conf Config) (blob.ContainerURL, error) { - - credentials, err := getAzureStorageCredentials(conf) +func getContainerURL(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) { + credentials, err := getAzureStorageCredentials(logger, conf) if err != nil { return blob.ContainerURL{}, err @@ -134,8 +140,8 @@ func DefaultTransport(config Config) *http.Transport { } } -func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) { - c, err := getContainerURL(ctx, conf) +func getContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) { + c, err := getContainerURL(ctx, logger, conf) if err != nil { return blob.ContainerURL{}, err } @@ -144,8 +150,8 @@ func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) { return c, err } -func createContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) { - c, err := getContainerURL(ctx, conf) +func createContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) { + c, err := getContainerURL(ctx, logger, conf) if err != nil { return blob.ContainerURL{}, err } diff --git a/pkg/objstore/azure/helpers_test.go b/pkg/objstore/azure/helpers_test.go index 3b1b22119d..8b46b9d3b7 100644 --- a/pkg/objstore/azure/helpers_test.go +++ b/pkg/objstore/azure/helpers_test.go @@ -7,6 +7,7 @@ import ( "context" "testing" + "github.com/go-kit/kit/log" "github.com/thanos-io/thanos/pkg/testutil" ) @@ -50,7 +51,7 @@ func Test_getContainerURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - got, err := getContainerURL(ctx, tt.args.conf) + got, err := getContainerURL(ctx, log.NewNopLogger(), tt.args.conf) if (err != nil) != tt.wantErr { t.Errorf("getContainerURL() error = %v, wantErr %v", err, tt.wantErr) return