Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func pathLogin(b *jwtAuthBackend) *framework.Path {
Type: framework.TypeString,
Description: "An optional token used to fetch group memberships specified by the distributed claim source in the jwt. This is supported only on Azure/Entra ID",
},
"force_fetch_groups": {
Type: framework.TypeBool,
Description: "If true, groups are fetched from Microsoft Graph API. This is supported only on Azure/Entra ID",
},
},

Operations: map[logical.Operation]framework.OperationHandler{
Expand Down Expand Up @@ -118,6 +122,8 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d

distClaimAccessToken := d.Get("distributed_claim_access_token").(string)

forceFetchGroups := d.Get("force_fetch_groups").(bool)

if len(role.TokenBoundCIDRs) > 0 {
if req.Connection == nil {
b.Logger().Warn("token bound CIDRs found but no connection information available for validation")
Expand Down Expand Up @@ -179,7 +185,7 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
}
}

alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, &accessTokenSrc{accessToken: distClaimAccessToken})
alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, &accessTokenSrc{accessToken: distClaimAccessToken}, forceFetchGroups)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
Expand Down Expand Up @@ -234,7 +240,7 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques

// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, roleName string, role *jwtRole, tokenSource oauth2.TokenSource) (*logical.Alias, []*logical.Alias, error) {
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, roleName string, role *jwtRole, tokenSource oauth2.TokenSource, forceFetchGroups bool) (*logical.Alias, []*logical.Alias, error) {
var userClaimRaw interface{}
if role.UserClaimJSONPointer {
userClaimRaw = getClaim(b.Logger(), allClaims, role.UserClaim)
Expand Down Expand Up @@ -277,7 +283,7 @@ func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[strin
return alias, groupAliases, nil
}

groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role, tokenSource)
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role, tokenSource, forceFetchGroups)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
}
Expand Down Expand Up @@ -316,12 +322,12 @@ func (b *jwtAuthBackend) fetchUserInfo(ctx context.Context, pConfig CustomProvid
}

// Checks if there's a custom provider_config and calls FetchGroups() if implemented
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource, forceFetchGroups bool) (interface{}, error) {
// If the custom provider implements interface GroupsFetcher, call it,
// otherwise fall through to the default method
if pConfig != nil {
if gf, ok := pConfig.(GroupsFetcher); ok {
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role, tokenSource)
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role, tokenSource, forceFetchGroups)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
}
}

alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, tokenSource)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, tokenSource, false)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
Expand Down
107 changes: 91 additions & 16 deletions provider_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ import (
"encoding/json"
"errors"
"fmt"
log "github.com/hashicorp/go-hclog"
"golang.org/x/oauth2"
"io/ioutil"
"net/http"
"net/url"
"strings"

log "github.com/hashicorp/go-hclog"
"golang.org/x/oauth2"
)

const (
Expand Down Expand Up @@ -49,29 +48,50 @@ func (a *AzureProvider) SensitiveKeys() []string {
}

// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
// If the "groups" claim is missing, it might be because the user is a
// member of more than 200 groups, which means the token contains
// distributed claim information. Attempt to look that up here.
azureClaimSourcesURL, err := a.getClaimSource(b.Logger(), allClaims, role)
if err != nil {
return nil, fmt.Errorf("unable to get claim sources: %s", err)
}
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource, forceFetchGroups bool) (interface{}, error) {
var groupsClaimRaw interface{}
if forceFetchGroups == true {
azureClaimSourcesURL := "https://graph.microsoft.com/v1.0/me/memberOf"

var err error
a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM)
if err != nil {
return nil, fmt.Errorf("unable to create CA Context: %s", err)
}

azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, tokenSource)
azureGroups, err := a.getAzureGroupsFromMeCall(azureClaimSourcesURL, tokenSource)

if err != nil {
return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err)
return nil, fmt.Errorf("Unable to fetch groups from graph API: %v", err)
}
groupsClaimRaw = azureGroups
} else {
groupsClaimRaw = getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
// If the "groups" claim is missing, it might be because the user is a
// member of more than 200 groups, which means the token contains
// distributed claim information. Attempt to look that up here.
azureClaimSourcesURL, err := a.getClaimSource(b.Logger(), allClaims, role)
if err != nil {
return nil, fmt.Errorf("unable to get claim sources: %s", err)
}
//TODO: check any version support or any other restrictions for API https://graph.microsoft.com/v1.0/me/memberOf"

a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM)
if err != nil {
return nil, fmt.Errorf("unable to create CA Context: %s", err)
}

azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, tokenSource)

if err != nil {
return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err)
}
groupsClaimRaw = azureGroups
}
}

b.Logger().Debug(fmt.Sprintf("groups claim raw is %v", groupsClaimRaw))
return groupsClaimRaw, nil
}
Expand Down Expand Up @@ -179,6 +199,61 @@ func (a *AzureProvider) getAzureGroups(groupsURL string, tokenSource oauth2.Toke
return target.Value, nil
}

