Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.*: Fix Azure MSI token refresh #4611

Merged
merged 1 commit into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/objstore/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ func NewBucket(logger log.Logger, azureConfig []byte, component string) (*Bucket
}

ctx := context.Background()
container, err := createContainer(ctx, conf)
container, err := createContainer(ctx, logger, conf)
if err != nil {
ret, ok := err.(blob.StorageError)
if !ok {
return nil, errors.Wrapf(err, "Azure API return unexpected error: %T\n", err)
}
if ret.ServiceCode() == "ContainerAlreadyExists" {
level.Debug(logger).Log("msg", "Getting connection to existing Azure blob container", "container", conf.ContainerName)
container, err = getContainer(ctx, conf)
container, err = getContainer(ctx, logger, conf)
if err != nil {
return nil, errors.Wrapf(err, "cannot get existing Azure blob container: %s", container)
}
Expand Down
36 changes: 21 additions & 15 deletions pkg/objstore/azure/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/Azure/azure-pipeline-go/pipeline"
blob "github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
)

// DirDelim is the delimiter used to model a directory structure in an object store bucket.
Expand All @@ -34,24 +36,29 @@ func init() {
pipeline.SetForceLogEnabled(false)
}

func getAzureStorageCredentials(conf Config) (blob.Credential, error) {
func getAzureStorageCredentials(logger log.Logger, conf Config) (blob.Credential, error) {
if conf.MSIResource != "" {
msiConfig := auth.NewMSIConfig()
msiConfig.Resource = conf.MSIResource

azureServicePrincipalToken, err := msiConfig.ServicePrincipalToken()
spt, err := msiConfig.ServicePrincipalToken()
if err != nil {
return nil, err
}

// Get a new token.
err = azureServicePrincipalToken.Refresh()
if err != nil {
if err := spt.Refresh(); err != nil {
return nil, err
}
token := azureServicePrincipalToken.Token()

return blob.NewTokenCredential(token.AccessToken, nil), nil
return blob.NewTokenCredential(spt.Token().AccessToken, func(tc blob.TokenCredential) time.Duration {
err := spt.Refresh()
if err != nil {
level.Error(logger).Log("msg", "could not refresh MSI token", "err", err)
// Retry later as the error can be related to API throttling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would 100% add metric or error log to indicate that we hit this and why (err)

return 30 * time.Second
bwplotka marked this conversation as resolved.
Show resolved Hide resolved
}
tc.SetToken(spt.Token().AccessToken)
return spt.Token().Expires().Sub(time.Now().Add(2 * time.Minute))
}), nil
}

credential, err := blob.NewSharedKeyCredential(conf.StorageAccountName, conf.StorageAccountKey)
Expand All @@ -61,9 +68,8 @@ func getAzureStorageCredentials(conf Config) (blob.Credential, error) {
return credential, nil
}

func getContainerURL(ctx context.Context, conf Config) (blob.ContainerURL, error) {

credentials, err := getAzureStorageCredentials(conf)
func getContainerURL(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
credentials, err := getAzureStorageCredentials(logger, conf)

if err != nil {
return blob.ContainerURL{}, err
Expand Down Expand Up @@ -134,8 +140,8 @@ func DefaultTransport(config Config) *http.Transport {
}
}

func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, conf)
func getContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, logger, conf)
if err != nil {
return blob.ContainerURL{}, err
}
Expand All @@ -144,8 +150,8 @@ func getContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
return c, err
}

func createContainer(ctx context.Context, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, conf)
func createContainer(ctx context.Context, logger log.Logger, conf Config) (blob.ContainerURL, error) {
c, err := getContainerURL(ctx, logger, conf)
if err != nil {
return blob.ContainerURL{}, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/objstore/azure/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"testing"

"github.com/go-kit/kit/log"
"github.com/thanos-io/thanos/pkg/testutil"
)

Expand Down Expand Up @@ -50,7 +51,7 @@ func Test_getContainerURL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
got, err := getContainerURL(ctx, tt.args.conf)
got, err := getContainerURL(ctx, log.NewNopLogger(), tt.args.conf)
if (err != nil) != tt.wantErr {
t.Errorf("getContainerURL() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down