Skip to content

Commit 7dbda44

Browse files
authored
Improve availability zone selection and validation (#885)
1 parent 5a072d7 commit 7dbda44

File tree

6 files changed

+268
-50
lines changed

6 files changed

+268
-50
lines changed

cli/cmd/lib_cluster_config.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,6 @@ func getInstallClusterConfig(awsCreds AWSCredentials) (*clusterconfig.Config, er
152152
return nil, err
153153
}
154154

155-
if clusterConfig.Spot != nil && *clusterConfig.Spot {
156-
clusterConfig.AutoFillSpot(awsClient)
157-
}
158-
159155
err = clusterConfig.Validate(awsClient)
160156
if err != nil {
161157
if _flagClusterConfig != "" {

pkg/lib/aws/ec2.go

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
"github.com/aws/aws-sdk-go/aws"
2424
"github.com/aws/aws-sdk-go/service/ec2"
2525
"github.com/cortexlabs/cortex/pkg/lib/errors"
26+
"github.com/cortexlabs/cortex/pkg/lib/parallel"
27+
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
2628
s "github.com/cortexlabs/cortex/pkg/lib/strings"
2729
)
2830

@@ -64,19 +66,85 @@ func (c *Client) SpotInstancePrice(region string, instanceType string) (float64,
6466
return min, nil
6567
}
6668

67-
func (c *Client) GetAvailabilityZones() ([]string, error) {
68-
input := &ec2.DescribeAvailabilityZonesInput{}
69+
func (c *Client) ListAvailabilityZones() (strset.Set, error) {
70+
input := &ec2.DescribeAvailabilityZonesInput{
71+
Filters: []*ec2.Filter{
72+
{
73+
Name: aws.String("region-name"),
74+
Values: []*string{aws.String(c.Region)},
75+
},
76+
{
77+
Name: aws.String("state"),
78+
Values: []*string{aws.String(ec2.AvailabilityZoneStateAvailable)},
79+
},
80+
},
81+
}
82+
6983
result, err := c.EC2().DescribeAvailabilityZones(input)
7084
if err != nil {
7185
return nil, errors.WithStack(err)
7286
}
7387

74-
availabilityZones := []string{}
88+
zones := strset.New()
7589
for _, az := range result.AvailabilityZones {
7690
if az.ZoneName != nil {
77-
availabilityZones = append(availabilityZones, *az.ZoneName)
91+
zones.Add(*az.ZoneName)
7892
}
7993
}
8094

81-
return availabilityZones, nil
95+
return zones, nil
96+
}
97+
98+
func (c *Client) listSupportedAvailabilityZonesSingle(instanceType string) (strset.Set, error) {
99+
input := &ec2.DescribeReservedInstancesOfferingsInput{
100+
InstanceType: &instanceType,
101+
IncludeMarketplace: aws.Bool(false),
102+
Filters: []*ec2.Filter{
103+
{
104+
Name: aws.String("scope"),
105+
Values: []*string{aws.String(ec2.ScopeAvailabilityZone)},
106+
},
107+
},
108+
}
109+
110+
zones := strset.New()
111+
err := c.EC2().DescribeReservedInstancesOfferingsPages(input, func(output *ec2.DescribeReservedInstancesOfferingsOutput, lastPage bool) bool {
112+
for _, offering := range output.ReservedInstancesOfferings {
113+
if offering.AvailabilityZone != nil {
114+
zones.Add(*offering.AvailabilityZone)
115+
}
116+
}
117+
return true
118+
})
119+
120+
if err != nil {
121+
return nil, errors.WithStack(err)
122+
}
123+
124+
return zones, nil
125+
}
126+
127+
func (c *Client) ListSupportedAvailabilityZones(instanceType string, instanceTypes ...string) (strset.Set, error) {
128+
allInstanceTypes := append(instanceTypes, instanceType)
129+
zoneSets := make([]strset.Set, len(allInstanceTypes))
130+
fns := make([]func() error, len(allInstanceTypes))
131+
132+
for i := range allInstanceTypes {
133+
localIdx := i
134+
fns[i] = func() error {
135+
zones, err := c.listSupportedAvailabilityZonesSingle(allInstanceTypes[localIdx])
136+
if err != nil {
137+
return err
138+
}
139+
zoneSets[localIdx] = zones
140+
return nil
141+
}
142+
}
143+
144+
err := parallel.RunFirstErr(fns[0], fns[1:]...)
145+
if err != nil {
146+
return nil, err
147+
}
148+
149+
return strset.Intersection(zoneSets...), nil
82150
}

pkg/lib/sets/strset/strset.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package strset
1919
import (
2020
"fmt"
2121
"math"
22+
"sort"
2223
"strings"
2324
)
2425

@@ -159,7 +160,7 @@ func (s Set) String() string {
159160
for item := range s {
160161
v = append(v, fmt.Sprintf("%v", item))
161162
}
162-
return fmt.Sprintf("[\"%s\"]", strings.Join(v, ", "))
163+
return fmt.Sprintf("[%s]", strings.Join(v, ", "))
163164
}
164165

165166
// List returns a slice of all items.
@@ -171,6 +172,13 @@ func (s Set) Slice() []string {
171172
return v
172173
}
173174

175+
// List returns a sorted slice of all items.
176+
func (s Set) SliceSorted() []string {
177+
v := s.Slice()
178+
sort.Strings(v)
179+
return v
180+
}
181+
174182
// Merge is like Union, however it modifies the current Set it's applied on
175183
// with the given t Set.
176184
func (s Set) Merge(sets ...Set) {
@@ -181,7 +189,7 @@ func (s Set) Merge(sets ...Set) {
181189
}
182190
}
183191

184-
// Subtract removes the Set items containing in sets from Set s
192+
// Subtract removes the Set items contained in sets from Set s
185193
func (s Set) Subtract(sets ...Set) {
186194
for _, set := range sets {
187195
for item := range set {
@@ -190,6 +198,13 @@ func (s Set) Subtract(sets ...Set) {
190198
}
191199
}
192200

201+
// Remove items until len(s) <= targetLen
202+
func (s Set) Shrink(targetLen int) {
203+
for len(s) > targetLen {
204+
s.Pop()
205+
}
206+
}
207+
193208
// Union is the merger of multiple sets. It returns a new set with all the
194209
// elements present in all the sets that are passed.
195210
func Union(sets ...Set) Set {
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
Copyright 2020 Cortex Labs, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package clusterconfig
18+
19+
import (
20+
"github.com/cortexlabs/cortex/pkg/lib/aws"
21+
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
22+
)
23+
24+
var _azBlacklist = strset.New("us-east-1e")
25+
26+
func (cc *Config) validateAvailabilityZones(awsClient *aws.Client) error {
27+
var extraInstances []string
28+
if cc.Spot != nil && *cc.Spot && len(cc.SpotConfig.InstanceDistribution) >= 0 {
29+
for _, instanceType := range cc.SpotConfig.InstanceDistribution {
30+
if instanceType != *cc.InstanceType {
31+
extraInstances = append(extraInstances, instanceType)
32+
}
33+
}
34+
}
35+
36+
if len(cc.AvailabilityZones) == 0 {
37+
if err := cc.setDefaultAvailabilityZones(awsClient, extraInstances...); err != nil {
38+
return err
39+
}
40+
return nil
41+
}
42+
43+
if err := cc.validateUserAvailabilityZones(awsClient, extraInstances...); err != nil {
44+
return err
45+
}
46+
47+
return nil
48+
}
49+
50+
func (cc *Config) setDefaultAvailabilityZones(awsClient *aws.Client, extraInstances ...string) error {
51+
zones, err := awsClient.ListSupportedAvailabilityZones(*cc.InstanceType, extraInstances...)
52+
if err != nil {
53+
// Try again without checking instance types
54+
zones, err = awsClient.ListAvailabilityZones()
55+
if err != nil {
56+
return nil // Let eksctl choose the availability zones
57+
}
58+
}
59+
60+
zones.Subtract(_azBlacklist)
61+
62+
if len(zones) < 2 {
63+
return ErrorNotEnoughDefaultSupportedZones(awsClient.Region, zones, *cc.InstanceType, extraInstances...)
64+
}
65+
66+
// See https://github.com/weaveworks/eksctl/blob/master/pkg/eks/api.go
67+
if awsClient.Region == "us-east-1" {
68+
zones.Shrink(2)
69+
} else {
70+
zones.Shrink(3)
71+
}
72+
73+
cc.AvailabilityZones = zones.SliceSorted()
74+
75+
return nil
76+
}
77+
78+
func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInstances ...string) error {
79+
allZones, err := awsClient.ListAvailabilityZones()
80+
if err != nil {
81+
return nil // Skip validation
82+
}
83+
84+
for _, userZone := range cc.AvailabilityZones {
85+
if !allZones.Has(userZone) {
86+
return ErrorInvalidAvailabilityZone(userZone, allZones)
87+
}
88+
}
89+
90+
supportedZones, err := awsClient.ListSupportedAvailabilityZones(*cc.InstanceType, extraInstances...)
91+
if err != nil {
92+
// Skip validation instance-based validation
93+
supportedZones = strset.Difference(allZones, _azBlacklist)
94+
}
95+
96+
for _, userZone := range cc.AvailabilityZones {
97+
if !supportedZones.Has(userZone) {
98+
return ErrorUnsupportedAvailabilityZone(userZone, *cc.InstanceType, extraInstances...)
99+
}
100+
}
101+
102+
return nil
103+
}

pkg/types/clusterconfig/clusterconfig.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package clusterconfig
1818

1919
import (
20+
"fmt"
2021
"regexp"
2122
"sort"
2223
"strings"
@@ -30,7 +31,6 @@ import (
3031
"github.com/cortexlabs/cortex/pkg/lib/hash"
3132
"github.com/cortexlabs/cortex/pkg/lib/pointer"
3233
"github.com/cortexlabs/cortex/pkg/lib/prompt"
33-
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
3434
s "github.com/cortexlabs/cortex/pkg/lib/strings"
3535
"github.com/cortexlabs/cortex/pkg/lib/table"
3636
)
@@ -461,6 +461,8 @@ func (cc *Config) ToAccessConfig() AccessConfig {
461461
}
462462

463463
func (cc *Config) Validate(awsClient *aws.Client) error {
464+
fmt.Print("verifying your configuration...\n\n")
465+
464466
if *cc.MinInstances > *cc.MaxInstances {
465467
return ErrorMinInstancesGreaterThanMax(*cc.MinInstances, *cc.MaxInstances)
466468
}
@@ -481,21 +483,24 @@ func (cc *Config) Validate(awsClient *aws.Client) error {
481483
}
482484
}
483485

484-
if len(cc.AvailabilityZones) > 0 {
485-
zones, err := awsClient.GetAvailabilityZones()
486-
if err != nil {
487-
return err
488-
}
489-
zoneSet := strset.New(zones...)
490-
491-
for _, az := range cc.AvailabilityZones {
492-
if !zoneSet.Has(az) {
493-
return errors.Wrap(ErrorInvalidAvailabilityZone(az, zones), AvailabilityZonesKey)
486+
// instance_distribution cleanup must be performed before availability_zone cleanup
487+
if cc.Spot != nil && *cc.Spot && len(cc.SpotConfig.InstanceDistribution) >= 0 {
488+
cleanedDistribution := []string{*cc.InstanceType}
489+
for _, instanceType := range cc.SpotConfig.InstanceDistribution {
490+
if instanceType != *cc.InstanceType {
491+
cleanedDistribution = append(cleanedDistribution, instanceType)
494492
}
495493
}
494+
cc.SpotConfig.InstanceDistribution = cleanedDistribution
495+
}
496+
497+
if err := cc.validateAvailabilityZones(awsClient); err != nil {
498+
return errors.Wrap(err, AvailabilityZonesKey)
496499
}
497500

498501
if cc.Spot != nil && *cc.Spot {
502+
cc.AutoFillSpot(awsClient)
503+
499504
chosenInstance := aws.InstanceMetadatas[*cc.Region][*cc.InstanceType]
500505
compatibleSpots := CompatibleSpotInstances(awsClient, chosenInstance, cc.SpotConfig.MaxPrice, _spotInstanceDistributionLength)
501506
if len(compatibleSpots) == 0 {
@@ -635,9 +640,7 @@ func CompatibleSpotInstances(awsClient *aws.Client, targetInstance aws.InstanceM
635640

636641
func AutoGenerateSpotConfig(awsClient *aws.Client, spotConfig *SpotConfig, region string, instanceType string) error {
637642
chosenInstance := aws.InstanceMetadatas[region][instanceType]
638-
if len(spotConfig.InstanceDistribution) == 0 {
639-
spotConfig.InstanceDistribution = append(spotConfig.InstanceDistribution, chosenInstance.Type)
640-
643+
if len(spotConfig.InstanceDistribution) == 1 {
641644
compatibleSpots := CompatibleSpotInstances(awsClient, chosenInstance, spotConfig.MaxPrice, _spotInstanceDistributionLength)
642645
if len(compatibleSpots) == 0 {
643646
return errors.Wrap(ErrorNoCompatibleSpotInstanceFound(chosenInstance.Type), InstanceTypeKey)
@@ -646,11 +649,8 @@ func AutoGenerateSpotConfig(awsClient *aws.Client, spotConfig *SpotConfig, regio
646649
for _, instance := range compatibleSpots {
647650
spotConfig.InstanceDistribution = append(spotConfig.InstanceDistribution, instance.Type)
648651
}
649-
} else {
650-
instanceDistributionSet := strset.New(spotConfig.InstanceDistribution...)
651-
instanceDistributionSet.Remove(instanceType)
652-
spotConfig.InstanceDistribution = append([]string{instanceType}, instanceDistributionSet.Slice()...)
653652
}
653+
654654
if spotConfig.MaxPrice == nil {
655655
spotConfig.MaxPrice = &chosenInstance.Price
656656
}

0 commit comments

Comments
 (0)