diff --git a/go.mod b/go.mod index 14518e3653..192c671d91 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.6-0.20210604193023-d5e0c0615ace + github.com/stretchr/testify v1.9.0 github.com/zgalor/weberr v0.6.0 golang.org/x/crypto v0.22.0 golang.org/x/text v0.14.0 @@ -169,6 +170,7 @@ require ( github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.52.2 // indirect github.com/prometheus/procfs v0.13.0 // indirect @@ -185,7 +187,6 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/viper v1.18.2 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/valyala/fastjson v1.6.4 // indirect github.com/vincent-petithory/dataurl v1.0.0 // indirect diff --git a/pkg/cloud/services/ec2/instances.go b/pkg/cloud/services/ec2/instances.go index fa4522c566..495f4ae20c 100644 --- a/pkg/cloud/services/ec2/instances.go +++ b/pkg/cloud/services/ec2/instances.go @@ -619,17 +619,10 @@ func (s *Service) runInstance(role string, i *infrav1.Instance) (*infrav1.Instan resources := []string{ec2.ResourceTypeInstance, ec2.ResourceTypeVolume, ec2.ResourceTypeNetworkInterface} for _, r := range resources { spec := &ec2.TagSpecification{ResourceType: aws.String(r)} - - // We need to sort keys for tests to work - keys := make([]string, 0, len(i.Tags)) - for k := range i.Tags { - keys = append(keys, k) - } - sort.Strings(keys) - for _, key := range keys { + for tagKey, tagValue := range i.Tags { spec.Tags = append(spec.Tags, &ec2.Tag{ - Key: aws.String(key), - Value: aws.String(i.Tags[key]), + Key: aws.String(tagKey), + Value: aws.String(tagValue), }) } diff --git a/pkg/cloud/services/ec2/instances_test.go b/pkg/cloud/services/ec2/instances_test.go index 9d235fc00c..74cdb1f2b8 100644 --- a/pkg/cloud/services/ec2/instances_test.go +++ b/pkg/cloud/services/ec2/instances_test.go @@ -19,6 +19,7 @@ package ec2 import ( "context" "encoding/base64" + "fmt" "strings" "testing" @@ -30,6 +31,7 @@ import ( "github.com/google/go-cmp/cmp" . "github.com/onsi/gomega" "github.com/pkg/errors" + "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -45,6 +47,72 @@ import ( clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ) +type runInstancesInputMatcher struct { + runInstancesInput *ec2.RunInstancesInput +} + +func (m runInstancesInputMatcher) Matches(arg interface{}) bool { + riiArg, ok := arg.(*ec2.RunInstancesInput) + if !ok { + return false + } + + if *m.runInstancesInput.ImageId != *riiArg.ImageId { + return false + } + if *m.runInstancesInput.InstanceType != *riiArg.InstanceType { + return false + } + if *m.runInstancesInput.KeyName != *riiArg.KeyName { + return false + } + if *m.runInstancesInput.SubnetId != *riiArg.SubnetId { + return false + } + if *m.runInstancesInput.UserData != *riiArg.UserData { + return false + } + if *m.runInstancesInput.MaxCount != *riiArg.MaxCount { + return false + } + if *m.runInstancesInput.MinCount != *riiArg.MinCount { + return false + } + + if !assert.ElementsMatch(nil, m.runInstancesInput.SecurityGroupIds, riiArg.SecurityGroupIds) { + return false + } + + if len(m.runInstancesInput.TagSpecifications) != len(riiArg.TagSpecifications) { + return false + } + + for _, instanceTagSpec := range m.runInstancesInput.TagSpecifications { + found := false + for _, argTagSpec := range riiArg.TagSpecifications { + if *instanceTagSpec.ResourceType == *argTagSpec.ResourceType { + found = true + if !assert.ElementsMatch(nil, instanceTagSpec.Tags, argTagSpec.Tags) { + return false + } + } + } + if !found { + return false + } + } + + return true +} + +func (m runInstancesInputMatcher) String() string { + return fmt.Sprintf("has the same elements as %v", m.runInstancesInput) +} + +func RunInstancesInputEq(runInstancesInput *ec2.RunInstancesInput) gomock.Matcher { + return runInstancesInputMatcher{runInstancesInput} +} + func TestInstanceIfExists(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -671,7 +739,7 @@ func TestCreateInstance(t *testing.T) { }, }, nil) m. - RunInstancesWithContext(context.TODO(), &ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.2xlarge"), KeyName: aws.String("default"), @@ -757,7 +825,7 @@ func TestCreateInstance(t *testing.T) { UserData: aws.String(base64.StdEncoding.EncodeToString(userDataCompressed)), MaxCount: aws.Int64(1), MinCount: aws.Int64(1), - }).Return(&ec2.Reservation{ + })).Return(&ec2.Reservation{ Instances: []*ec2.Instance{ { State: &ec2.InstanceState{ @@ -919,7 +987,7 @@ func TestCreateInstance(t *testing.T) { }, }, nil) m. - RunInstancesWithContext(context.TODO(), &ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.2xlarge"), KeyName: aws.String("default"), @@ -1005,7 +1073,7 @@ func TestCreateInstance(t *testing.T) { UserData: aws.String(base64.StdEncoding.EncodeToString(userDataCompressed)), MaxCount: aws.Int64(1), MinCount: aws.Int64(1), - }).Return(&ec2.Reservation{ + })).Return(&ec2.Reservation{ Instances: []*ec2.Instance{ { State: &ec2.InstanceState{ @@ -3337,7 +3405,7 @@ func TestCreateInstance(t *testing.T) { }, expect: func(m *mocks.MockEC2APIMockRecorder) { m. // TODO: Restore these parameters, but with the tags as well - RunInstancesWithContext(context.TODO(), gomock.Eq(&ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.large"), KeyName: aws.String("default"), @@ -3546,7 +3614,7 @@ func TestCreateInstance(t *testing.T) { }, expect: func(m *mocks.MockEC2APIMockRecorder) { m. // TODO: Restore these parameters, but with the tags as well - RunInstancesWithContext(context.TODO(), gomock.Eq(&ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.large"), KeyName: aws.String("default"), @@ -3775,7 +3843,7 @@ func TestCreateInstance(t *testing.T) { }, }, nil) m. // TODO: Restore these parameters, but with the tags as well - RunInstancesWithContext(context.TODO(), gomock.Eq(&ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.large"), KeyName: aws.String("default"), @@ -3966,7 +4034,7 @@ func TestCreateInstance(t *testing.T) { }, expect: func(m *mocks.MockEC2APIMockRecorder) { m. // TODO: Restore these parameters, but with the tags as well - RunInstancesWithContext(context.TODO(), gomock.Eq(&ec2.RunInstancesInput{ + RunInstancesWithContext(context.TODO(), RunInstancesInputEq(&ec2.RunInstancesInput{ ImageId: aws.String("abc"), InstanceType: aws.String("m5.large"), KeyName: aws.String("default"),