diff --git a/cmd/create/machinepool/cmd.go b/cmd/create/machinepool/cmd.go index 00b5301c3b..2e0353eab2 100644 --- a/cmd/create/machinepool/cmd.go +++ b/cmd/create/machinepool/cmd.go @@ -24,6 +24,7 @@ import ( "strings" cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + "github.com/openshift/rosa/pkg/aws" "github.com/openshift/rosa/pkg/interactive/confirm" "github.com/openshift/rosa/pkg/rosa" "github.com/spf13/cobra" @@ -50,6 +51,7 @@ var args struct { spotMaxPrice string multiAvailabilityZone bool availabilityZone string + subnet string } var Cmd = &cobra.Command{ @@ -165,6 +167,12 @@ func init() { "", "Select availability zone to create a single AZ machine pool for a multi-AZ cluster") + flags.StringVar( + &args.subnet, + "subnet", + "", + "Select subnet to create a single AZ machine pool for BYOVPC cluster") + interactive.AddFlag(flags) } @@ -192,11 +200,29 @@ func run(cmd *cobra.Command, _ []string) { os.Exit(1) } + // Validate flags that are only allowed for BYOVPC cluster + isSubnetSet := cmd.Flags().Changed("subnet") + if !isBYOVPC(cluster) && isSubnetSet { + r.Reporter.Errorf("Setting the `subnet` flag is only allowed for BYOVPC clusters") + os.Exit(1) + } + + if isSubnetSet && isAvailabilityZoneSet { + r.Reporter.Errorf("Setting both `subnet` and `availability-zone` flag is not supported." + + " Please select `subnet` or `availability-zone` to create a single availability zone machine pool") + os.Exit(1) + } + + // Validate `subnet` or `availability-zone` flags are set for a single AZ machine pool if isAvailabilityZoneSet && isMultiAvailabilityZoneSet && args.multiAvailabilityZone { r.Reporter.Errorf("Setting the `availability-zone` flag is only supported for creating a single AZ " + "machine pool in a multi-AZ cluster") os.Exit(1) } + if isSubnetSet && isMultiAvailabilityZoneSet && args.multiAvailabilityZone { + r.Reporter.Errorf("Setting the `subnet` flag is only supported for creating a single AZ machine pool") + os.Exit(1) + } var err error // Machine pool name: @@ -225,6 +251,12 @@ func run(cmd *cobra.Command, _ []string) { os.Exit(1) } + // Allow the user to select subnet for a single AZ BYOVPC cluster + var subnet string + if !cluster.MultiAZ() && isBYOVPC(cluster) { + subnet = getSubnetFromUser(cmd, r, isSubnetSet, cluster.AWS().SubnetIDs()[0]) + } + // Single AZ machine pool for a multi-AZ cluster var multiAZMachinePool bool var availabilityZone string @@ -251,28 +283,35 @@ func run(cmd *cobra.Command, _ []string) { } if !multiAZMachinePool { - availabilityZone = cluster.Nodes().AvailabilityZones()[0] - - if !isAvailabilityZoneSet && interactive.Enabled() { - availabilityZone, err = interactive.GetOption(interactive.Input{ - Question: "AWS availability zone", - Help: cmd.Flags().Lookup("availability-zone").Usage, - Options: cluster.Nodes().AvailabilityZones(), - Default: availabilityZone, - Required: true, - }) - if err != nil { - r.Reporter.Errorf("Expected a valid AWS availability zone: %s", err) - os.Exit(1) - } - } else if isAvailabilityZoneSet { - availabilityZone = args.availabilityZone + // Allow to create a single AZ machine pool providing the subnet + if isBYOVPC(cluster) { + subnet = getSubnetFromUser(cmd, r, isSubnetSet, cluster.AWS().SubnetIDs()[0]) } - if !helper.Contains(cluster.Nodes().AvailabilityZones(), availabilityZone) { - r.Reporter.Errorf("Availability zone '%s' doesn't belong to the cluster's availability zones", - availabilityZone) - os.Exit(1) + // Select availability zone if the user didn't select subnet + if subnet == "" { + availabilityZone = cluster.Nodes().AvailabilityZones()[0] + if !isAvailabilityZoneSet && interactive.Enabled() { + availabilityZone, err = interactive.GetOption(interactive.Input{ + Question: "AWS availability zone", + Help: cmd.Flags().Lookup("availability-zone").Usage, + Options: cluster.Nodes().AvailabilityZones(), + Default: availabilityZone, + Required: true, + }) + if err != nil { + r.Reporter.Errorf("Expected a valid AWS availability zone: %s", err) + os.Exit(1) + } + } else if isAvailabilityZoneSet { + availabilityZone = args.availabilityZone + } + + if !helper.Contains(cluster.Nodes().AvailabilityZones(), availabilityZone) { + r.Reporter.Errorf("Availability zone '%s' doesn't belong to the cluster's availability zones", + availabilityZone) + os.Exit(1) + } } } } @@ -527,10 +566,15 @@ func run(cmd *cobra.Command, _ []string) { } // Create a single AZ machine pool for a multi-AZ cluster - if cluster.MultiAZ() && !multiAZMachinePool { + if cluster.MultiAZ() && !multiAZMachinePool && availabilityZone != "" { mpBuilder.AvailabilityZones(availabilityZone) } + // Create a single AZ machine pool for a BYOVPC cluster + if subnet != "" { + mpBuilder.Subnets(subnet) + } + machinePool, err := mpBuilder.Build() if err != nil { r.Reporter.Errorf("Failed to create machine pool for cluster '%s': %v", clusterKey, err) @@ -650,3 +694,69 @@ func parseTaints(taints string) ([]*cmv1.TaintBuilder, error) { } return taintBuilders, nil } + +func isBYOVPC(cluster *cmv1.Cluster) bool { + return len(cluster.AWS().SubnetIDs()) > 0 +} + +func getSubnetFromUser(cmd *cobra.Command, r *rosa.Runtime, isSubnetSet bool, + clusterSubnetID string) string { + var selectSubnet bool + var subnet string + var err error + + if !isSubnetSet && interactive.Enabled() { + selectSubnet, err = interactive.GetBool(interactive.Input{ + Question: "Select subnet for a single AZ machine pool", + Help: cmd.Flags().Lookup("subnet").Usage, + Default: false, + Required: false, + }) + if err != nil { + r.Reporter.Errorf("Expected a valid value for select subnet for a single AZ machine pool") + os.Exit(1) + } + } else { + subnet = args.subnet + } + + if selectSubnet { + subnetOptions, err := getSubnetOptions(r.AWSClient, clusterSubnetID) + if err != nil { + r.Reporter.Errorf("%s", err) + os.Exit(1) + } + + subnetOption, err := interactive.GetOption(interactive.Input{ + Question: "Subnet ID", + Help: cmd.Flags().Lookup("subnet").Usage, + Options: subnetOptions, + Default: subnetOptions[0], + Required: true, + }) + if err != nil { + r.Reporter.Errorf("Expected a valid AWS subnet: %s", err) + os.Exit(1) + } + subnet = aws.ParseSubnet(subnetOption) + } + + return subnet +} + +// getSubnetOptions gets one of the cluster subnets and returns a slice of formatted VPC's private subnets. +func getSubnetOptions(awsClient aws.Client, clusterSubnetID string) ([]string, error) { + // Fetch VPC's subnets + privateSubnets, err := awsClient.GetVPCPrivateSubnets(clusterSubnetID) + if err != nil { + return nil, err + } + + // Format subnet options + var subnetOptions []string + for _, subnet := range privateSubnets { + subnetOptions = append(subnetOptions, aws.SetSubnetOption(*subnet.SubnetId, *subnet.AvailabilityZone)) + } + + return subnetOptions, nil +} diff --git a/pkg/aws/client.go b/pkg/aws/client.go index 57e1df0933..1ca9cac630 100644 --- a/pkg/aws/client.go +++ b/pkg/aws/client.go @@ -89,6 +89,7 @@ type Client interface { GetCreator() (*Creator, error) ValidateSCP(*string, map[string]string) (bool, error) GetSubnetIDs() ([]*ec2.Subnet, error) + GetVPCPrivateSubnets(subnetID string) ([]*ec2.Subnet, error) ValidateQuota() (bool, error) TagUserRegion(username string, region string) error GetClusterRegionTagForUser(username string) (string, error) @@ -353,6 +354,132 @@ func (c *awsClient) GetSubnetIDs() ([]*ec2.Subnet, error) { return c.getSubnetIDs(&ec2.DescribeSubnetsInput{}) } +func (c *awsClient) GetVPCPrivateSubnets(subnetID string) ([]*ec2.Subnet, error) { + subnets, err := c.getVPCSubnets(subnetID) + if err != nil { + return nil, err + } + + return c.filterVPCsPrivateSubnets(subnets) +} + +// getVPCSubnets gets a subnet ID and fetches all the subnets that belong to the same VPC as the provided subnet. +func (c *awsClient) getVPCSubnets(subnetID string) ([]*ec2.Subnet, error) { + // Fetch the subnet details + subnets, err := c.getSubnetIDs(&ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("subnet-id"), + Values: []*string{aws.String(subnetID)}, + }, + }, + }) + if err != nil { + return nil, err + } + if len(subnets) < 1 { + return nil, fmt.Errorf("Failed to get subnet with ID '%s'", subnetID) + } + + // Fetch VPC's subnets + vpcID := subnets[0].VpcId + subnets, err = c.getSubnetIDs(&ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: []*string{vpcID}, + }, + }, + }) + if err != nil { + return nil, err + } + if len(subnets) < 1 { + return nil, fmt.Errorf("Failed to get the subnets of VPC with ID '%s'", *vpcID) + } + + return subnets, nil +} + +// FilterPrivateSubnets gets a slice of subnets that belongs to the same VPC and filters the private subnets. +// Assumption: subnets - non-empty slice. +func (c *awsClient) filterVPCsPrivateSubnets(subnets []*ec2.Subnet) ([]*ec2.Subnet, error) { + // Fetch VPC route tables + vpcID := subnets[0].VpcId + describeRouteTablesOutput, err := c.ec2Client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: []*string{vpcID}, + }, + }, + }) + if err != nil { + return nil, err + } + if len(describeRouteTablesOutput.RouteTables) < 1 { + return nil, fmt.Errorf("Failed to find VPC '%s' route table", *vpcID) + } + + var privateSubnets []*ec2.Subnet + for _, subnet := range subnets { + isPublic, err := c.isPublicSubnet(subnet.SubnetId, describeRouteTablesOutput.RouteTables) + if err != nil { + return nil, err + } + if !isPublic { + privateSubnets = append(privateSubnets, subnet) + } + } + + if len(privateSubnets) < 1 { + return nil, fmt.Errorf("Failed to find private subnets associated with VPC '%s'", *subnets[0].VpcId) + } + + return privateSubnets, nil +} + +// isPublicSubnet a public subnet is a subnet that's associated with a route table that has a route to an +// internet gateway +func (c *awsClient) isPublicSubnet(subnetID *string, routeTables []*ec2.RouteTable) (bool, error) { + subnetRouteTable, err := c.getSubnetRouteTable(subnetID, routeTables) + if err != nil { + return false, err + } + + for _, route := range subnetRouteTable.Routes { + if strings.Contains(aws.StringValue(route.GatewayId), "igw") { + return true, nil + } + } + + return false, nil +} + +func (c *awsClient) getSubnetRouteTable(subnetID *string, routeTables []*ec2.RouteTable) (*ec2.RouteTable, error) { + // Subnet route table — A route table that's associated with a subnet + for _, routeTable := range routeTables { + for _, association := range routeTable.Associations { + if aws.StringValue(association.SubnetId) == aws.StringValue(subnetID) { + return routeTable, nil + } + } + } + + // A subnet can be explicitly associated with custom route table, or implicitly or explicitly associated with the + // main route table. + for _, routeTable := range routeTables { + for _, association := range routeTable.Associations { + if aws.BoolValue(association.Main) { + return routeTable, nil + } + } + } + + // Each subnet in the VPC must be associated with a route table + return nil, fmt.Errorf("Failed to find subnet '%s' route table", *subnetID) +} + // getSubnetIDs will return the list of subnetsIDs supported for the region picked. // It is possible to pass non-empty `describeSubnetsInput` to filter results. func (c *awsClient) getSubnetIDs(describeSubnetsInput *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) {