Skip to content

Commit

Permalink
Merge pull request openshift#765 from oriAdler/select-subnet-for-sing…
Browse files Browse the repository at this point in the history
…le-az-machine-pool

Select a single subnet for a single AZ machine pool - BYOVPC clusters
  • Loading branch information
openshift-ci[bot] authored Jun 30, 2022
2 parents d8e69bb + 8c06c00 commit 8b162b0
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 21 deletions.
152 changes: 131 additions & 21 deletions cmd/create/machinepool/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"

cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"github.com/openshift/rosa/pkg/aws"
"github.com/openshift/rosa/pkg/interactive/confirm"
"github.com/openshift/rosa/pkg/rosa"
"github.com/spf13/cobra"
Expand All @@ -50,6 +51,7 @@ var args struct {
spotMaxPrice string
multiAvailabilityZone bool
availabilityZone string
subnet string
}

var Cmd = &cobra.Command{
Expand Down Expand Up @@ -165,6 +167,12 @@ func init() {
"",
"Select availability zone to create a single AZ machine pool for a multi-AZ cluster")

flags.StringVar(
&args.subnet,
"subnet",
"",
"Select subnet to create a single AZ machine pool for BYOVPC cluster")

interactive.AddFlag(flags)
}

Expand Down Expand Up @@ -192,11 +200,29 @@ func run(cmd *cobra.Command, _ []string) {
os.Exit(1)
}

// Validate flags that are only allowed for BYOVPC cluster
isSubnetSet := cmd.Flags().Changed("subnet")
if !isBYOVPC(cluster) && isSubnetSet {
r.Reporter.Errorf("Setting the `subnet` flag is only allowed for BYOVPC clusters")
os.Exit(1)
}

if isSubnetSet && isAvailabilityZoneSet {
r.Reporter.Errorf("Setting both `subnet` and `availability-zone` flag is not supported." +
" Please select `subnet` or `availability-zone` to create a single availability zone machine pool")
os.Exit(1)
}

// Validate `subnet` or `availability-zone` flags are set for a single AZ machine pool
if isAvailabilityZoneSet && isMultiAvailabilityZoneSet && args.multiAvailabilityZone {
r.Reporter.Errorf("Setting the `availability-zone` flag is only supported for creating a single AZ " +
"machine pool in a multi-AZ cluster")
os.Exit(1)
}
if isSubnetSet && isMultiAvailabilityZoneSet && args.multiAvailabilityZone {
r.Reporter.Errorf("Setting the `subnet` flag is only supported for creating a single AZ machine pool")
os.Exit(1)
}

var err error
// Machine pool name:
Expand Down Expand Up @@ -225,6 +251,12 @@ func run(cmd *cobra.Command, _ []string) {
os.Exit(1)
}

// Allow the user to select subnet for a single AZ BYOVPC cluster
var subnet string
if !cluster.MultiAZ() && isBYOVPC(cluster) {
subnet = getSubnetFromUser(cmd, r, isSubnetSet, cluster.AWS().SubnetIDs()[0])
}

// Single AZ machine pool for a multi-AZ cluster
var multiAZMachinePool bool
var availabilityZone string
Expand All @@ -251,28 +283,35 @@ func run(cmd *cobra.Command, _ []string) {
}

