Skip to content

Commit

Permalink
Split the getClientToken invoke handler into smaller methods and add …
Browse files Browse the repository at this point in the history
…unit tests
  • Loading branch information
thomas11 committed Oct 17, 2024
1 parent 0a48980 commit 81928c4
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 25 deletions.
64 changes: 39 additions & 25 deletions provider/pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,31 +301,9 @@ func (k *azureNativeProvider) Invoke(ctx context.Context, req *rpc.InvokeRequest
if err != nil {
return nil, fmt.Errorf("getting auth config: %w", err)
}
endpoint := k.environment.ResourceManagerEndpoint
if endpointArg := args["endpoint"]; endpointArg.HasValue() && endpointArg.IsString() {
endpoint = endpointArg.StringValue()
}

var token string
if useLegacyAuth() {
token, err = k.getOAuthToken(ctx, auth, endpoint)
if err != nil {
return nil, err
}
} else {
cred, err := k.newTokenCredential()
if err != nil {
return nil, err
}
t, err := cred.GetToken(ctx, policy.TokenRequestOptions{
// .default is the well-defined scope for all resources accessible to the user or application.
// https://learn.microsoft.com/en-us/entra/identity-platform/scopes-oidc#the-default-scope
Scopes: []string{endpoint + "/.default"},
})
if err != nil {
return nil, err
}
token = t.Token
token, err := k.getClientToken(ctx, auth, args["endpoint"])
if err != nil {
return nil, err
}
outputs = map[string]interface{}{"token": token}
default:
Expand Down Expand Up @@ -391,6 +369,42 @@ func (k *azureNativeProvider) Invoke(ctx context.Context, req *rpc.InvokeRequest
return &rpc.InvokeResponse{Return: result}, nil
}

func (k *azureNativeProvider) getClientToken(ctx context.Context, authConfig *authConfig, endpointArg resource.PropertyValue) (string, error) {
endpoint := k.tokenEndpoint(endpointArg)

if useLegacyAuth() {
return k.getOAuthToken(ctx, authConfig, endpoint)
}

cred, err := k.newTokenCredential()
if err != nil {
return "", err
}
t, err := cred.GetToken(ctx, tokenRequestOpts(endpoint))
if err != nil {
return "", err
}
return t.Token, nil
}

// Returns the Azure endpoint where tokens can be requested. If the argument is not null or empty,
// it will be used verbatim.
func (k *azureNativeProvider) tokenEndpoint(endpointArg resource.PropertyValue) string {
if endpointArg.HasValue() && endpointArg.IsString() && endpointArg.StringValue() != "" {
return endpointArg.StringValue()
}
return k.environment.ResourceManagerEndpoint
}

func tokenRequestOpts(endpoint string) policy.TokenRequestOptions {
return policy.TokenRequestOptions{
// "".default" is the well-defined scope for all resources accessible to the user or
// application. Despite the URL, it doesn't apply only to OIDC.
// https://learn.microsoft.com/en-us/entra/identity-platform/scopes-oidc#the-default-scope
Scopes: []string{endpoint + "/.default"},
}
}

func (k *azureNativeProvider) invokeResponseToOutputs(response any, res resources.AzureAPIInvoke) map[string]any {
if responseMap, ok := response.(map[string]any); ok {
// Map the raw response to the shape of outputs that the SDKs expect.
Expand Down
48 changes: 48 additions & 0 deletions provider/pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/fake"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/convert"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/provider/crud"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/resources"
Expand Down Expand Up @@ -466,3 +467,50 @@ func (m *mockAzureClient) Put(ctx context.Context, id string, bodyProps map[stri
func (m *mockAzureClient) IsNotFound(err error) bool {
return false
}

func TestGetTokenEndpoint(t *testing.T) {
t.Parallel()

t.Run("explicit", func(t *testing.T) {
t.Parallel()
p := azureNativeProvider{}
endpoint := p.tokenEndpoint(resource.NewStringProperty("https://management.azure.com/"))
assert.Equal(t, "https://management.azure.com/", endpoint)
})

t.Run("implicit public", func(t *testing.T) {
t.Parallel()
p := azureNativeProvider{
environment: azure.PublicCloud,
}
endpoint := p.tokenEndpoint(resource.NewNullProperty())
assert.Equal(t, "https://management.azure.com/", endpoint)
})

t.Run("implicit usgov", func(t *testing.T) {
t.Parallel()
p := azureNativeProvider{
environment: azure.USGovernmentCloud,
}
endpoint := p.tokenEndpoint(resource.NewNullProperty())
assert.Equal(t, "https://management.usgovcloudapi.net/", endpoint)
})

t.Run("implicit with empty string, public", func(t *testing.T) {
t.Parallel()
p := azureNativeProvider{
environment: azure.PublicCloud,
}
endpoint := p.tokenEndpoint(resource.NewStringProperty(""))
assert.Equal(t, "https://management.azure.com/", endpoint)
})
}

func TestGetTokenRequestOpts(t *testing.T) {
t.Parallel()

opts := tokenRequestOpts("http://endpoint")
assert.Empty(t, opts.Claims)
assert.Empty(t, opts.TenantID)
assert.Equal(t, []string{"http://endpoint/.default"}, opts.Scopes)
}

0 comments on commit 81928c4

Please sign in to comment.