Skip to content

Commit

Permalink
azcontainerregistry: move code
Browse files Browse the repository at this point in the history
This commit simply moves code from one location to another, broken out
from other commits for ease of review.

Signed-off-by: Steve Kuznetsov <skuznets@redhat.com>
  • Loading branch information
stevekuznetsov committed Aug 6, 2024
1 parent 3206f97 commit 308795f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 113 deletions.
113 changes: 0 additions & 113 deletions sdk/containers/azcontainerregistry/authentication_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@
package azcontainerregistry

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)

const (
Expand Down Expand Up @@ -106,39 +101,6 @@ func (p *authenticationPolicy) Do(req *policy.Request) (*http.Response, error) {
return resp, nil
}

func (c *authenticationTokenCache) AcquireAccessToken(ctx context.Context, service, scope string) (string, error) {
// anonymous access
if c.cred == nil {
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

// access with token
// get refresh token from cache/request
refreshToken, err := c.refreshTokenCache.Get(acquiringResourceState{
ctx: ctx,
aadCredential: c.cred,
aadScopes: c.aadScopes,
authClient: c.authClient,
service: service,
})
if err != nil {
return "", err
}

// get access token from request
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

func findServiceAndScope(resp *http.Response) (string, string, error) {
authHeader := resp.Header.Get("WWW-Authenticate")
if authHeader == "" {
Expand Down Expand Up @@ -175,78 +137,3 @@ func getChallengeRequest(oriReq policy.Request) (*policy.Request, error) {
copied.Raw().Header.Del("Content-Type")
return copied, nil
}

type acquiringResourceState struct {
ctx context.Context

aadCredential azcore.TokenCredential
aadScopes []string

authClient *authenticationClient
service string
}

// acquireRefreshToken acquires or updates the refresh token of ACR service; only one thread/goroutine at a time ever calls this function
func acquireRefreshToken(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
// get AAD token from credential
aadToken, err := state.aadCredential.GetToken(
state.ctx,
policy.TokenRequestOptions{
Scopes: state.aadScopes,
},
)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// exchange refresh token with AAD token
refreshResp, err := state.authClient.ExchangeAADAccessTokenForACRRefreshToken(state.ctx, postContentSchemaGrantTypeAccessToken, state.service, &authenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &aadToken.Token,
})
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

refreshToken := azcore.AccessToken{
Token: *refreshResp.acrRefreshToken.RefreshToken,
}

// get refresh token expire time
refreshToken.ExpiresOn, err = getJWTExpireTime(*refreshResp.acrRefreshToken.RefreshToken)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// return refresh token
return refreshToken, refreshToken.ExpiresOn, nil
}

func getJWTExpireTime(token string) (time.Time, error) {
values := strings.Split(token, ".")
if len(values) > 2 {
value := values[1]
padding := len(value) % 4
if padding > 0 {
for i := 0; i < 4-padding; i++ {
value += "="
}
}
parsedValue, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return time.Time{}, err
}

var jsonValue *jwtOnlyWithExp
err = json.Unmarshal(parsedValue, &jsonValue)
if err != nil {
return time.Time{}, err
}
return time.Unix(jsonValue.Exp, 0), nil
}

return time.Time{}, errors.New("could not parse refresh token expire time")
}

type jwtOnlyWithExp struct {
Exp int64 `json:"exp"`
}
116 changes: 116 additions & 0 deletions sdk/containers/azcontainerregistry/authentication_token_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
package azcontainerregistry

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"strings"
"sync/atomic"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
)

Expand Down Expand Up @@ -39,3 +47,111 @@ func (c *authenticationTokenCache) Load() string {
}
return value
}

func (c *authenticationTokenCache) AcquireAccessToken(ctx context.Context, service, scope string) (string, error) {
// anonymous access
if c.cred == nil {
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, "", &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypePassword)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

// access with token
// get refresh token from cache/request
refreshToken, err := c.refreshTokenCache.Get(acquiringResourceState{
ctx: ctx,
aadCredential: c.cred,
aadScopes: c.aadScopes,
authClient: c.authClient,
service: service,
})
if err != nil {
return "", err
}

// get access token from request
resp, err := c.authClient.ExchangeACRRefreshTokenForACRAccessToken(ctx, service, scope, refreshToken.Token, &authenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{GrantType: to.Ptr(tokenGrantTypeRefreshToken)})
if err != nil {
return "", err
}
c.accessTokenCache.Store(*resp.acrAccessToken.AccessToken)
return *resp.acrAccessToken.AccessToken, nil
}

type acquiringResourceState struct {
ctx context.Context

aadCredential azcore.TokenCredential
aadScopes []string

authClient *authenticationClient
service string
}

// acquireRefreshToken acquires or updates the refresh token of ACR service; only one thread/goroutine at a time ever calls this function
func acquireRefreshToken(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
// get AAD token from credential
aadToken, err := state.aadCredential.GetToken(
state.ctx,
policy.TokenRequestOptions{
Scopes: state.aadScopes,
},
)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// exchange refresh token with AAD token
refreshResp, err := state.authClient.ExchangeAADAccessTokenForACRRefreshToken(state.ctx, postContentSchemaGrantTypeAccessToken, state.service, &authenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &aadToken.Token,
})
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

refreshToken := azcore.AccessToken{
Token: *refreshResp.acrRefreshToken.RefreshToken,
}

// get refresh token expire time
refreshToken.ExpiresOn, err = getJWTExpireTime(*refreshResp.acrRefreshToken.RefreshToken)
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}

// return refresh token
return refreshToken, refreshToken.ExpiresOn, nil
}

func getJWTExpireTime(token string) (time.Time, error) {
values := strings.Split(token, ".")
if len(values) > 2 {
value := values[1]
padding := len(value) % 4
if padding > 0 {
for i := 0; i < 4-padding; i++ {
value += "="
}
}
parsedValue, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return time.Time{}, err
}

var jsonValue *jwtOnlyWithExp
err = json.Unmarshal(parsedValue, &jsonValue)
if err != nil {
return time.Time{}, err
}
return time.Unix(jsonValue.Exp, 0), nil
}

return time.Time{}, errors.New("could not parse refresh token expire time")
}

type jwtOnlyWithExp struct {
Exp int64 `json:"exp"`
}

0 comments on commit 308795f

Please sign in to comment.