if !multiAZMachinePool {
availabilityZone = cluster.Nodes().AvailabilityZones()[0]

if !isAvailabilityZoneSet && interactive.Enabled() {
availabilityZone, err = interactive.GetOption(interactive.Input{
Question: "AWS availability zone",
Help: cmd.Flags().Lookup("availability-zone").Usage,
Options: cluster.Nodes().AvailabilityZones(),
Default: availabilityZone,
Required: true,
})
if err != nil {
r.Reporter.Errorf("Expected a valid AWS availability zone: %s", err)
os.Exit(1)
}
} else if isAvailabilityZoneSet {
availabilityZone = args.availabilityZone
// Allow to create a single AZ machine pool providing the subnet
if isBYOVPC(cluster) {
subnet = getSubnetFromUser(cmd, r, isSubnetSet, cluster.AWS().SubnetIDs()[0])
}

if !helper.Contains(cluster.Nodes().AvailabilityZones(), availabilityZone) {
r.Reporter.Errorf("Availability zone '%s' doesn't belong to the cluster's availability zones",
availabilityZone)
os.Exit(1)
// Select availability zone if the user didn't select subnet
if subnet == "" {
availabilityZone = cluster.Nodes().AvailabilityZones()[0]
if !isAvailabilityZoneSet && interactive.Enabled() {
availabilityZone, err = interactive.GetOption(interactive.Input{
Question: "AWS availability zone",
Help: cmd.Flags().Lookup("availability-zone").Usage,
Options: cluster.Nodes().AvailabilityZones(),
Default: availabilityZone,
Required: true,
})
if err != nil {
r.Reporter.Errorf("Expected a valid AWS availability zone: %s", err)
os.Exit(1)
}
} else if isAvailabilityZoneSet {
availabilityZone = args.availabilityZone
}

if !helper.Contains(cluster.Nodes().AvailabilityZones(), availabilityZone) {
r.Reporter.Errorf("Availability zone '%s' doesn't belong to the cluster's availability zones",
availabilityZone)
os.Exit(1)
}
}
}
}
Expand Down Expand Up @@ -527,10 +566,15 @@ func run(cmd *cobra.Command, _ []string) {
}

// Create a single AZ machine pool for a multi-AZ cluster
if cluster.MultiAZ() && !multiAZMachinePool {
if cluster.MultiAZ() && !multiAZMachinePool && availabilityZone != "" {
mpBuilder.AvailabilityZones(availabilityZone)
}

// Create a single AZ machine pool for a BYOVPC cluster
if subnet != "" {
mpBuilder.Subnets(subnet)
}

machinePool, err := mpBuilder.Build()
if err != nil {
r.Reporter.Errorf("Failed to create machine pool for cluster '%s': %v", clusterKey, err)
Expand Down Expand Up @@ -650,3 +694,69 @@ func parseTaints(taints string) ([]*cmv1.TaintBuilder, error) {
}
return taintBuilders, nil
}

func isBYOVPC(cluster *cmv1.Cluster) bool {
return len(cluster.AWS().SubnetIDs()) > 0
}

func getSubnetFromUser(cmd *cobra.Command, r *rosa.Runtime, isSubnetSet bool,
clusterSubnetID string) string {
var selectSubnet bool
var subnet string
var err error

if !isSubnetSet && interactive.Enabled() {
selectSubnet, err = interactive.GetBool(interactive.Input{
Question: "Select subnet for a single AZ machine pool",
Help: cmd.Flags().Lookup("subnet").Usage,
Default: false,
Required: false,
})
if err != nil {
r.Reporter.Errorf("Expected a valid value for select subnet for a single AZ machine pool")
os.Exit(1)
}
} else {
subnet = args.subnet
}

if selectSubnet {
subnetOptions, err := getSubnetOptions(r.AWSClient, clusterSubnetID)
if err != nil {
r.Reporter.Errorf("%s", err)
os.Exit(1)
}

subnetOption, err := interactive.GetOption(interactive.Input{
Question: "Subnet ID",
Help: cmd.Flags().Lookup("subnet").Usage,
Options: subnetOptions,
Default: subnetOptions[0],
Required: true,
})
if err != nil {
r.Reporter.Errorf("Expected a valid AWS subnet: %s", err)
os.Exit(1)
}
subnet = aws.ParseSubnet(subnetOption)
}

return subnet
}

// getSubnetOptions gets one of the cluster subnets and returns a slice of formatted VPC's private subnets.
func getSubnetOptions(awsClient aws.Client, clusterSubnetID string) ([]string, error) {
// Fetch VPC's subnets
privateSubnets, err := awsClient.GetVPCPrivateSubnets(clusterSubnetID)
if err != nil {
return nil, err
}

// Format subnet options
var subnetOptions []string
for _, subnet := range privateSubnets {
subnetOptions = append(subnetOptions, aws.SetSubnetOption(*subnet.SubnetId, *subnet.AvailabilityZone))
}

return subnetOptions, nil
}
127 changes: 127 additions & 0 deletions pkg/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ type Client interface {
GetCreator() (*Creator, error)
ValidateSCP(*string, map[string]string) (bool, error)
GetSubnetIDs() ([]*ec2.Subnet, error)
GetVPCPrivateSubnets(subnetID string) ([]*ec2.Subnet, error)
ValidateQuota() (bool, error)
TagUserRegion(username string, region string) error
GetClusterRegionTagForUser(username string) (string, error)
Expand Down Expand Up @@ -353,6 +354,132 @@ func (c *awsClient) GetSubnetIDs() ([]*ec2.Subnet, error) {
return c.getSubnetIDs(&ec2.DescribeSubnetsInput{})
}

func (c *awsClient) GetVPCPrivateSubnets(subnetID string) ([]*ec2.Subnet, error) {
subnets, err := c.getVPCSubnets(subnetID)
if err != nil {
return nil, err
}

return c.filterVPCsPrivateSubnets(subnets)
}

// getVPCSubnets gets a subnet ID and fetches all the subnets that belong to the same VPC as the provided subnet.
func (c *awsClient) getVPCSubnets(subnetID string) ([]*ec2.Subnet, error) {
// Fetch the subnet details
subnets, err := c.getSubnetIDs(&ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("subnet-id"),
Values: []*string{aws.String(subnetID)},
},
},
})
if err != nil {
return nil, err
}
if len(subnets) < 1 {
return nil, fmt.Errorf("Failed to get subnet with ID '%s'", subnetID)
}

// Fetch VPC's subnets
vpcID := subnets[0].VpcId
subnets, err = c.getSubnetIDs(&ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("vpc-id"),
Values: []*string{vpcID},
},
},
})
if err != nil {
return nil, err
}
if len(subnets) < 1 {
return nil, fmt.Errorf("Failed to get the subnets of VPC with ID '%s'", *vpcID)
}

