Skip to content

Commit

Permalink
azcontainerregistry: delegate token caching
Browse files Browse the repository at this point in the history
Mechanically refactor token caching, retrieval and exchange behavior
from the authentication policy into a dedicated object, allowing us to
separate token management concerns from Azure client policy concerns.

Signed-off-by: Steve Kuznetsov <skuznets@redhat.com>
  • Loading branch information
stevekuznetsov committed Aug 2, 2024
1 parent 523f4e1 commit f993345
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 46 deletions.
40 changes: 18 additions & 22 deletions sdk/containers/azcontainerregistry/authentication_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ import (
"fmt"
"net/http"
"strings"
"sync/atomic"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
)

const (
Expand All @@ -30,6 +28,7 @@ const (
)

type authenticationPolicyOptions struct {
*authenticationTokenCacheOptions
}

// authenticationPolicy is a policy to do the challenge-based authentication for container registry service. The authorization flow is as follows:
Expand All @@ -47,19 +46,15 @@ type authenticationPolicyOptions struct {
// Each registry service shares one refresh token, it will be cached in refreshTokenCache until expire time.
// Since the scope will be different for different API/repository/artifact, accessTokenCache will only work when continuously calling same API.
type authenticationPolicy struct {
refreshTokenCache *temporal.Resource[azcore.AccessToken, acquiringResourceState]
accessTokenCache atomic.Value
cred azcore.TokenCredential
aadScopes []string
authClient *authenticationClient
accessTokenCache *authenticationTokenCache
}

func newAuthenticationPolicy(cred azcore.TokenCredential, scopes []string, authClient *authenticationClient, opts *authenticationPolicyOptions) *authenticationPolicy {
if opts == nil {
opts = &authenticationPolicyOptions{}
}
return &authenticationPolicy{
cred: cred,
aadScopes: scopes,
authClient: authClient,
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
accessTokenCache: newAuthenticationTokenCache(cred, scopes, authClient, opts.authenticationTokenCacheOptions),
}
}

Expand All @@ -69,7 +64,7 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
if req.Raw().Header.Get(headerAuthorization) != "" {
// retry request could do the request with existed token directly
resp, err = req.Next()
} else if accessToken := p.accessTokenCache.Load(); accessToken != nil && accessToken != "" {
} else if accessToken := p.accessTokenCache.Load(); accessToken != "" {
// if there is a previous access token, then we try to use this token to do the request
req.Raw().Header.Set(
headerAuthorization,
Expand All @@ -95,10 +90,9 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
if service, scope, err = findServiceAndScope(resp); err != nil {
return nil, err
}
if accessToken, err = p.getAccessToken(req.Raw().Context(), service, scope); err != nil {
if accessToken, err = p.accessTokenCache.AcquireAccessToken(req.Raw().Context(), service, scope); err != nil {
return nil, err
}
p.accessTokenCache.Store(accessToken)
req.Raw().Header.Set(
headerAuthorization,
fmt.Sprintf("%s%s", bearerHeader, accessToken),
Expand All @@ -113,34 +107,36 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
return resp, nil
}

func (p *authenticationPolicy) getAccessToken(ctx context.Context, service, scope string) (string, error) {
func (c *authenticationTokenCache) AcquireAccessToken(ctx context.Context, service, scope string) (string, error) {
// anonymous access
if p.cred == nil {
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
if c.cred == nil {
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

// access with token
// get refresh token from cache/request
refreshToken, err := p.refreshTokenCache.Get(acquiringResourceState{
refreshToken, err := c.refreshTokenCache.Get(acquiringResourceState{
ctx: ctx,
aadCredential: p.cred,
aadScopes: p.aadScopes,
authClient: p.authClient,
aadCredential: c.cred,
aadScopes: c.aadScopes,
authClient: c.authClient,
service: service,
})
if err != nil {
return "", err
}

// get access token from request
resp, err := p.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

Expand Down
56 changes: 32 additions & 24 deletions sdk/containers/azcontainerregistry/authentication_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ func Test_authenticationPolicy_getAccessToken_live(t *testing.T) {
authClient, err := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
require.NoError(t, err)
p := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
cred,
[]string{options.Cloud.Services[ServiceName].Audience + "/.default"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
cred,
[]string{options.Cloud.Services[ServiceName].Audience + "/.default"},
authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
require.NoError(t, err)
token, err := p.getAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
token, err := p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
require.NoError(t, err)
require.NotEmpty(t, token)
}
Expand All @@ -161,22 +163,24 @@ func Test_authenticationPolicy_getAccessToken_error(t *testing.T) {
require.NoError(t, err)

p := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
require.NoError(t, err)
_, err = p.getAccessToken(request.Raw().Context(), "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
_, err = p.getAccessToken(request.Raw().Context(), "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
_, err = p.getAccessToken(request.Raw().Context(), "service", "scope")
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
p.cred = nil
_, err = p.getAccessToken(request.Raw().Context(), "service", "scope")
p.accessTokenCache.cred = nil
_, err = p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), "service", "scope")
require.Error(t, err)
}

Expand All @@ -186,12 +190,14 @@ func Test_authenticationPolicy_getAccessToken_live_anonymous(t *testing.T) {
authClient, err := newAuthenticationClient(endpoint, &authenticationClientOptions{options})
require.NoError(t, err)
p := &authenticationPolicy{
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
authClient: authClient,
accessTokenCache: &authenticationTokenCache{
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
authClient: authClient,
},
}
request, err := runtime.NewRequest(context.Background(), http.MethodGet, "https://test.com")
require.NoError(t, err)
token, err := p.getAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
token, err := p.accessTokenCache.AcquireAccessToken(request.Raw().Context(), strings.TrimPrefix(endpoint, "https://"), "registry:catalog:*")
require.NoError(t, err)
require.NotEmpty(t, token)
}
Expand Down Expand Up @@ -244,11 +250,13 @@ func Test_authenticationPolicy(t *testing.T) {
authClient, err := newAuthenticationClient(srv.URL(), &authenticationClientOptions{ClientOptions: azcore.ClientOptions{Transport: srv}})
require.NoError(t, err)
authPolicy := &authenticationPolicy{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
accessTokenCache: &authenticationTokenCache{
temporal.NewResource(acquireRefreshToken),
atomic.Value{},
&credential.Fake{},
[]string{"test"},
authClient,
},
}
pl := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, &policy.ClientOptions{Transport: srv})

Expand Down
41 changes: 41 additions & 0 deletions sdk/containers/azcontainerregistry/authentication_token_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package azcontainerregistry

import (
"sync/atomic"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
)

type authenticationTokenCacheOptions struct{}

type authenticationTokenCache struct {
refreshTokenCache *temporal.Resource[azcore.AccessToken, acquiringResourceState]
accessTokenCache atomic.Value
cred azcore.TokenCredential
aadScopes []string
authClient *authenticationClient
}

func newAuthenticationTokenCache(cred azcore.TokenCredential, scopes []string, authClient *authenticationClient, opts *authenticationTokenCacheOptions) *authenticationTokenCache {
return &authenticationTokenCache{
cred: cred,
aadScopes: scopes,
authClient: authClient,
refreshTokenCache: temporal.NewResource(acquireRefreshToken),
}
}

func (c *authenticationTokenCache) Load() string {
value, ok := c.accessTokenCache.Load().(string)
if !ok {
return ""
}
return value
}

0 comments on commit f993345

Please sign in to comment.