Skip to content

Support creating cluster in existing AWS VPC #1759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions cli/cmd/lib_cluster_config_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,21 @@ func setConfigFieldsFromCached(userClusterConfig *clusterconfig.Config, cachedCl
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.TagsKey, s.ObjFlat(cachedClusterConfig.Tags))
}

if len(userClusterConfig.AvailabilityZones) > 0 && !strset.New(userClusterConfig.AvailabilityZones...).IsEqual(strset.New(cachedClusterConfig.AvailabilityZones...)) {
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.AvailabilityZonesKey, cachedClusterConfig.AvailabilityZones)
// The user doesn't have to specify AZs in their config
if len(userClusterConfig.AvailabilityZones) > 0 {
if !strset.New(userClusterConfig.AvailabilityZones...).IsEqual(strset.New(cachedClusterConfig.AvailabilityZones...)) {
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.AvailabilityZonesKey, cachedClusterConfig.AvailabilityZones)
}
}
userClusterConfig.AvailabilityZones = cachedClusterConfig.AvailabilityZones

if len(userClusterConfig.Subnets) > 0 || len(cachedClusterConfig.Subnets) > 0 {
if !reflect.DeepEqual(userClusterConfig.Subnets, cachedClusterConfig.Subnets) {
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SubnetsKey, cachedClusterConfig.Subnets)
}
}
userClusterConfig.Subnets = cachedClusterConfig.Subnets

if s.Obj(cachedClusterConfig.SSLCertificateARN) != s.Obj(userClusterConfig.SSLCertificateARN) {
return clusterconfig.ErrorConfigCannotBeChangedOnUpdate(clusterconfig.SSLCertificateARNKey, cachedClusterConfig.SSLCertificateARN)
}
Expand Down Expand Up @@ -537,6 +547,10 @@ func confirmInstallClusterConfig(clusterConfig *clusterconfig.Config, awsCreds A
fmt.Print(fmt.Sprintf("warning: you've configured the operator load balancer to be internal; you must configure VPC Peering to connect your CLI to your cluster operator (see https://docs.cortex.dev/v/%s/)\n\n", consts.CortexVersionMinor))
}

if len(clusterConfig.Subnets) > 0 {
fmt.Print("warning: you've configured your cluster to be installed in an existing VPC; if your cluster doesn't spin up or function as expected, please double-check your VPC configuration (here are the requirements: https://eksctl.io/usage/vpc-networking/#use-existing-vpc-other-custom-configuration)\n\n")
}

if isSpot && clusterConfig.SpotConfig.OnDemandBackup != nil && !*clusterConfig.SpotConfig.OnDemandBackup {
if *clusterConfig.SpotConfig.OnDemandBaseCapacity == 0 && *clusterConfig.SpotConfig.OnDemandPercentageAboveBaseCapacity == 0 {
fmt.Printf("warning: you've disabled on-demand instances (%s=0 and %s=0); spot instances are not guaranteed to be available so please take that into account for production clusters; see https://docs.cortex.dev/v/%s/ for more information\n\n", clusterconfig.OnDemandBaseCapacityKey, clusterconfig.OnDemandPercentageAboveBaseCapacityKey, consts.CortexVersionMinor)
Expand Down Expand Up @@ -573,6 +587,9 @@ func clusterConfigConfirmationStr(clusterConfig clusterconfig.Config, awsCreds A
if len(clusterConfig.AvailabilityZones) > 0 {
items.Add(clusterconfig.AvailabilityZonesUserKey, clusterConfig.AvailabilityZones)
}
for _, subnetConfig := range clusterConfig.Subnets {
items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID)
}
items.Add(clusterconfig.BucketUserKey, clusterConfig.Bucket)
items.Add(clusterconfig.ClusterNameUserKey, clusterConfig.ClusterName)

Expand Down
10 changes: 10 additions & 0 deletions docs/clusters/aws/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ api_load_balancer_scheme: internet-facing
# note: if using "internal", you must configure VPC Peering to connect your CLI to your cluster operator
operator_load_balancer_scheme: internet-facing

# to install Cortex in an existing VPC, you can provide a list of subnets for your cluster to use
# subnet_visibility (specified above in this file) must match your subnets' visibility
# this is an advanced feature (not recommended for first-time users) and requires your VPC to be configured correctly; see https://eksctl.io/usage/vpc-networking/#use-existing-vpc-other-custom-configuration
# here is an example:
# subnets:
# - availability_zone: us-west-2a
# subnet_id: subnet-060f3961c876872ae
# - availability_zone: us-west-2b
# subnet_id: subnet-0faed05adf6042ab7

# additional tags to assign to AWS resources (all resources will automatically be tagged with cortex.dev/cluster-name: <cluster_name>)
tags: # <string>: <string> map of key/value pairs

Expand Down
21 changes: 18 additions & 3 deletions manager/generate_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def apply_worker_settings(nodegroup):
def apply_clusterconfig(nodegroup, config):
clusterconfig_settings = {
"instanceType": config["instance_type"],
"availabilityZones": config["availability_zones"],
"volumeSize": config["instance_volume_size"],
"minSize": config["min_instances"],
"maxSize": config["max_instances"],
Expand Down Expand Up @@ -161,7 +160,6 @@ def generate_eks(cluster_config_path):
operator_settings = {
"name": "ng-cortex-operator",
"instanceType": "t3.medium",
"availabilityZones": cluster_config["availability_zones"],
"minSize": 1,
"maxSize": 1,
"desiredCapacity": 1,
Expand Down Expand Up @@ -198,10 +196,27 @@ def generate_eks(cluster_config_path):
"tags": cluster_config["tags"],
},
"vpc": {"nat": {"gateway": nat_gateway}},
"availabilityZones": cluster_config["availability_zones"],
"nodeGroups": [operator_nodegroup, worker_nodegroup],
}

if (
len(cluster_config.get("availability_zones", [])) > 0
and len(cluster_config.get("subnets", [])) == 0
):
eks["availabilityZones"] = cluster_config["availability_zones"]

if len(cluster_config.get("subnets", [])) > 0:
eks_subnet_configs = {}
for subnet_config in cluster_config["subnets"]:
eks_subnet_configs[subnet_config["availability_zone"]] = {
"id": subnet_config["subnet_id"]
}

if cluster_config.get("subnet_visibility", "public") == "private":
eks["vpc"]["subnets"] = {"private": eks_subnet_configs}
else:
eks["vpc"]["subnets"] = {"public": eks_subnet_configs}

if cluster_config.get("vpc_cidr", "") != "":
eks["vpc"]["cidr"] = cluster_config["vpc_cidr"]

Expand Down
6 changes: 5 additions & 1 deletion manager/refresh_cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,13 @@ def refresh_yaml(configmap_yaml_path, output_yaml_path):
)
)
asg = asgs[0]

