diff --git a/sdk/security/keyvault/internal/CHANGELOG.md b/sdk/security/keyvault/internal/CHANGELOG.md index 94a36cf13b3e..c55b2eb6be47 100644 --- a/sdk/security/keyvault/internal/CHANGELOG.md +++ b/sdk/security/keyvault/internal/CHANGELOG.md @@ -1,14 +1,9 @@ # Release History -## 1.0.2 (Unreleased) +## 1.1.0 (2024-10-21) ### Features Added - -### Breaking Changes - -### Bugs Fixed - -### Other Changes +* Added CAE support ## 1.0.1 (2024-04-09) diff --git a/sdk/security/keyvault/internal/challenge_policy_test.go b/sdk/security/keyvault/internal/challenge_policy_test.go index 0529b0a73c8e..786e919550aa 100644 --- a/sdk/security/keyvault/internal/challenge_policy_test.go +++ b/sdk/security/keyvault/internal/challenge_policy_test.go @@ -7,6 +7,7 @@ package internal import ( + "bytes" "context" "fmt" "net/http" @@ -17,11 +18,29 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" ) +const ( + challengedToken = "needs more claims" + claimsToken = "all the claims" + kvChallenge = `Bearer authorization="https://login.microsoftonline.com/tenant", resource="https://vault.azure.net"` + caeChallenge1 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzE="` + caeChallenge2 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzI="` +) + +// requireToken is a mock.Response predicate that checks a request for the expected token +var requireToken = func(t *testing.T, want string) func(req *http.Request) bool { + return func(r *http.Request) bool { + _, actual, _ := strings.Cut(r.Header.Get("Authorization"), " ") + require.Equal(t, want, actual) + return true + } +} + type credentialFunc func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { @@ -100,6 +119,183 @@ func TestChallengePolicy(t *testing.T) { } } +func TestChallengePolicy_CAE(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, "")), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse() + + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, challengedToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse( + mock.WithPredicate(requireToken(t, claimsToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tkReqs += 1 + tk := challengedToken + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + tk = claimsToken + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + + // req 1 kv then regular + req, err := runtime.NewRequest(context.Background(), "POST", "https://42.vault.azure.net") + require.NoError(t, err) + err = req.SetBody(streaming.NopCloser(bytes.NewReader([]byte("test"))), "text/plain") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, 1, tkReqs) + + // req 2 cae + req, err = runtime.NewRequest(context.Background(), "POST", "https://42.vault.azure.net") + require.NoError(t, err) + err = req.SetBody(streaming.NopCloser(bytes.NewReader([]byte("test2"))), "text/plain") + require.NoError(t, err) + res, err = pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, 2, tkReqs) +} + +func TestChallengePolicy_KVThenCAE(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, "")), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, challengedToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse( + mock.WithPredicate(requireToken(t, claimsToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tkReqs += 1 + tk := challengedToken + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + tk = claimsToken + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, tkReqs, 2) +} + +func TestChallengePolicy_TwoCAEChallenges(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, "")), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse() + + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, challengedToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge2), + mock.WithStatusCode(401), + mock.WithPredicate(requireToken(t, claimsToken)), + ) + srv.AppendResponse() // when a response's predicate returns true, srv pops the following one + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tk := challengedToken + tkReqs += 1 + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + tk = claimsToken + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + + // req 1 kv then regular + req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, tkReqs, 1) + + // req 2 cae twice + req, err = runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err = pl.Do(req) + require.NoError(t, err) + require.Equal(t, 401, res.StatusCode) + require.Equal(t, caeChallenge2, res.Header.Get("WWW-Authenticate")) + require.Equal(t, tkReqs, 2) +} + func TestParseTenant(t *testing.T) { actual := parseTenant("") require.Empty(t, actual) diff --git a/sdk/security/keyvault/internal/constants.go b/sdk/security/keyvault/internal/constants.go index 5c00886d183e..5a037978fa06 100644 --- a/sdk/security/keyvault/internal/constants.go +++ b/sdk/security/keyvault/internal/constants.go @@ -7,5 +7,5 @@ package internal const ( - version = "v1.0.2" //nolint + version = "v1.1.0" //nolint ) diff --git a/sdk/security/keyvault/internal/go.mod b/sdk/security/keyvault/internal/go.mod index 59ddc4d24a2e..ee021ab88053 100644 --- a/sdk/security/keyvault/internal/go.mod +++ b/sdk/security/keyvault/internal/go.mod @@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal go 1.18 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 github.com/stretchr/testify v1.9.0 ) @@ -11,7 +11,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/net v0.30.0 // indirect + golang.org/x/text v0.19.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/sdk/security/keyvault/internal/go.sum b/sdk/security/keyvault/internal/go.sum index 7da3690ef0b5..6d4b60840d12 100644 --- a/sdk/security/keyvault/internal/go.sum +++ b/sdk/security/keyvault/internal/go.sum @@ -1,5 +1,5 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 h1:JZg6HRh6W6U4OLl6lk7BZ7BLisIzM9dG1R50zUk9C/M= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0/go.mod h1:YL1xnZ6QejvQHWJrX/AvhFl4WW4rqHVoKspWNVwFk0M= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -11,10 +11,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=