Skip to content

Commit

Permalink
Enforce that payer accounts cannot be assigned to individual users
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexVulaj committed Jul 11, 2023
1 parent 4201665 commit 88b66ea
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 55 deletions.
2 changes: 1 addition & 1 deletion cmd/account/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/service/sts"
osdCloud "github.com/openshift/osdctl/pkg/osdCloud"
"github.com/openshift/osdctl/pkg/osdCloud"
"github.com/openshift/osdctl/pkg/provider/aws"
"github.com/openshift/osdctl/pkg/utils"
"github.com/spf13/cobra"
Expand Down
32 changes: 20 additions & 12 deletions cmd/account/mgmt/account-assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package mgmt

import (
"fmt"

"github.com/aws/aws-sdk-go/service/sts"
"math/rand"
"time"

Expand Down Expand Up @@ -214,8 +214,18 @@ func (o *accountAssignOptions) findUntaggedAccount(rootOu string) (string, error
return "", ErrNoUntaggedAccounts
}

identity, err := o.awsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}

// Loop through accounts and check that it's untagged and assign ID to user
for _, a := range accounts.Accounts {
if *a.Id == *identity.Account {
// Don't allow the payer account to be assigned to an individual user
continue
}

isOwned, err := isOwned(*a.Id, &o.awsClient)
if err != nil {
return "", err
Expand Down Expand Up @@ -321,16 +331,14 @@ func (o *accountAssignOptions) buildAccount(seedVal int64) (string, error) {
return newAccountId, nil
}

var ErrAwsAccountLimitExceeded error = fmt.Errorf("ErrAwsAccountLimitExceeded")
var ErrEmailAlreadyExist error = fmt.Errorf("ErrEmailAlreadyExist")
var ErrAwsInternalFailure error = fmt.Errorf("ErrAwsInternalFailure")
var ErrAwsTooManyRequests error = fmt.Errorf("ErrAwsTooManyRequests")
var ErrAwsFailedCreateAccount error = fmt.Errorf("ErrAwsFailedCreateAccount")
var ErrAwsAccountLimitExceeded = fmt.Errorf("ErrAwsAccountLimitExceeded")
var ErrEmailAlreadyExist = fmt.Errorf("ErrEmailAlreadyExist")
var ErrAwsInternalFailure = fmt.Errorf("ErrAwsInternalFailure")
var ErrAwsTooManyRequests = fmt.Errorf("ErrAwsTooManyRequests")
var ErrAwsFailedCreateAccount = fmt.Errorf("ErrAwsFailedCreateAccount")

func (o *accountAssignOptions) createAccount(seedVal int64) (*organizations.DescribeCreateAccountStatusOutput, error) {

rand.Seed(seedVal)
randStr := RandomString(6)
randStr := RandomString(rand.New(rand.NewSource(seedVal)), 6)
accountName := "osd-creds-mgmt+" + randStr
email := accountName + "@redhat.com"

Expand Down Expand Up @@ -382,12 +390,12 @@ func (o *accountAssignOptions) createAccount(seedVal int64) (*organizations.Desc
return accountStatus, nil
}

func RandomString(n int) string {
func RandomString(r *rand.Rand, length int) string {
var letters = []byte("abcdefghijklmnopqrstuvwxyz0123456789")

s := make([]byte, n)
s := make([]byte, length)
for i := range s {
s[i] = letters[rand.Intn(len(letters))] //#nosec G404 -- math/rand is not used for a secret here, hence it's okay
s[i] = letters[r.Intn(len(letters))] //#nosec G404 -- math/rand is not used for a secret here, hence it's okay
}
return string(s)
}
Expand Down
95 changes: 54 additions & 41 deletions cmd/account/mgmt/account-assign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mgmt

import (
"fmt"
"github.com/aws/aws-sdk-go/service/sts"
"math/rand"
"testing"

Expand Down Expand Up @@ -91,14 +92,16 @@ func TestFindUntaggedAccount(t *testing.T) {
var genericAWSError error = fmt.Errorf("Generic AWS error")

testData := []struct {
name string
accountsList []string
tags map[string]string
suspendCheck bool
accountStatus string
expectedAccountId string
expectErr error
expectedAWSError error
name string
accountsList []string
tags map[string]string
suspendCheck bool
accountStatus string
callerIdentityAccount string
expectedGetCallerIdentityErr error
expectedAccountId string
expectErr error
expectedListAccountsForParentErr error
}{
{
name: "test for untagged account present",
Expand All @@ -107,50 +110,54 @@ func TestFindUntaggedAccount(t *testing.T) {
tags: map[string]string{},
suspendCheck: true,
accountStatus: organizations.AccountStatusActive,
expectErr: nil,
expectedAWSError: nil,
},
{
name: "test for only partially tagged accounts present",
accountsList: []string{"111111111111"},
expectedAccountId: "",
name: "test for only payer account present",
accountsList: []string{"222222222222"},
callerIdentityAccount: "222222222222",
expectErr: ErrNoUntaggedAccounts,
},
{
name: "test for only partially tagged accounts present",
accountsList: []string{"111111111111"},
tags: map[string]string{
"claimed": "true",
},
suspendCheck: false,
expectErr: ErrNoUntaggedAccounts,
expectedAWSError: nil,
expectErr: ErrNoUntaggedAccounts,
},
{
name: "test for only tagged accounts present",
accountsList: []string{"111111111111"},
expectedAccountId: "",
name: "test for no untagged accounts present",
accountsList: []string{},
expectErr: ErrNoUntaggedAccounts,
},
{
name: "test for only tagged accounts present",
accountsList: []string{"111111111111"},
tags: map[string]string{
"owner": "randuser",
"claimed": "true",
},
suspendCheck: false,
expectErr: ErrNoUntaggedAccounts,
expectedAWSError: nil,
expectErr: ErrNoUntaggedAccounts,
},
{
name: "test for AWS list accounts error",
accountsList: []string{},
expectedAccountId: "",
tags: nil,
suspendCheck: false,
expectErr: genericAWSError,
expectedAWSError: genericAWSError,
name: "test for AWS list accounts error",
accountsList: []string{},
expectErr: genericAWSError,
expectedListAccountsForParentErr: genericAWSError,
},
{
name: "test for suspended account error",
accountsList: []string{"111111111111"},
expectedAccountId: "",
tags: map[string]string{},
suspendCheck: true,
accountStatus: organizations.AccountStatusSuspended,
expectErr: ErrNoUntaggedAccounts,
expectedAWSError: nil,
name: "test for AWS get caller identity error",
accountsList: []string{"111111111111"},
expectErr: genericAWSError,
expectedGetCallerIdentityErr: genericAWSError,
},
{
name: "test for suspended account error",
accountsList: []string{"111111111111"},
tags: map[string]string{},
suspendCheck: true,
accountStatus: organizations.AccountStatusSuspended,
expectErr: ErrNoUntaggedAccounts,
},
}

Expand Down Expand Up @@ -194,7 +201,7 @@ func TestFindUntaggedAccount(t *testing.T) {
ResourceId: &test.accountsList[0],
}).Return(
awsOutputTags,
test.expectedAWSError,
test.expectedListAccountsForParentErr,
)
}

Expand All @@ -215,9 +222,16 @@ func TestFindUntaggedAccount(t *testing.T) {

mockAWSClient.EXPECT().ListAccountsForParent(gomock.Any()).Return(
awsOutputAccounts,
test.expectedAWSError,
test.expectedListAccountsForParentErr,
)

if test.expectedListAccountsForParentErr == nil && len(test.accountsList) > 0 {
mockAWSClient.EXPECT().GetCallerIdentity(gomock.Any()).Return(
&sts.GetCallerIdentityOutput{Account: aws.String(test.callerIdentityAccount)},
test.expectedGetCallerIdentityErr,
)
}

returnValue, err := o.findUntaggedAccount(rootOuId)
if test.expectErr != err {
t.Errorf("expected error %s and got %s", test.expectErr, err)
Expand All @@ -235,8 +249,7 @@ func TestCreateAccount(t *testing.T) {
mockAWSClient := mock.NewMockClient(mocks.mockCtrl)

seed := int64(1)
rand.Seed(seed)
randStr := RandomString(6)
randStr := RandomString(rand.New(rand.NewSource(seed)), 6)
accountName := "osd-creds-mgmt+" + randStr
email := accountName + "@redhat.com"

Expand Down
1 change: 0 additions & 1 deletion cmd/account/servicequotas/describe.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"

//"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/servicequotas"

"github.com/spf13/cobra"
Expand Down

0 comments on commit 88b66ea

Please sign in to comment.