return subnets, nil
}

// FilterPrivateSubnets gets a slice of subnets that belongs to the same VPC and filters the private subnets.
// Assumption: subnets - non-empty slice.
func (c *awsClient) filterVPCsPrivateSubnets(subnets []*ec2.Subnet) ([]*ec2.Subnet, error) {
// Fetch VPC route tables
vpcID := subnets[0].VpcId
describeRouteTablesOutput, err := c.ec2Client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("vpc-id"),
Values: []*string{vpcID},
},
},
})
if err != nil {
return nil, err
}
if len(describeRouteTablesOutput.RouteTables) < 1 {
return nil, fmt.Errorf("Failed to find VPC '%s' route table", *vpcID)
}

var privateSubnets []*ec2.Subnet
for _, subnet := range subnets {
isPublic, err := c.isPublicSubnet(subnet.SubnetId, describeRouteTablesOutput.RouteTables)
if err != nil {
return nil, err
}
if !isPublic {
privateSubnets = append(privateSubnets, subnet)
}
}

if len(privateSubnets) < 1 {
return nil, fmt.Errorf("Failed to find private subnets associated with VPC '%s'", *subnets[0].VpcId)
}

return privateSubnets, nil
}

// isPublicSubnet a public subnet is a subnet that's associated with a route table that has a route to an
// internet gateway
func (c *awsClient) isPublicSubnet(subnetID *string, routeTables []*ec2.RouteTable) (bool, error) {
subnetRouteTable, err := c.getSubnetRouteTable(subnetID, routeTables)
if err != nil {
return false, err
}

for _, route := range subnetRouteTable.Routes {
if strings.Contains(aws.StringValue(route.GatewayId), "igw") {
return true, nil
}
}

return false, nil
}

func (c *awsClient) getSubnetRouteTable(subnetID *string, routeTables []*ec2.RouteTable) (*ec2.RouteTable, error) {
// Subnet route table — A route table that's associated with a subnet
for _, routeTable := range routeTables {
for _, association := range routeTable.Associations {
if aws.StringValue(association.SubnetId) == aws.StringValue(subnetID) {
return routeTable, nil
}
}
}

// A subnet can be explicitly associated with custom route table, or implicitly or explicitly associated with the
// main route table.
for _, routeTable := range routeTables {
for _, association := range routeTable.Associations {
if aws.BoolValue(association.Main) {
return routeTable, nil
}
}
}

// Each subnet in the VPC must be associated with a route table
return nil, fmt.Errorf("Failed to find subnet '%s' route table", *subnetID)
}

// getSubnetIDs will return the list of subnetsIDs supported for the region picked.
// It is possible to pass non-empty `describeSubnetsInput` to filter results.
func (c *awsClient) getSubnetIDs(describeSubnetsInput *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) {
Expand Down

0 comments on commit 8b162b0

Please sign in to comment.