// Fetch user groups from the Microsoft Graph API /me/memberOf
func (a *AzureProvider) getAzureGroupsFromMeCall(groupsURL string, tokenSource oauth2.TokenSource) (interface{}, error) {
// Use the Access Token that was pre-negotiated between the Claims Provider and RP
// via https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims.
if tokenSource == nil {
return nil, errors.New("token unavailable to call Microsoft Graph API")
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("unable to get token: %s", err)
}
//payload := strings.NewReader("{\"securityEnabledOnly\": false}")
req, err := http.NewRequest("GET", groupsURL, nil)
if err != nil {
return nil, fmt.Errorf("error constructing groups endpoint request: %s", err)
}
token.SetAuthHeader(req)

client := http.DefaultClient
if c, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
client = c
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to call Microsoft Graph API: %s", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("failed to read Microsoft Graph API response: %s", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get groups: %s", string(body))
}

var resp response
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("unabled to decode response: %s", err)
}

var target azureGroups
for _, group := range resp.Value {
target.Value = append(target.Value, group.ID)
}

return target.Value, nil
}

type azureGroups struct {
Value []interface{} `json:"value"`
}

type groupObject struct {
ID string `json:"id"`
}
type response struct {
Value []groupObject `json:"value"`
}
2 changes: 1 addition & 1 deletion provider_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestLogin_fetchGroups(t *testing.T) {

// Ensure groups are as expected
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"})
groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource)
groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource, false)
assert.NoError(t, err)
assert.Equal(t, []interface{}{"group1", "group2"}, groupsResp)
}
Expand Down
2 changes: 1 addition & 1 deletion provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ type UserInfoFetcher interface {
// GroupsFetcher - Optional support for custom groups handling
type GroupsFetcher interface {
// FetchGroups queries for groups claims during login
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole, oauth2.TokenSource) (interface{}, error)
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole, oauth2.TokenSource, bool) (interface{}, error)
}
2 changes: 1 addition & 1 deletion provider_gsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (g *GSuiteProvider) SensitiveKeys() []string {
}

// FetchGroups fetches and returns groups from G Suite.
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource) (interface{}, error) {
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource, _ bool) (interface{}, error) {
if !g.config.FetchGroups {
return nil, nil
}
Expand Down
6 changes: 3 additions & 3 deletions provider_gsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestGSuiteProvider_FetchGroups(t *testing.T) {
UserClaim: "sub",
GroupsClaim: "groups",
}
groupsRaw, err := gProvider.FetchGroups(ctx, b.(*jwtAuthBackend), allClaims, role, nil)
groupsRaw, err := gProvider.FetchGroups(ctx, b.(*jwtAuthBackend), allClaims, role, nil, false)
assert.NoError(t, err)

// Assert that groups are as expected
Expand Down Expand Up @@ -562,7 +562,7 @@ func TestGSuiteProvider_validateBoundClaims(t *testing.T) {
provider.adminSvc.BasePath = gServer.URL

// Fetch the groups
_, err = b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, jwtRole, nil)
_, err = b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, jwtRole, nil, false)
assert.NoError(t, err)

// Fetch the user info
Expand Down Expand Up @@ -622,6 +622,6 @@ func TestGSuiteProvider_domain(t *testing.T) {

// Fetch the groups
claims := map[string]interface{}{"email": "user1@example.com"}
_, err = b.(*jwtAuthBackend).fetchGroups(ctx, provider, claims, jwtRole, nil)
_, err = b.(*jwtAuthBackend).fetchGroups(ctx, provider, claims, jwtRole, nil, false)
assert.NoError(t, err)
}
2 changes: 1 addition & 1 deletion provider_ibmisam.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (a *IBMISAMProvider) SensitiveKeys() []string {
// FetchGroups - custom groups fetching for ibmisam - satisfying GroupsFetcher interface
// IBMISAM by default will return groups not as a json list but as a list of space seperated strings
// We need to convert this to a json list
func (a *IBMISAMProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource) (interface{}, error) {
func (a *IBMISAMProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource, _ bool) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw != nil {
Expand Down
2 changes: 1 addition & 1 deletion provider_ibmisam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestLogin_ibmisam_fetchGroups(t *testing.T) {

// Ensure groups are as expected
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"})
groupsRaw, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource)
groupsRaw, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource, false)
assert.NoError(t, err)

groupsResp, ok := normalizeList(groupsRaw)
Expand Down
2 changes: 1 addition & 1 deletion provider_secureauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (a *SecureAuthProvider) SensitiveKeys() []string {
// FetchGroups - custom groups fetching for secureauth - satisfying GroupsFetcher interface
// SecureAuth by default will return groups not as a json list but as a list of comma seperated strings
// We need to convert this to a json list
func (a *SecureAuthProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource) (interface{}, error) {
func (a *SecureAuthProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource, _ bool) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw != nil {
Expand Down
2 changes: 1 addition & 1 deletion provider_secureauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestLogin_secureauth_fetchGroups(t *testing.T) {

// Ensure groups are as expected
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"})
groupsRaw, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource)
groupsRaw, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource, false)
assert.NoError(t, err)

groupsResp, ok := normalizeList(groupsRaw)
Expand Down
Loading