Skip to content

Commit

Permalink
[keyvault] add cae support (#23543)
Browse files Browse the repository at this point in the history
  • Loading branch information
gracewilcox authored Oct 18, 2024
1 parent 6055ff7 commit 4a7ae27
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 17 deletions.
9 changes: 2 additions & 7 deletions sdk/security/keyvault/internal/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Release History

## 1.0.2 (Unreleased)
## 1.1.0 (2024-10-21)

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
* Added CAE support

## 1.0.1 (2024-04-09)

Expand Down
196 changes: 196 additions & 0 deletions sdk/security/keyvault/internal/challenge_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package internal

import (
"bytes"
"context"
"fmt"
"net/http"
Expand All @@ -17,11 +18,29 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
)

const (
challengedToken = "needs more claims"
claimsToken = "all the claims"
kvChallenge = `Bearer authorization="https://login.microsoftonline.com/tenant", resource="https://vault.azure.net"`
caeChallenge1 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzE="`
caeChallenge2 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzI="`
)

// requireToken is a mock.Response predicate that checks a request for the expected token
var requireToken = func(t *testing.T, want string) func(req *http.Request) bool {
return func(r *http.Request) bool {
_, actual, _ := strings.Cut(r.Header.Get("Authorization"), " ")
require.Equal(t, want, actual)
return true
}
}

type credentialFunc func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error)

func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
Expand Down Expand Up @@ -100,6 +119,183 @@ func TestChallengePolicy(t *testing.T) {
}
}

func TestChallengePolicy_CAE(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", kvChallenge),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, "")),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse()

srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", caeChallenge1),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, challengedToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse(
mock.WithPredicate(requireToken(t, claimsToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one

tkReqs := 0
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
require.True(t, tro.EnableCAE)
tkReqs += 1
tk := challengedToken
switch tkReqs {
case 1:
require.Empty(t, tro.Claims)
case 2:
tk = claimsToken
require.Equal(t, "testing1", tro.Claims)
default:
t.Fatal("unexpected token request")
}
return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewKeyVaultChallengePolicy(cred, nil)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)

// req 1 kv then regular
req, err := runtime.NewRequest(context.Background(), "POST", "https://42.vault.azure.net")
require.NoError(t, err)
err = req.SetBody(streaming.NopCloser(bytes.NewReader([]byte("test"))), "text/plain")
require.NoError(t, err)
res, err := pl.Do(req)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
require.Equal(t, 1, tkReqs)

// req 2 cae
req, err = runtime.NewRequest(context.Background(), "POST", "https://42.vault.azure.net")
require.NoError(t, err)
err = req.SetBody(streaming.NopCloser(bytes.NewReader([]byte("test2"))), "text/plain")
require.NoError(t, err)
res, err = pl.Do(req)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
require.Equal(t, 2, tkReqs)
}

func TestChallengePolicy_KVThenCAE(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", kvChallenge),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, "")),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", caeChallenge1),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, challengedToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse(
mock.WithPredicate(requireToken(t, claimsToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one

tkReqs := 0
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
require.True(t, tro.EnableCAE)
tkReqs += 1
tk := challengedToken
switch tkReqs {
case 1:
require.Empty(t, tro.Claims)
case 2:
tk = claimsToken
require.Equal(t, "testing1", tro.Claims)
default:
t.Fatal("unexpected token request")
}
return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewKeyVaultChallengePolicy(cred, nil)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)
req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net")
require.NoError(t, err)
res, err := pl.Do(req)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
require.Equal(t, tkReqs, 2)
}

func TestChallengePolicy_TwoCAEChallenges(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", kvChallenge),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, "")),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse()

srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", caeChallenge1),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, challengedToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", caeChallenge2),
mock.WithStatusCode(401),
mock.WithPredicate(requireToken(t, claimsToken)),
)
srv.AppendResponse() // when a response's predicate returns true, srv pops the following one
tkReqs := 0
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
require.True(t, tro.EnableCAE)
tk := challengedToken
tkReqs += 1
switch tkReqs {
case 1:
require.Empty(t, tro.Claims)
case 2:
tk = claimsToken
require.Equal(t, "testing1", tro.Claims)
default:
t.Fatal("unexpected token request")
}
return azcore.AccessToken{Token: tk, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewKeyVaultChallengePolicy(cred, nil)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)

// req 1 kv then regular
req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net")
require.NoError(t, err)
res, err := pl.Do(req)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
require.Equal(t, tkReqs, 1)

// req 2 cae twice
req, err = runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net")
require.NoError(t, err)
res, err = pl.Do(req)
require.NoError(t, err)
require.Equal(t, 401, res.StatusCode)
require.Equal(t, caeChallenge2, res.Header.Get("WWW-Authenticate"))
require.Equal(t, tkReqs, 2)
}

func TestParseTenant(t *testing.T) {
actual := parseTenant("")
require.Empty(t, actual)
Expand Down
2 changes: 1 addition & 1 deletion sdk/security/keyvault/internal/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
package internal

const (
version = "v1.0.2" //nolint
version = "v1.1.0" //nolint
)
6 changes: 3 additions & 3 deletions sdk/security/keyvault/internal/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal
go 1.18

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/text v0.19.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
12 changes: 6 additions & 6 deletions sdk/security/keyvault/internal/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 h1:JZg6HRh6W6U4OLl6lk7BZ7BLisIzM9dG1R50zUk9C/M=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0/go.mod h1:YL1xnZ6QejvQHWJrX/AvhFl4WW4rqHVoKspWNVwFk0M=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand All @@ -11,10 +11,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down

0 comments on commit 4a7ae27

Please sign in to comment.