Skip to content

Commit

Permalink
Add custom vpc support in AWS cloud prepare
Browse files Browse the repository at this point in the history
Signed-off-by: Aswin Suryanarayanan <asuryana@redhat.com>
Signed-off-by: Tom Pantelis <tompantelis@gmail.com>
  • Loading branch information
aswinsuryan authored and tpantelis committed Sep 25, 2024
1 parent fa86cf1 commit e1aa569
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 103 deletions.
118 changes: 94 additions & 24 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/pkg/errors"
"github.com/submariner-io/admiral/pkg/reporter"
"github.com/submariner-io/cloud-prepare/pkg/api"
Expand All @@ -40,36 +41,84 @@ const (
messageValidatedPrerequisites = "Validated pre-requisites"
)

type CloudOption func(*awsCloud)

const (
ControlPlaneSecurityGroupIDKey = "controlPlaneSecurityGroupID"
WorkerSecurityGroupIDKey = "workerSecurityGroupID"
PublicSubnetListKey = "PublicSubnetList"
VPCIDKey = "VPCID"
)

func WithControlPlaneSecurityGroup(id string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[ControlPlaneSecurityGroupIDKey] = id
}
}

func WithWorkerSecurityGroup(id string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[WorkerSecurityGroupIDKey] = id
}
}

func WithPublicSubnetList(id []string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[PublicSubnetListKey] = id
}
}

func WithVPCName(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[VPCIDKey] = name
}
}

type awsCloud struct {
client awsClient.Interface
infraID string
region string
nodeSGSuffix string
controlPlaneSGSuffix string
cloudConfig map[string]interface{}
}

// NewCloud creates a new api.Cloud instance which can prepare AWS for Submariner to be deployed on it.
func NewCloud(client awsClient.Interface, infraID, region string) api.Cloud {
return &awsCloud{
client: client,
infraID: infraID,
region: region,
func NewCloud(client awsClient.Interface, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: client,
infraID: infraID,
region: region,
cloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromConfig creates a new api.Cloud instance based on an AWS configuration
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromConfig(cfg *aws.Config, infraID, region string) api.Cloud {
return &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
func NewCloudFromConfig(cfg *aws.Config, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
cloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromSettings creates a new api.Cloud instance using the given credentials file and profile
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api.Cloud, error) {
func NewCloudFromSettings(credentialsFile, profile, infraID, region string, opts ...CloudOption) (api.Cloud, error) {
options := []func(*config.LoadOptions) error{config.WithRegion(region), config.WithSharedConfigProfile(profile)}
if credentialsFile != DefaultCredentialsFile() {
options = append(options, config.WithSharedCredentialsFiles([]string{credentialsFile}))
Expand All @@ -80,7 +129,7 @@ func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api
return nil, errors.Wrap(err, "error loading default config")
}

return NewCloudFromConfig(&cfg, infraID, region), nil
return NewCloudFromConfig(&cfg, infraID, region, opts...), nil
}

// DefaultCredentialsFile returns the default credentials file name.
Expand All @@ -98,13 +147,30 @@ func (ac *awsCloud) setSuffixes(vpcID string) error {
return nil
}

publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}
var publicSubnets []types.Subnet

if subnets, exists := ac.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := ac.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}

if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
}
}

pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region))
Expand Down Expand Up @@ -137,9 +203,11 @@ func (ac *awsCloud) OpenPorts(ports []api.PortSpec, status reporter.Interface) e
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.cloudConfig[VPCIDKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down Expand Up @@ -180,9 +248,11 @@ func (ac *awsCloud) ClosePorts(status reporter.Interface) error {
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.cloudConfig[VPCIDKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down
2 changes: 0 additions & 2 deletions pkg/aws/aws_cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func testOpenPorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)

retError = t.cloud.OpenPorts([]api.PortSpec{
Expand Down Expand Up @@ -118,7 +117,6 @@ func testClosePorts() {
JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

retError = t.cloud.ClosePorts(reporter.Stdout())
Expand Down
21 changes: 3 additions & 18 deletions pkg/aws/aws_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ const (
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
providerAWSTagPrefix = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID
clusterFilterTagNameSigs = providerAWSTagPrefix + infraID
)

var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic)
Expand Down Expand Up @@ -110,24 +111,8 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) {
}, {
Name: ptr.To(clusterFilterTagName),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}

func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) {
var vpcs []types.Vpc
if vpcID != "" {
vpcs = []types.Vpc{
{
VpcId: ptr.To(vpcID),
},
}
}

f.awsClient.EXPECT().DescribeVpcs(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:Name"),
Values: []string{infraID + "-vpc"},
}, {
Name: ptr.To(clusterFilterTagNameSigs),
Name: ptr.To(providerAWSTagPrefix + infraID),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}
Expand Down
74 changes: 59 additions & 15 deletions pkg/aws/ocpgwdeployer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,35 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.cloudConfig[VPCIDKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)

publicSubnets, err := d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
var publicSubnets []types.Subnet

if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := d.aws.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err = d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
}
}

err = d.validateDeployPrerequisites(vpcID, input, publicSubnets)
Expand All @@ -97,9 +116,15 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte

status.Success("Created Submariner gateway security group %s", gatewaySG)

return d.processSubnets(vpcID, gatewaySG, publicSubnets, input, status)
}

func (d *ocpGatewayDeployer) processSubnets(vpcID, gatewaySG string, publicSubnets []types.Subnet,
input api.GatewayDeployInput, status reporter.Interface,
) error {
subnets, err := d.aws.getSubnetsSupportingInstanceType(publicSubnets, d.instanceType)
if err != nil {
return status.Error(err, "unable to create security group")
return status.Error(err, "unable to get subnets supporting instance type")
}

taggedSubnets, _ := filterSubnets(subnets, func(subnet *types.Subnet) (bool, error) {
Expand Down Expand Up @@ -313,9 +338,11 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error {

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.cloudConfig[VPCIDKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)
Expand All @@ -327,13 +354,30 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error {

status.Success(messageValidatedPrerequisites)

subnets, err := d.aws.getTaggedPublicSubnets(vpcID)
if err != nil {
return err
var publicSubnets []types.Subnet

if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := d.aws.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err = d.aws.getTaggedPublicSubnets(vpcID)
if err != nil {
return err
}
}

for i := range subnets {
subnet := &subnets[i]
for i := range publicSubnets {
subnet := &publicSubnets[i]
subnetName := extractName(subnet.Tags)

status.Start("Removing gateway node for public subnet %s", subnetName)
Expand Down
1 change: 0 additions & 1 deletion pkg/aws/ocpgwdeployer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver {
t.expectDescribeInstances(instanceImageID)
t.expectDescribeSecurityGroups(workerSGName, workerGroupID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

var err error
Expand Down
Loading

0 comments on commit e1aa569

Please sign in to comment.