Skip to content

Commit 456b465

Browse files
authored
Support creating cluster in existing AWS VPC (#1759)
1 parent 38b038b commit 456b465

File tree

10 files changed

+219
-32
lines changed

10 files changed

+219
-32
lines changed

cli/cmd/lib_cluster_config_aws.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,21 @@ func setConfigFieldsFromCached(userClusterConfig *clusterconfig.Config, cachedCl
288288
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.TagsKey, s.ObjFlat(cachedClusterConfig.Tags))
289289
}
290290

291-
if len(userClusterConfig.AvailabilityZones) > 0 && !strset.New(userClusterConfig.AvailabilityZones...).IsEqual(strset.New(cachedClusterConfig.AvailabilityZones...)) {
292-
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.AvailabilityZonesKey, cachedClusterConfig.AvailabilityZones)
291+
// The user doesn't have to specify AZs in their config
292+
if len(userClusterConfig.AvailabilityZones) > 0 {
293+
if !strset.New(userClusterConfig.AvailabilityZones...).IsEqual(strset.New(cachedClusterConfig.AvailabilityZones...)) {
294+
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.AvailabilityZonesKey, cachedClusterConfig.AvailabilityZones)
295+
}
293296
}
294297
userClusterConfig.AvailabilityZones = cachedClusterConfig.AvailabilityZones
295298

299+
if len(userClusterConfig.Subnets) > 0 || len(cachedClusterConfig.Subnets) > 0 {
300+
if !reflect.DeepEqual(userClusterConfig.Subnets, cachedClusterConfig.Subnets) {
301+
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SubnetsKey, cachedClusterConfig.Subnets)
302+
}
303+
}
304+
userClusterConfig.Subnets = cachedClusterConfig.Subnets
305+
296306
if s.Obj(cachedClusterConfig.SSLCertificateARN) != s.Obj(userClusterConfig.SSLCertificateARN) {
297307
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SSLCertificateARNKey, cachedClusterConfig.SSLCertificateARN)
298308
}
@@ -537,6 +547,10 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsCreds A
537547
fmt.Print(fmt.Sprintf("warning: you've configured the operator load balancer to be internal; you must configure VPC Peering to connect your CLI to your cluster operator (see https://docs.cortex.dev/v/%s/)\n\n", consts.CortexVersionMinor))
538548
}
539549

