Skip to content

Commit

Permalink
OCM-12660 | feat: Refactor op + acc roles sharedvpc policies
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterkepley committed Nov 20, 2024
1 parent 188084e commit 1a07a4b
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 76 deletions.
9 changes: 5 additions & 4 deletions cmd/create/accountroles/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ func run(cmd *cobra.Command, argv []string) {
os.Exit(1)
}

if !cmd.Flag("hosted-cp").Changed {
rosa.HostedClusterOnlyFlag(r, cmd, route53RoleArnFlag)
rosa.HostedClusterOnlyFlag(r, cmd, vpcEndpointRoleArnFlag)
}
isHcpSharedVpc, err := validateSharedVpcInputs(args.hostedCP, args.vpcEndpointRoleArn, args.route53RoleArn)
if err != nil {
r.Reporter.Errorf("%s", err)
Expand Down Expand Up @@ -409,6 +405,11 @@ func run(cmd *cobra.Command, argv []string) {
isHostedCPValueSet = true
}

if !createHostedCP {
rosa.HostedClusterOnlyFlag(r, cmd, route53RoleArnFlag)
rosa.HostedClusterOnlyFlag(r, cmd, vpcEndpointRoleArnFlag)
}

rolesCreator, createRoles := initCreator(r, managedPolicies, createClassic, createHostedCP,
isClassicValueSet, isHostedCPValueSet)
if !createRoles {
Expand Down
10 changes: 5 additions & 5 deletions cmd/create/accountroles/creators.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (up *unmanagedPoliciesCreator) printCommands(r *rosa.Runtime, input *accoun

createPolicy := buildCreatePolicyCommand(policyName, policyDocument, iamTags, input.path)

policyARN := aws.GetPolicyARN(r.Creator.Partition, input.accountID, accRoleName, input.path)
policyARN := aws.GetPolicyArnWithSuffix(r.Creator.Partition, input.accountID, accRoleName, input.path)

attachRolePolicy := buildAttachRolePolicyCommand(accRoleName, policyARN)

Expand Down Expand Up @@ -279,7 +279,7 @@ func createRoleUnmanagedPolicy(r *rosa.Runtime, input *accountRolesCreationInput

policyPermissionDetail := aws.GetPolicyDetails(input.policies, filename)

policyARN := aws.GetPolicyARN(r.Creator.Partition, r.Creator.AccountID, accRoleName, input.path)
policyARN := aws.GetPolicyArnWithSuffix(r.Creator.Partition, r.Creator.AccountID, accRoleName, input.path)

r.Reporter.Debugf("Creating permission policy '%s'", policyARN)
if args.forcePolicyCreation {
Expand Down Expand Up @@ -467,9 +467,9 @@ func attachHcpSharedVpcPolicy(r *rosa.Runtime, sharedVpcRoleArn string, roleName
return err
}
policyName := fmt.Sprintf(aws.AssumeRolePolicyPrefix, userProvidedRoleName)
policyArn := aws.GetPolicyARN(r.Creator.Partition, r.Creator.AccountID, roleName, path)
policyArn, err = r.AWSClient.EnsurePolicyWithName(policyArn, policyDetails,
defaultPolicyVersion, policyTags, path, policyName)
policyArn := aws.GetPolicyArn(r.Creator.Partition, r.Creator.AccountID, policyName, path)
policyArn, err = r.AWSClient.EnsurePolicy(policyArn, policyDetails,
defaultPolicyVersion, policyTags, path)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/create/accountroles/creators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ var _ = Describe("Accountroles", Ordered, func() {
mockClient.EXPECT().AttachRolePolicy(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(6)
mockClient.EXPECT().EnsureRole(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), gomock.Any(), gomock.Any()).Return("arn::role:role-123", nil).AnyTimes()
mockClient.EXPECT().EnsurePolicyWithName(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), gomock.Any()).Return("arn::policy:123", nil).Times(2)
mockClient.EXPECT().EnsurePolicy(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any()).Return("arn::policy:123", nil).Times(2)

r := rosa.NewRuntime()
r.AWSClient = mockClient
Expand Down
4 changes: 2 additions & 2 deletions cmd/create/ocmrole/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func buildCommands(prefix string, roleName string, rolePath string, permissionsB
return "", err
}
} else {
policyARN = aws.GetPolicyARN(creator.Partition, creator.AccountID, roleName, rolePath)
policyARN = aws.GetPolicyArnWithSuffix(creator.Partition, creator.AccountID, roleName, rolePath)
}
attachRolePolicy := awscb.NewIAMCommandBuilder().
SetCommand(awscb.AttachRolePolicy).
Expand Down Expand Up @@ -448,7 +448,7 @@ func createRoles(r *rosa.Runtime, prefix string, roleName string, rolePath strin
return "", err
}
} else {
policyARN = aws.GetPolicyARN(r.Creator.Partition, r.Creator.AccountID, roleName, rolePath)
policyARN = aws.GetPolicyArnWithSuffix(r.Creator.Partition, r.Creator.AccountID, roleName, rolePath)
}
if !confirm.Prompt(true, "Create the '%s' role?", roleName) {
os.Exit(0)
Expand Down
27 changes: 17 additions & 10 deletions cmd/create/operatorroles/by_prefix.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ func createRolesByPrefix(r *rosa.Runtime, prefix string, permissionsBoundary str
}

var policyArn string
var policyArns []string
filename := aws.GetOperatorPolicyKey(credrequest, hostedCPPolicies, isSharedVpc)
if managedPolicies {
policyArn, err = aws.GetManagedPolicyARN(policies, filename)
Expand All @@ -312,16 +313,19 @@ func createRolesByPrefix(r *rosa.Runtime, prefix string, permissionsBoundary str
}
if isSharedVpc {
if credrequest == aws.IngressOperatorCloudCredentialsRoleType {
policyArn, err = getHcpSharedVpcPolicy(r, sharedVpcRoleArn, roleName, operator,
path, defaultPolicyVersion)
sharedVpcPolicyArn, err := getHcpSharedVpcPolicy(r, sharedVpcRoleArn, roleName, defaultPolicyVersion)
if err != nil {
return err
}
policyArns = append(policyArns, sharedVpcPolicyArn)
} else if credrequest == aws.ControlPlaneCloudCredentialsRoleType {
policyArn, err = getHcpSharedVpcPolicy(r, sharedVpcEndpointRoleArn, roleName,
operator, path, defaultPolicyVersion)
if err != nil {
return err
for _, arn := range []string{sharedVpcEndpointRoleArn, sharedVpcRoleArn} {
sharedVpcPolicyArn, err := getHcpSharedVpcPolicy(r, arn, path,
defaultPolicyVersion)
if err != nil {
return err
}
policyArns = append(policyArns, sharedVpcPolicyArn)
}
}
}
Expand Down Expand Up @@ -363,6 +367,7 @@ func createRolesByPrefix(r *rosa.Runtime, prefix string, permissionsBoundary str
}
}
}
policyArns = append(policyArns, policyArn)

policyDetails := aws.GetPolicyDetails(policies, "operator_iam_role_policy")
policy, err := aws.GenerateOperatorRolePolicyDocByOidcEndpointUrl(r.Creator.Partition, oidcEndpointUrl,
Expand Down Expand Up @@ -393,10 +398,12 @@ func createRolesByPrefix(r *rosa.Runtime, prefix string, permissionsBoundary str
r.Reporter.Infof("Created role '%s' with ARN '%s'", roleName, roleARN)
}

r.Reporter.Debugf("Attaching permission policy '%s' to role '%s'", policyArn, roleName)
err = r.AWSClient.AttachRolePolicy(r.Reporter, roleName, policyArn)
if err != nil {
return err
for _, arn := range policyArns {
r.Reporter.Debugf("Attaching permission policy '%s' to role '%s'", arn, roleName)
err = r.AWSClient.AttachRolePolicy(r.Reporter, roleName, arn)
if err != nil {
return err
}
}
}

Expand Down
32 changes: 18 additions & 14 deletions cmd/create/operatorroles/common_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"

awsCommonUtils "github.com/openshift-online/ocm-common/pkg/aws/utils"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
errors "github.com/zgalor/weberr"

"github.com/openshift/rosa/pkg/aws"
Expand All @@ -26,6 +25,9 @@ const policyDocumentBody = ` \
}
}'`

const policyDetails = "{\n \"Version\": \"2012-10-17\",\n \"Statement\": {\n \"Effect\": \"Allow\",\n " +
"\"Action\": \"sts:AssumeRole\",\n \"Resource\": \"%{shared_vpc_role_arn}\"\n }\n}\n"

func computePolicyARN(creator aws.Creator,
prefix string, namespace string, name string, path string) string {
if prefix == "" {
Expand Down Expand Up @@ -75,18 +77,19 @@ func validateIngressOperatorPolicyOverride(r *rosa.Runtime, policyArn string, sh
return nil
}

func getHcpSharedVpcPolicy(r *rosa.Runtime, roleArn string, roleName string,
operator *cmv1.STSOperator, path string, defaultPolicyVersion string) (string, error) {
policyDetails := "{\n \"Version\": \"2012-10-17\",\n \"Statement\": {\n \"Effect\": \"Allow\",\n " +
"\"Action\": \"sts:AssumeRole\",\n \"Resource\": \"%{shared_vpc_role_arn}\"\n }\n}\n"
policyDetails = aws.InterpolatePolicyDocument(r.Creator.Partition, policyDetails, map[string]string{
func getHcpSharedVpcPolicy(r *rosa.Runtime, roleArn string, path string, defaultPolicyVersion string) (string, error) {

interpolatedPolicyDetails := aws.InterpolatePolicyDocument(r.Creator.Partition, policyDetails, map[string]string{
"shared_vpc_role_arn": roleArn,
})
policy := aws.GetOperatorPolicyARN(r.Creator.Partition, r.Creator.AccountID,
aws.SharedVpcAssumeRolePrefix+"-"+roleName, operator.Namespace(), operator.Name(), path)
userProvidedRoleName, err := aws.GetResourceIdFromARN(roleArn)
if err != nil {
return "", err
}
policyName := fmt.Sprintf(aws.AssumeRolePolicyPrefix, userProvidedRoleName)
policy := aws.GetPolicyArn(r.Creator.Partition, r.Creator.AccountID, policyName, "")

var err error
policyArn, err := r.AWSClient.EnsurePolicy(policy, policyDetails, defaultPolicyVersion,
policyArn, err := r.AWSClient.EnsurePolicy(policy, interpolatedPolicyDetails, defaultPolicyVersion,
map[string]string{tags.RedHatManaged: helper.True}, path)
if err != nil {
return policyArn, err
Expand All @@ -96,16 +99,17 @@ func getHcpSharedVpcPolicy(r *rosa.Runtime, roleArn string, roleName string,

func getHcpSharedVpcPolicyDetails(r *rosa.Runtime, roleArn string, roleName string, iamTags map[string]string,
path string) (string, string) {
policyDetails := aws.InterpolatePolicyDocument(r.Creator.Partition, policyDocumentBody, map[string]string{
"shared_vpc_role_arn": roleArn,
})
interpolatedPolicyDetails := aws.InterpolatePolicyDocument(r.Creator.Partition, policyDocumentBody,
map[string]string{
"shared_vpc_role_arn": roleArn,
})

policyName := aws.SharedVpcAssumeRolePrefix + "-" + roleName

createPolicy := awscb.NewIAMCommandBuilder().
SetCommand(awscb.CreatePolicy).
AddParam(awscb.PolicyName, policyName).
AddParam(awscb.PolicyDocument, policyDetails).
AddParam(awscb.PolicyDocument, interpolatedPolicyDetails).
AddTags(iamTags).
AddParam(awscb.Path, path).
Build()
Expand Down
10 changes: 2 additions & 8 deletions cmd/create/operatorroles/common_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
errors "github.com/zgalor/weberr"

"github.com/openshift/rosa/pkg/aws"
Expand All @@ -25,7 +24,6 @@ var _ = Describe("Create dns domain", func() {
var testRoleName = "test"
var testIamTags = map[string]string{tags.RedHatManaged: aws.TrueString}
var testPath = "/path"
var testOperator *cmv1.STSOperator
var testVersion = "2012-10-17"
var mockClient *aws.MockClient

Expand All @@ -37,10 +35,6 @@ var _ = Describe("Create dns domain", func() {
runtime.AWSClient = mockClient
mockClient.EXPECT().GetCreator().Return(&aws.Creator{Partition: testPartition}, nil)

var err error
testOperator, err = cmv1.NewSTSOperator().Namespace("test").Namespace("test-namespace").Build()
Expect(err).ToNot(HaveOccurred())

creator, err := runtime.AWSClient.GetCreator()
Expect(err).ToNot(HaveOccurred())
runtime.Creator = creator
Expand All @@ -63,14 +57,14 @@ var _ = Describe("Create dns domain", func() {
returnedArn := "arn:aws:iam::123123123123:policy/test"
mockClient.EXPECT().EnsurePolicy(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any()).Return(returnedArn, nil)
arn, err := getHcpSharedVpcPolicy(runtime, testArn, testRoleName, testOperator, testPath, testVersion)
arn, err := getHcpSharedVpcPolicy(runtime, testArn, testPath, testVersion)
Expect(err).ToNot(HaveOccurred())
Expect(arn).To(Equal(returnedArn))
})
It("KO: Returns empty policy when fails", func() {
mockClient.EXPECT().EnsurePolicy(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any()).Return("", errors.UserErrorf("Failed"))
arn, err := getHcpSharedVpcPolicy(runtime, testArn, testRoleName, testOperator, testPath, testVersion)
arn, err := getHcpSharedVpcPolicy(runtime, testArn, testPath, testVersion)
Expect(err).To(HaveOccurred())
Expect(arn).To(Equal(""))
})
Expand Down
2 changes: 1 addition & 1 deletion cmd/dlt/operatorrole/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var args struct {

var Cmd = &cobra.Command{
Use: "operator-roles",
Aliases: []string{"operatorrole"},
Aliases: []string{"operatorrole", "operatorroles"},
Short: "Delete Operator Roles",
Long: "Cleans up operator roles of deleted STS cluster.",
Example: ` # Delete Operator roles for cluster named "mycluster"
Expand Down
4 changes: 2 additions & 2 deletions cmd/upgrade/accountroles/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func upgradeAccountRolePolicies(reporter *rprtr.Object, awsClient aws.Client, pr
continue
}
filename := fmt.Sprintf("sts_%s_permission_policy", file)
policyARN := aws.GetPolicyARN(partition, accountID, roleName, policyPath)
policyARN := aws.GetPolicyArnWithSuffix(partition, accountID, roleName, policyPath)

policyDetails := aws.GetPolicyDetails(policies, filename)
policyARN, err := awsClient.EnsurePolicy(policyARN, policyDetails,
Expand Down Expand Up @@ -333,7 +333,7 @@ func buildCommands(prefix string, partition string, accountID string, isUpgradeN
if isUpgradeNeedForAccountRolePolicies {
for file, role := range aws.AccountRoles {
accRoleName := common.GetRoleName(prefix, role.Name)
policyARN := aws.GetPolicyARN(partition, accountID, accRoleName, policyPath)
policyARN := aws.GetPolicyArnWithSuffix(partition, accountID, accRoleName, policyPath)
_, err := awsClient.IsPolicyExists(policyARN)
policyExists := err == nil
policyName := aws.GetPolicyName(accRoleName)
Expand Down
2 changes: 1 addition & 1 deletion cmd/upgrade/roles/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ func handleAccountRolePolicyARN(

attachedPoliciesDetail := aws.FindAllAttachedPolicyDetails(policiesDetails)

generatedPolicyARN := aws.GetPolicyARN(partition, accountID, roleName, rolePath)
generatedPolicyARN := aws.GetPolicyArnWithSuffix(partition, accountID, roleName, rolePath)
if len(attachedPoliciesDetail) == 0 {
return generatedPolicyARN, nil
}
Expand Down
2 changes: 0 additions & 2 deletions pkg/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ type Client interface {
path string) (string, error)
EnsurePolicy(policyArn string, document string, version string, tagList map[string]string,
path string) (string, error)
EnsurePolicyWithName(policyArn string, document string, version string, tagList map[string]string,
path string, policyName string) (string, error)
AttachRolePolicy(reporter *reporter.Object, roleName string, policyARN string) error
CreateOpenIDConnectProvider(issuerURL string, thumbprint string, clusterID string) (string, error)
DeleteOpenIDConnectProvider(providerURL string) error
Expand Down
15 changes: 0 additions & 15 deletions pkg/aws/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pkg/aws/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,14 @@ func GetAdminPolicyARN(partition string, accountID string, name string, path str
return getPolicyARN(partition, accountID, GetAdminPolicyName(name), path)
}

func GetPolicyARN(partition string, accountID string, name string, path string) string {
func GetPolicyArnWithSuffix(partition string, accountID string, name string, path string) string {
return getPolicyARN(partition, accountID, GetPolicyName(name), path)
}

func GetPolicyArn(partition string, accountID string, name string, path string) string {
return getPolicyARN(partition, accountID, name, path)
}

func getPolicyARN(partition string, accountID string, name string, path string) string {
str := fmt.Sprintf("arn:%s:iam::%s:policy", partition, accountID)
if path != "" {
Expand Down
5 changes: 0 additions & 5 deletions pkg/aws/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,6 @@ func (c *awsClient) EnsurePolicy(policyArn string, document string,
return c.ensurePolicyHelper(policyArn, document, version, tagList, path, false, "")
}

func (c *awsClient) EnsurePolicyWithName(policyArn string, document string,
version string, tagList map[string]string, path string, policyName string) (string, error) {
return c.ensurePolicyHelper(policyArn, document, version, tagList, path, false, policyName)
}

func (c *awsClient) ensurePolicyHelper(policyArn string, document string,
version string, tagList map[string]string, path string, force bool, policyName string) (string, error) {
output, err := c.IsPolicyExists(policyArn)
Expand Down
12 changes: 8 additions & 4 deletions pkg/rosa/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ import (
"github.com/spf13/cobra"
)

const hostedCpFlagName = "hosted-cp"

func HostedClusterOnlyFlag(r *Runtime, cmd *cobra.Command, flagName string) {
isFlagSet := cmd.Flags().Changed(flagName)
if isFlagSet {
r.Reporter.Errorf("Setting the `%s` flag is only supported for hosted clusters", flagName)
os.Exit(1)
if cmd.Flag(hostedCpFlagName) == nil || (cmd.Flag(hostedCpFlagName) != nil && !cmd.Flag(hostedCpFlagName).Changed) {
isFlagSet := cmd.Flags().Changed(flagName)
if isFlagSet {
r.Reporter.Errorf("Setting the `%s` flag is only supported for hosted clusters", flagName)
os.Exit(1)
}
}
}

0 comments on commit 1a07a4b

Please sign in to comment.