From ddca85c9ce42bb6f738f85036a2ab2481d95b034 Mon Sep 17 00:00:00 2001 From: Guilherme Branco Date: Wed, 29 May 2024 11:26:39 -0300 Subject: [PATCH] OCM-8867 | fix: include prefix check for clusters running registered oidc configs --- .gitignore | 3 + cmd/list/operatorroles/cmd.go | 13 ++-- pkg/aws/client.go | 2 +- pkg/aws/client_mock.go | 9 ++- pkg/aws/policies.go | 15 ++-- pkg/aws/policies_test.go | 131 ++++++++++++++++++++++++++++++++++ 6 files changed, 158 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 5dc02f8151..5ab65cc310 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ docs/ cover.out /temp dist/ +# for VS Code debug sessions +**/__debug_bin* +**/ginkgo.report diff --git a/cmd/list/operatorroles/cmd.go b/cmd/list/operatorroles/cmd.go index 5071758734..1bddb2ab19 100644 --- a/cmd/list/operatorroles/cmd.go +++ b/cmd/list/operatorroles/cmd.go @@ -112,9 +112,10 @@ func run(cmd *cobra.Command, _ []string) { os.Exit(1) } clusterId = cluster.ID() + args.prefix = cluster.AWS().STS().OperatorRolePrefix() } - operatorsMap, err := r.AWSClient.ListOperatorRoles(args.version, clusterId) + operatorsMap, err := r.AWSClient.ListOperatorRoles(args.version, clusterId, args.prefix) prefixes := helper.MapKeys(operatorsMap) helper.SortStringRespectLength(prefixes) @@ -132,16 +133,18 @@ func run(cmd *cobra.Command, _ []string) { if args.version != "" { noOperatorRolesOutput = fmt.Sprintf("%s in version '%s'", noOperatorRolesOutput, args.version) } + if args.prefix != "" { + if _, ok := operatorsMap[args.prefix]; !ok { + r.Reporter.Infof("No operator roles available for prefix '%s'", args.prefix) + os.Exit(0) + } + } r.Reporter.Infof(noOperatorRolesOutput) os.Exit(0) } if output.HasFlag() { var resource interface{} = operatorsMap if args.prefix != "" { - if _, ok := operatorsMap[args.prefix]; !ok { - r.Reporter.Infof("No operator roles available for prefix '%s'", args.prefix) - os.Exit(0) - } resource = operatorsMap[args.prefix] } err = output.Print(resource) diff --git a/pkg/aws/client.go b/pkg/aws/client.go index 9fb543e1f1..ca9bdf492a 100644 --- a/pkg/aws/client.go +++ b/pkg/aws/client.go @@ -138,7 +138,7 @@ type Client interface { ListUserRoles() ([]Role, error) ListOCMRoles() ([]Role, error) ListAccountRoles(version string) ([]Role, error) - ListOperatorRoles(version string, clusterID string) (map[string][]OperatorRoleDetail, error) + ListOperatorRoles(version string, clusterID string, prefix string) (map[string][]OperatorRoleDetail, error) ListAttachedRolePolicies(roleName string) ([]string, error) ListOidcProviders(targetClusterId string, config *cmv1.OidcConfig) ([]OidcProviderOutput, error) GetRoleByARN(roleARN string) (iamtypes.Role, error) diff --git a/pkg/aws/client_mock.go b/pkg/aws/client_mock.go index bcca70a022..a4e87200a4 100644 --- a/pkg/aws/client_mock.go +++ b/pkg/aws/client_mock.go @@ -5,7 +5,6 @@ // // mockgen -source=client.go -package=aws -destination=client_mock.go // - // Package aws is a generated GoMock package. package aws @@ -1248,18 +1247,18 @@ func (mr *MockClientMockRecorder) ListOidcProviders(targetClusterId, config any) } // ListOperatorRoles mocks base method. -func (m *MockClient) ListOperatorRoles(version, clusterID string) (map[string][]OperatorRoleDetail, error) { +func (m *MockClient) ListOperatorRoles(version, clusterID, prefix string) (map[string][]OperatorRoleDetail, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListOperatorRoles", version, clusterID) + ret := m.ctrl.Call(m, "ListOperatorRoles", version, clusterID, prefix) ret0, _ := ret[0].(map[string][]OperatorRoleDetail) ret1, _ := ret[1].(error) return ret0, ret1 } // ListOperatorRoles indicates an expected call of ListOperatorRoles. -func (mr *MockClientMockRecorder) ListOperatorRoles(version, clusterID any) *gomock.Call { +func (mr *MockClientMockRecorder) ListOperatorRoles(version, clusterID, prefix any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperatorRoles", reflect.TypeOf((*MockClient)(nil).ListOperatorRoles), version, clusterID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperatorRoles", reflect.TypeOf((*MockClient)(nil).ListOperatorRoles), version, clusterID, prefix) } // ListSubnets mocks base method. diff --git a/pkg/aws/policies.go b/pkg/aws/policies.go index fb7747a3e4..8a646a74e5 100644 --- a/pkg/aws/policies.go +++ b/pkg/aws/policies.go @@ -870,8 +870,8 @@ func (c *awsClient) ListAccountRoles(version string) ([]Role, error) { return c.mapToAccountRoles(version, roles) } -func (c *awsClient) ListOperatorRoles(version string, - targetClusterId string) (map[string][]OperatorRoleDetail, error) { +func (c *awsClient) ListOperatorRoles(targetVersion string, + targetClusterId string, targetPrefix string) (map[string][]OperatorRoleDetail, error) { operatorMap := map[string][]OperatorRoleDetail{} roles, err := c.ListRoles() @@ -958,7 +958,7 @@ func (c *awsClient) ListOperatorRoles(version string, switch aws.ToString(tag.Key) { case common.OpenShiftVersion: tagValue := aws.ToString(tag.Value) - if version != "" && tagValue != version { + if targetVersion != "" && tagValue != targetVersion { skip = true break } @@ -977,8 +977,15 @@ func (c *awsClient) ListOperatorRoles(version string, for key, list := range operatorMap { if len(list) == 0 { emptyListKeys = append(emptyListKeys, key) - } else if targetClusterId != "" && list[0].ClusterID != targetClusterId { + continue + } + if targetClusterId != "" && list[0].ClusterID != "" && list[0].ClusterID != targetClusterId { + emptyListKeys = append(emptyListKeys, key) + continue + } + if targetPrefix != "" && targetPrefix != key { emptyListKeys = append(emptyListKeys, key) + continue } } for _, key := range emptyListKeys { diff --git a/pkg/aws/policies_test.go b/pkg/aws/policies_test.go index a194e18914..40bcbc7496 100644 --- a/pkg/aws/policies_test.go +++ b/pkg/aws/policies_test.go @@ -19,6 +19,137 @@ import ( "github.com/openshift/rosa/pkg/aws/tags" ) +var _ = Describe("ListOperatorRoles", func() { + var ( + client awsClient + mockIamAPI *mocks.MockIamApiClient + mockCtrl *gomock.Controller + ) + + BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + mockIamAPI = mocks.NewMockIamApiClient(mockCtrl) + client = awsClient{ + iamClient: mockIamAPI, + } + }) + + It("Retrieves by target version", func() { + mockIamAPI.EXPECT().ListRoles(gomock.Any(), gomock.Any()).Return( + &iam.ListRolesOutput{ + IsTruncated: false, + Roles: []iamtypes.Role{ + { + RoleName: aws.String("some-role-name-openshift"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListRoleTags(gomock.Any(), gomock.Any()).Return( + &iam.ListRoleTagsOutput{ + IsTruncated: false, + }, nil) + mockIamAPI.EXPECT().ListAttachedRolePolicies(gomock.Any(), gomock.Any()).Return( + &iam.ListAttachedRolePoliciesOutput{ + IsTruncated: false, + AttachedPolicies: []iamtypes.AttachedPolicy{ + { + PolicyName: aws.String("some-policy-name"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListPolicyTags(gomock.Any(), gomock.Any()).Return( + &iam.ListPolicyTagsOutput{ + IsTruncated: false, + Tags: []iamtypes.Tag{ + { + Key: aws.String(common.OpenShiftVersion), + Value: aws.String("4.13"), + }, + }, + }, nil) + roles, err := client.ListOperatorRoles("4.13", "", "") + Expect(err).ToNot(HaveOccurred()) + Expect(roles).To(HaveLen(1)) + }) + + It("Retrieves by target cluster ID", func() { + mockIamAPI.EXPECT().ListRoles(gomock.Any(), gomock.Any()).Return( + &iam.ListRolesOutput{ + IsTruncated: false, + Roles: []iamtypes.Role{ + { + RoleName: aws.String("some-role-name-openshift"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListRoleTags(gomock.Any(), gomock.Any()).Return( + &iam.ListRoleTagsOutput{ + IsTruncated: false, + Tags: []iamtypes.Tag{ + { + Key: aws.String(tags.ClusterID), + Value: aws.String("123"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListAttachedRolePolicies(gomock.Any(), gomock.Any()).Return( + &iam.ListAttachedRolePoliciesOutput{ + IsTruncated: false, + AttachedPolicies: []iamtypes.AttachedPolicy{ + { + PolicyName: aws.String("some-policy-name"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListPolicyTags(gomock.Any(), gomock.Any()).Return( + &iam.ListPolicyTagsOutput{ + IsTruncated: false, + Tags: []iamtypes.Tag{ + { + Key: aws.String(common.OpenShiftVersion), + Value: aws.String("4.13"), + }, + }, + }, nil) + roles, err := client.ListOperatorRoles("", "123", "") + Expect(err).ToNot(HaveOccurred()) + Expect(roles).To(HaveLen(1)) + }) + + It("Retrieves by target prefix", func() { + mockIamAPI.EXPECT().ListRoles(gomock.Any(), gomock.Any()).Return( + &iam.ListRolesOutput{ + IsTruncated: false, + Roles: []iamtypes.Role{ + { + RoleName: aws.String("some-role-name-openshift"), + }, + }, + }, nil) + mockIamAPI.EXPECT().ListRoleTags(gomock.Any(), gomock.Any()).Return( + &iam.ListRoleTagsOutput{ + IsTruncated: false, + Tags: []iamtypes.Tag{ + { + Key: aws.String(common.ManagedPolicies), + Value: aws.String("true"), + }}, + }, nil) + mockIamAPI.EXPECT().ListAttachedRolePolicies(gomock.Any(), gomock.Any()).Return( + &iam.ListAttachedRolePoliciesOutput{ + IsTruncated: false, + AttachedPolicies: []iamtypes.AttachedPolicy{ + { + PolicyName: aws.String("some-policy-name"), + }, + }, + }, nil) + roles, err := client.ListOperatorRoles("", "", "some-role-name") + Expect(err).ToNot(HaveOccurred()) + Expect(roles).To(HaveLen(1)) + }) +}) + var _ = Describe("mapToAccountRoles", func() { var ( client awsClient