550+
if len(clusterConfig.Subnets) > 0 {
551+
fmt.Print("warning: you've configured your cluster to be installed in an existing VPC; if your cluster doesn't spin up or function as expected, please double-check your VPC configuration (here are the requirements: https://eksctl.io/usage/vpc-networking/#use-existing-vpc-other-custom-configuration)\n\n")
552+
}
553+
540554
if isSpot && clusterConfig.SpotConfig.OnDemandBackup != nil && !*clusterConfig.SpotConfig.OnDemandBackup {
541555
if *clusterConfig.SpotConfig.OnDemandBaseCapacity == 0 && *clusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity == 0 {
542556
fmt.Printf("warning: you've disabled on-demand instances (%s=0 and %s=0); spot instances are not guaranteed to be available so please take that into account for production clusters; see https://docs.cortex.dev/v/%s/ for more information\n\n", clusterconfig.OnDemandBaseCapacityKey, clusterconfig.OnDemandPercentageAboveBaseCapacityKey, consts.CortexVersionMinor)
@@ -573,6 +587,9 @@ func clusterConfigConfirmationStr(clusterConfig clusterconfig.Config, awsCreds A
573587
if len(clusterConfig.AvailabilityZones) > 0 {
574588
items.Add(clusterconfig.AvailabilityZonesUserKey, clusterConfig.AvailabilityZones)
575589
}
590+
for _, subnetConfig := range clusterConfig.Subnets {
591+
items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID)
592+
}
576593
items.Add(clusterconfig.BucketUserKey, clusterConfig.Bucket)
577594
items.Add(clusterconfig.ClusterNameUserKey, clusterConfig.ClusterName)
578595

docs/clusters/aws/install.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ api_load_balancer_scheme: internet-facing
6262
# note: if using "internal", you must configure VPC Peering to connect your CLI to your cluster operator
6363
operator_load_balancer_scheme: internet-facing
6464

65+
# to install Cortex in an existing VPC, you can provide a list of subnets for your cluster to use
66+
# subnet_visibility (specified above in this file) must match your subnets' visibility
67+
# this is an advanced feature (not recommended for first-time users) and requires your VPC to be configured correctly; see https://eksctl.io/usage/vpc-networking/#use-existing-vpc-other-custom-configuration
68+
# here is an example:
69+
# subnets:
70+
# - availability_zone: us-west-2a
71+
# subnet_id: subnet-060f3961c876872ae
72+
# - availability_zone: us-west-2b
73+
# subnet_id: subnet-0faed05adf6042ab7
74+
6575
# additional tags to assign to AWS resources (all resources will automatically be tagged with cortex.dev/cluster-name: <cluster_name>)
6676
tags: # <string>: <string> map of key/value pairs
6777

manager/generate_eks.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def apply_worker_settings(nodegroup):
6363
def apply_clusterconfig(nodegroup, config):
6464
clusterconfig_settings = {
6565
"instanceType": config["instance_type"],
66-
"availabilityZones": config["availability_zones"],
6766
"volumeSize": config["instance_volume_size"],
6867
"minSize": config["min_instances"],
6968
"maxSize": config["max_instances"],
@@ -161,7 +160,6 @@ def generate_eks(cluster_config_path):
161160
operator_settings = {
162161
"name": "ng-cortex-operator",
163162
"instanceType": "t3.medium",
164-
"availabilityZones": cluster_config["availability_zones"],
165163
"minSize": 1,
166164
"maxSize": 1,
167165
"desiredCapacity": 1,
@@ -198,10 +196,27 @@ def generate_eks(cluster_config_path):
198196
"tags": cluster_config["tags"],
199197
},
200198
"vpc": {"nat": {"gateway": nat_gateway}},
201-
"availabilityZones": cluster_config["availability_zones"],
202199
"nodeGroups": [operator_nodegroup, worker_nodegroup],
203200
}
204201

202+
if (
203+
len(cluster_config.get("availability_zones", [])) > 0
204+
and len(cluster_config.get("subnets", [])) == 0
205+
):
206+
eks["availabilityZones"] = cluster_config["availability_zones"]
207+
208+
if len(cluster_config.get("subnets", [])) > 0:
209+
eks_subnet_configs = {}
210+
for subnet_config in cluster_config["subnets"]:
211+
eks_subnet_configs[subnet_config["availability_zone"]] = {
212+
"id": subnet_config["subnet_id"]
213+
}
214+
215+
if cluster_config.get("subnet_visibility", "public") == "private":
216+
eks["vpc"]["subnets"] = {"private": eks_subnet_configs}
217+
else:
218+
eks["vpc"]["subnets"] = {"public": eks_subnet_configs}
219+
205220
if cluster_config.get("vpc_cidr", "") != "":
206221
eks["vpc"]["cidr"] = cluster_config["vpc_cidr"]
207222

manager/refresh_cluster_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,13 @@ def refresh_yaml(configmap_yaml_path, output_yaml_path):
115115
)
116116
)
117117
asg = asgs[0]
118+
118119
cluster_config["min_instances"] = asg["MinSize"]
119120
cluster_config["max_instances"] = asg["MaxSize"]
120-
cluster_config["availability_zones"] = asg["AvailabilityZones"]
121+
122+
if len(cluster_config.get("subnets", [])) == 0:
123+
cluster_config["availability_zones"] = asg["AvailabilityZones"]
124+
121125
if asg.get("MixedInstancesPolicy") is not None:
122126
launch_template = get_launch_template(
123127
asg["MixedInstancesPolicy"]["LaunchTemplate"]["LaunchTemplateSpecification"][

pkg/lib/configreader/reader.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ import (
3737
)
3838

3939
type StructFieldValidation struct {
40-
Key string // Required, defaults to json key or "StructField"
41-
StructField string // Required
42-
DefaultField string // Optional. Will set the default to the runtime value of this field
43-
DefaultFieldFunc func(interface{}) interface{} // Optional. Will call the func with the value of DefaultField
40+
Key string // Required, defaults to json key or "StructField"
41+
StructField string // Required
42+
DefaultField string // Optional. Will set the default to the runtime value of this field
43+
DefaultDependentFields []string // Optional. Will be passed in to DefaultDependentFieldsFunc. Dependent fields must be listed first in the `[]*cr.StructFieldValidation`.
44+
DefaultDependentFieldsFunc func([]interface{}) interface{} // Optional. Will be called with DefaultDependentFields
4445

4546
// Provide one of the following:
4647
StringValidation *StringValidation
@@ -94,6 +95,9 @@ type StructListValidation struct {
9495
Required bool
9596
AllowExplicitNull bool
9697
TreatNullAsEmpty bool // If explicit null or if it's top level and the file is empty, treat as empty map
98+
MinLength int
99+
MaxLength int
100+
InvalidLengths []int
97101
CantBeSpecifiedErrStr *string
98102
ShortCircuit bool
99103
}
@@ -399,6 +403,22 @@ func StructList(dest interface{}, inter interface{}, v *StructListValidation) (i
399403
return nil, []error{ErrorInvalidPrimitiveType(inter, PrimTypeList)}
400404
}
401405

406+
if v.MinLength != 0 {
407+
if len(interSlice) < v.MinLength {
408+
return nil, []error{ErrorTooFewElements(v.MinLength)}
409+
}
410+
}
411+
if v.MaxLength != 0 {
412+
if len(interSlice) > v.MaxLength {
413+
return nil, []error{ErrorTooManyElements(v.MaxLength)}
414+
}
415+
}
416+
for _, invalidLength := range v.InvalidLengths {
417+
if len(interSlice) == invalidLength {
418+
return nil, []error{ErrorWrongNumberOfElements(v.InvalidLengths)}
419+
}
420+
}
421+
402422
errs := []error{}
403423
for i, interItem := range interSlice {
404424
val := reflect.New(reflect.ValueOf(dest).Type().Elem().Elem()).Interface()
@@ -538,10 +558,14 @@ func InterfaceStructList(dest interface{}, inter interface{}, v *InterfaceStruct
538558
func updateValidation(validation interface{}, dest interface{}, structFieldValidation *StructFieldValidation) {
539559
if structFieldValidation.DefaultField != "" {
540560
runtimeVal := reflect.ValueOf(dest).Elem().FieldByName(structFieldValidation.DefaultField).Interface()
541-
if structFieldValidation.DefaultFieldFunc != nil {
542-
runtimeVal = structFieldValidation.DefaultFieldFunc(runtimeVal)
543-
}
544561
setField(runtimeVal, validation, "Default")
562+
} else if structFieldValidation.DefaultDependentFieldsFunc != nil {
563+
runtimeVals := make([]interface{}, len(structFieldValidation.DefaultDependentFields))
564+
for i, fieldName := range structFieldValidation.DefaultDependentFields {
565+
runtimeVals[i] = reflect.ValueOf(dest).Elem().FieldByName(fieldName).Interface()
566+
}
567+
val := structFieldValidation.DefaultDependentFieldsFunc(runtimeVals)
568+
setField(val, validation, "Default")
545569
}
546570
}
547571

pkg/lib/configreader/reader_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -912,10 +912,10 @@ func TestDefaultField(t *testing.T) {
912912
StringValidation: &StringValidation{},
913913
},
914914
{
915-
StructField: "Key3",
916-
DefaultField: "Key2",
917-
DefaultFieldFunc: func(val interface{}) interface{} {
918-
return val.(string) + ".py"
915+
StructField: "Key3",
916+
DefaultDependentFields: []string{"Key2"},
917+
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
918+
return vals[0].(string) + ".py"
919919
},
920920
StringValidation: &StringValidation{},
921921
},
@@ -945,10 +945,10 @@ func TestDefaultField(t *testing.T) {
945945
StringValidation: &StringValidation{},
946946
},
947947
{
948-
StructField: "Key3",
949-
DefaultField: "Key1",
950-
DefaultFieldFunc: func(val interface{}) interface{} {
951-
if val.(bool) {
948+
StructField: "Key3",
949+
DefaultDependentFields: []string{"Key1"},
950+
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
951+
if vals[0].(bool) {
952952
return "It was true"
953953
}
954954
return "It was false"
@@ -981,10 +981,10 @@ func TestDefaultField(t *testing.T) {
981981
StringValidation: &StringValidation{},
982982
},
983983
{
984-
StructField: "Key1",
985-
DefaultField: "Key2",
986-
DefaultFieldFunc: func(val interface{}) interface{} {
987-
if val.(string) == "key2" {
984+
StructField: "Key1",
985+
DefaultDependentFields: []string{"Key2"},
986+
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
987+
if vals[0].(string) == "key2" {
988988
return true
989989
}
990990
return false

pkg/types/clusterconfig/availability_zones.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ package clusterconfig
1818

1919
import (
2020
"github.com/cortexlabs/cortex/pkg/lib/aws"
21+
"github.com/cortexlabs/cortex/pkg/lib/errors"
2122
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
23+
s "github.com/cortexlabs/cortex/pkg/lib/strings"
2224
)
2325

2426
var _azBlacklist = strset.New("us-east-1e")
2527

26-
func (cc *Config) validateAvailabilityZones(awsClient *aws.Client) error {
28+
func (cc *Config) setAvailabilityZones(awsClient *aws.Client) error {
2729
if len(cc.AvailabilityZones) == 0 {
2830
if err := cc.setDefaultAvailabilityZones(awsClient); err != nil {
2931
return err
@@ -92,3 +94,28 @@ func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInst
9294

9395
return nil
9496
}
97+
98+
func (cc *Config) validateSubnets(awsClient *aws.Client) error {
99+
if len(cc.Subnets) == 0 {
100+
return nil
101+
}
102+
103+
allZones, err := awsClient.ListAvailabilityZonesInRegion()
104+
if err != nil {
105+
return nil // Skip validation
106+
}
107+
108+
userZones := strset.New()
109+
110+
for i, subnetConfig := range cc.Subnets {
111+
if !allZones.Has(subnetConfig.AvailabilityZone) {
112+
return errors.Wrap(ErrorInvalidAvailabilityZone(subnetConfig.AvailabilityZone, allZones, *cc.Region), s.Index(i), AvailabilityZoneKey)
113+
}
114+
if userZones.Has(subnetConfig.AvailabilityZone) {
115+
return ErrorAvailabilityZoneSpecifiedTwice(subnetConfig.AvailabilityZone)
116+
}
117+
userZones.Add(subnetConfig.AvailabilityZone)
118+
}
119+
120+
return nil
121+
}

pkg/types/clusterconfig/cluster_config_aws.go

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ type Config struct {
6868
SSLCertificateARN *string `json:"ssl_certificate_arn,omitempty" yaml:"ssl_certificate_arn,omitempty"`
6969
Bucket string `json:"bucket" yaml:"bucket"`
7070
SubnetVisibility SubnetVisibility `json:"subnet_visibility" yaml:"subnet_visibility"`
71+
Subnets []*Subnet `json:"subnets,omitempty" yaml:"subnets,omitempty"`
7172
NATGateway NATGateway `json:"nat_gateway" yaml:"nat_gateway"`
7273
APILoadBalancerScheme LoadBalancerScheme `json:"api_load_balancer_scheme" yaml:"api_load_balancer_scheme"`
7374
OperatorLoadBalancerScheme LoadBalancerScheme `json:"operator_load_balancer_scheme" yaml:"operator_load_balancer_scheme"`
@@ -97,6 +98,11 @@ type SpotConfig struct {
9798
OnDemandBackup *bool `json:"on_demand_backup" yaml:"on_demand_backup"`
9899
}
99100

101+
type Subnet struct {
102+
AvailabilityZone string `json:"availability_zone" yaml:"availability_zone"`
103+
SubnetID string `json:"subnet_id" yaml:"subnet_id"`
104+
}
105+
100106
type InternalConfig struct {
101107
Config
102108

@@ -303,6 +309,25 @@ var UserValidation = &cr.StructValidation{
303309
return SubnetVisibilityFromString(str), nil
304310
},
305311
},
312+
{
313+
StructField: "Subnets",
314+
StructListValidation: &cr.StructListValidation{
315+
AllowExplicitNull: true,
316+
MinLength: 2,
317+
StructValidation: &cr.StructValidation{
318+
StructFieldValidations: []*cr.StructFieldValidation{
319+
{
320+
StructField: "AvailabilityZone",
321+
StringValidation: &cr.StringValidation{},
322+
},
323+
{
324+
StructField: "SubnetID",
325+
StringValidation: &cr.StringValidation{},
326+
},
327+
},
328+
},
329+
},
330+
},
306331
{
307332
StructField: "NATGateway",
308333
StringValidation: &cr.StringValidation{
@@ -311,9 +336,15 @@ var UserValidation = &cr.StructValidation{
311336
Parser: func(str string) (interface{}, error) {
312337
return NATGatewayFromString(str), nil
313338
},
314-
DefaultField: "SubnetVisibility",
315-
DefaultFieldFunc: func(val interface{}) interface{} {
316-
if val.(SubnetVisibility) == PublicSubnetVisibility {
339+
DefaultDependentFields: []string{"SubnetVisibility", "Subnets"},
340+
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
341+
subnetVisibility := vals[0].(SubnetVisibility)
342+
subnets := vals[1].([]*Subnet)
343+
344+
if len(subnets) > 0 {
345+
return NoneNATGateway.String()
346+
}
347+
if subnetVisibility == PublicSubnetVisibility {
317348
return NoneNATGateway.String()
318349
}
319350
return SingleNATGateway.String()
@@ -519,7 +550,15 @@ func (cc *Config) Validate(awsClient *aws.Client) error {
519550
return ErrorMinInstancesGreaterThanMax(*cc.MinInstances, *cc.MaxInstances)
520551
}
521552

522-
if cc.SubnetVisibility == PrivateSubnetVisibility && cc.NATGateway == NoneNATGateway {
553+
if len(cc.AvailabilityZones) > 0 && len(cc.Subnets) > 0 {
554+
return ErrorSpecifyOneOrNone(AvailabilityZonesKey, SubnetsKey)
555+
}
556+
557+
if len(cc.Subnets) > 0 && cc.NATGateway != NoneNATGateway {
558+
return ErrorNoNATGatewayWithSubnets()
559+
}
560+
561+
if cc.SubnetVisibility == PrivateSubnetVisibility && cc.NATGateway == NoneNATGateway && len(cc.Subnets) == 0 {
523562
return ErrorNATRequiredWithPrivateSubnetVisibility()
524563
}
525564

@@ -597,8 +636,14 @@ func (cc *Config) Validate(awsClient *aws.Client) error {
597636
}
598637
cc.Tags[ClusterNameTag] = cc.ClusterName
599638

600-
if err := cc.validateAvailabilityZones(awsClient); err != nil {
601-
return errors.Wrap(err, AvailabilityZonesKey)
639+
if len(cc.Subnets) > 0 {
640+
if err := cc.validateSubnets(awsClient); err != nil {
641+
return errors.Wrap(err, SubnetsKey)
642+
}
643+
} else {
644+
if err := cc.setAvailabilityZones(awsClient); err != nil {
645+
return errors.Wrap(err, AvailabilityZonesKey)
646+
}
602647
}
603648

604649
if cc.Spot != nil && *cc.Spot {
@@ -1095,6 +1140,9 @@ func (cc *Config) UserTable() table.KeyValuePairs {
10951140
if len(cc.AvailabilityZones) > 0 {
10961141
items.Add(AvailabilityZonesUserKey, cc.AvailabilityZones)
10971142
}
1143+
for _, subnetConfig := range cc.Subnets {
1144+
items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID)
1145+
}
10981146
items.Add(BucketUserKey, cc.Bucket)
10991147
items.Add(InstanceTypeUserKey, *cc.InstanceType)
11001148
items.Add(MinInstancesUserKey, *cc.MinInstances)
@@ -1184,6 +1232,11 @@ func (cc *Config) TelemetryEvent() map[string]interface{} {
11841232
event["availability_zones._len"] = len(cc.AvailabilityZones)
11851233
event["availability_zones"] = cc.AvailabilityZones
11861234
}
1235+
if len(cc.Subnets) > 0 {
1236+
event["subnets._is_defined"] = true
1237+
event["subnets._len"] = len(cc.Subnets)
1238+
event["subnets"] = cc.Subnets
1239+
}
11871240
if cc.SSLCertificateARN != nil {
11881241
event["ssl_certificate_arn._is_defined"] = true
11891242
}

0 commit comments

Comments
 (0)