cluster_config["min_instances"] = asg["MinSize"]
cluster_config["max_instances"] = asg["MaxSize"]
cluster_config["availability_zones"] = asg["AvailabilityZones"]

if len(cluster_config.get("subnets", [])) == 0:
cluster_config["availability_zones"] = asg["AvailabilityZones"]

if asg.get("MixedInstancesPolicy") is not None:
launch_template = get_launch_template(
asg["MixedInstancesPolicy"]["LaunchTemplate"]["LaunchTemplateSpecification"][
Expand Down
38 changes: 31 additions & 7 deletions pkg/lib/configreader/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ import (
)

type StructFieldValidation struct {
Key string // Required, defaults to json key or "StructField"
StructField string // Required
DefaultField string // Optional. Will set the default to the runtime value of this field
DefaultFieldFunc func(interface{}) interface{} // Optional. Will call the func with the value of DefaultField
Key string // Required, defaults to json key or "StructField"
StructField string // Required
DefaultField string // Optional. Will set the default to the runtime value of this field
DefaultDependentFields []string // Optional. Will be passed in to DefaultDependentFieldsFunc. Dependent fields must be listed first in the `[]*cr.StructFieldValidation`.
DefaultDependentFieldsFunc func([]interface{}) interface{} // Optional. Will be called with DefaultDependentFields

// Provide one of the following:
StringValidation *StringValidation
Expand Down Expand Up @@ -94,6 +95,9 @@ type StructListValidation struct {
Required bool
AllowExplicitNull bool
TreatNullAsEmpty bool // If explicit null or if it's top level and the file is empty, treat as empty map
MinLength int
MaxLength int
InvalidLengths []int
CantBeSpecifiedErrStr *string
ShortCircuit bool
}
Expand Down Expand Up @@ -399,6 +403,22 @@ func StructList(dest interface{}, inter interface{}, v *StructListValidation) (i
return nil, []error{ErrorInvalidPrimitiveType(inter, PrimTypeList)}
}

if v.MinLength != 0 {
if len(interSlice) < v.MinLength {
return nil, []error{ErrorTooFewElements(v.MinLength)}
}
}
if v.MaxLength != 0 {
if len(interSlice) > v.MaxLength {
return nil, []error{ErrorTooManyElements(v.MaxLength)}
}
}
for _, invalidLength := range v.InvalidLengths {
if len(interSlice) == invalidLength {
return nil, []error{ErrorWrongNumberOfElements(v.InvalidLengths)}
}
}

errs := []error{}
for i, interItem := range interSlice {
val := reflect.New(reflect.ValueOf(dest).Type().Elem().Elem()).Interface()
Expand Down Expand Up @@ -538,10 +558,14 @@ func InterfaceStructList(dest interface{}, inter interface{}, v *InterfaceStruct
func updateValidation(validation interface{}, dest interface{}, structFieldValidation *StructFieldValidation) {
if structFieldValidation.DefaultField != "" {
runtimeVal := reflect.ValueOf(dest).Elem().FieldByName(structFieldValidation.DefaultField).Interface()
if structFieldValidation.DefaultFieldFunc != nil {
runtimeVal = structFieldValidation.DefaultFieldFunc(runtimeVal)
}
setField(runtimeVal, validation, "Default")
} else if structFieldValidation.DefaultDependentFieldsFunc != nil {
runtimeVals := make([]interface{}, len(structFieldValidation.DefaultDependentFields))
for i, fieldName := range structFieldValidation.DefaultDependentFields {
runtimeVals[i] = reflect.ValueOf(dest).Elem().FieldByName(fieldName).Interface()
}
val := structFieldValidation.DefaultDependentFieldsFunc(runtimeVals)
setField(val, validation, "Default")
}
}

Expand Down
24 changes: 12 additions & 12 deletions pkg/lib/configreader/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,10 +912,10 @@ func TestDefaultField(t *testing.T) {
StringValidation: &StringValidation{},
},
{
StructField: "Key3",
DefaultField: "Key2",
DefaultFieldFunc: func(val interface{}) interface{} {
return val.(string) + ".py"
StructField: "Key3",
DefaultDependentFields: []string{"Key2"},
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
return vals[0].(string) + ".py"
},
StringValidation: &StringValidation{},
},
Expand Down Expand Up @@ -945,10 +945,10 @@ func TestDefaultField(t *testing.T) {
StringValidation: &StringValidation{},
},
{
StructField: "Key3",
DefaultField: "Key1",
DefaultFieldFunc: func(val interface{}) interface{} {
if val.(bool) {
StructField: "Key3",
DefaultDependentFields: []string{"Key1"},
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
if vals[0].(bool) {
return "It was true"
}
return "It was false"
Expand Down Expand Up @@ -981,10 +981,10 @@ func TestDefaultField(t *testing.T) {
StringValidation: &StringValidation{},
},
{
StructField: "Key1",
DefaultField: "Key2",
DefaultFieldFunc: func(val interface{}) interface{} {
if val.(string) == "key2" {
StructField: "Key1",
DefaultDependentFields: []string{"Key2"},
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
if vals[0].(string) == "key2" {
return true
}
return false
Expand Down
29 changes: 28 additions & 1 deletion pkg/types/clusterconfig/availability_zones.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package clusterconfig

import (
"github.com/cortexlabs/cortex/pkg/lib/aws"
"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
s "github.com/cortexlabs/cortex/pkg/lib/strings"
)

var _azBlacklist = strset.New("us-east-1e")

func (cc *Config) validateAvailabilityZones(awsClient *aws.Client) error {
func (cc *Config) setAvailabilityZones(awsClient *aws.Client) error {
if len(cc.AvailabilityZones) == 0 {
if err := cc.setDefaultAvailabilityZones(awsClient); err != nil {
return err
Expand Down Expand Up @@ -92,3 +94,28 @@ func (cc *Config) validateUserAvailabilityZones(awsClient *aws.Client, extraInst

return nil
}

func (cc *Config) validateSubnets(awsClient *aws.Client) error {
if len(cc.Subnets) == 0 {
return nil
}

allZones, err := awsClient.ListAvailabilityZonesInRegion()
if err != nil {
return nil // Skip validation
}

userZones := strset.New()

for i, subnetConfig := range cc.Subnets {
if !allZones.Has(subnetConfig.AvailabilityZone) {
return errors.Wrap(ErrorInvalidAvailabilityZone(subnetConfig.AvailabilityZone, allZones, *cc.Region), s.Index(i), AvailabilityZoneKey)
}
if userZones.Has(subnetConfig.AvailabilityZone) {
return ErrorAvailabilityZoneSpecifiedTwice(subnetConfig.AvailabilityZone)
}
userZones.Add(subnetConfig.AvailabilityZone)
}

return nil
}
65 changes: 59 additions & 6 deletions pkg/types/clusterconfig/cluster_config_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Config struct {
SSLCertificateARN *string `json:"ssl_certificate_arn,omitempty" yaml:"ssl_certificate_arn,omitempty"`
Bucket string `json:"bucket" yaml:"bucket"`
SubnetVisibility SubnetVisibility `json:"subnet_visibility" yaml:"subnet_visibility"`
Subnets []*Subnet `json:"subnets,omitempty" yaml:"subnets,omitempty"`
NATGateway NATGateway `json:"nat_gateway" yaml:"nat_gateway"`
APILoadBalancerScheme LoadBalancerScheme `json:"api_load_balancer_scheme" yaml:"api_load_balancer_scheme"`
OperatorLoadBalancerScheme LoadBalancerScheme `json:"operator_load_balancer_scheme" yaml:"operator_load_balancer_scheme"`
Expand Down Expand Up @@ -97,6 +98,11 @@ type SpotConfig struct {
OnDemandBackup *bool `json:"on_demand_backup" yaml:"on_demand_backup"`
}

type Subnet struct {
AvailabilityZone string `json:"availability_zone" yaml:"availability_zone"`
SubnetID string `json:"subnet_id" yaml:"subnet_id"`
}

type InternalConfig struct {
Config

Expand Down Expand Up @@ -303,6 +309,25 @@ var UserValidation = &cr.StructValidation{
return SubnetVisibilityFromString(str), nil
},
},
{
StructField: "Subnets",
StructListValidation: &cr.StructListValidation{
AllowExplicitNull: true,
MinLength: 2,
StructValidation: &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "AvailabilityZone",
StringValidation: &cr.StringValidation{},
},
{
StructField: "SubnetID",
StringValidation: &cr.StringValidation{},
},
},
},
},
},
{
StructField: "NATGateway",
StringValidation: &cr.StringValidation{
Expand All @@ -311,9 +336,15 @@ var UserValidation = &cr.StructValidation{
Parser: func(str string) (interface{}, error) {
return NATGatewayFromString(str), nil
},
DefaultField: "SubnetVisibility",
DefaultFieldFunc: func(val interface{}) interface{} {
if val.(SubnetVisibility) == PublicSubnetVisibility {
DefaultDependentFields: []string{"SubnetVisibility", "Subnets"},
DefaultDependentFieldsFunc: func(vals []interface{}) interface{} {
subnetVisibility := vals[0].(SubnetVisibility)
subnets := vals[1].([]*Subnet)

if len(subnets) > 0 {
return NoneNATGateway.String()
}
if subnetVisibility == PublicSubnetVisibility {
return NoneNATGateway.String()
}
return SingleNATGateway.String()
Expand Down Expand Up @@ -519,7 +550,15 @@ func (cc *Config) Validate(awsClient *aws.Client) error {
return ErrorMinInstancesGreaterThanMax(*cc.MinInstances, *cc.MaxInstances)
}

if cc.SubnetVisibility == PrivateSubnetVisibility && cc.NATGateway == NoneNATGateway {
if len(cc.AvailabilityZones) > 0 && len(cc.Subnets) > 0 {
return ErrorSpecifyOneOrNone(AvailabilityZonesKey, SubnetsKey)
}

if len(cc.Subnets) > 0 && cc.NATGateway != NoneNATGateway {
return ErrorNoNATGatewayWithSubnets()
}

if cc.SubnetVisibility == PrivateSubnetVisibility && cc.NATGateway == NoneNATGateway && len(cc.Subnets) == 0 {
return ErrorNATRequiredWithPrivateSubnetVisibility()
}

Expand Down Expand Up @@ -597,8 +636,14 @@ func (cc *Config) Validate(awsClient *aws.Client) error {
}
cc.Tags[ClusterNameTag] = cc.ClusterName

if err := cc.validateAvailabilityZones(awsClient); err != nil {
return errors.Wrap(err, AvailabilityZonesKey)
if len(cc.Subnets) > 0 {
if err := cc.validateSubnets(awsClient); err != nil {
return errors.Wrap(err, SubnetsKey)
}
} else {
if err := cc.setAvailabilityZones(awsClient); err != nil {
return errors.Wrap(err, AvailabilityZonesKey)
}
}

if cc.Spot != nil && *cc.Spot {
Expand Down Expand Up @@ -1095,6 +1140,9 @@ func (cc *Config) UserTable() table.KeyValuePairs {
if len(cc.AvailabilityZones) > 0 {
items.Add(AvailabilityZonesUserKey, cc.AvailabilityZones)
}
for _, subnetConfig := range cc.Subnets {
items.Add("subnet in "+subnetConfig.AvailabilityZone, subnetConfig.SubnetID)
}
items.Add(BucketUserKey, cc.Bucket)
items.Add(InstanceTypeUserKey, *cc.InstanceType)
items.Add(MinInstancesUserKey, *cc.MinInstances)
Expand Down Expand Up @@ -1184,6 +1232,11 @@ func (cc *Config) TelemetryEvent() map[string]interface{} {
event["availability_zones._len"] = len(cc.AvailabilityZones)
event["availability_zones"] = cc.AvailabilityZones
}
if len(cc.Subnets) > 0 {
event["subnets._is_defined"] = true
event["subnets._len"] = len(cc.Subnets)
event["subnets"] = cc.Subnets
}
if cc.SSLCertificateARN != nil {
event["ssl_certificate_arn._is_defined"] = true
}
Expand Down
Loading