diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go index ab8d6d92d21a0..4f42c944794c7 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go @@ -3665,6 +3665,27 @@ func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts return listener, nil } +func (c *Cloud) getSubnetCidrs(subnetIDs []string) ([]string, error) { + request := &ec2.DescribeSubnetsInput{} + for _, subnetID := range subnetIDs { + request.SubnetIds = append(request.SubnetIds, aws.String(subnetID)) + } + + subnets, err := c.ec2.DescribeSubnets(request) + if err != nil { + return nil, fmt.Errorf("error querying Subnet for ELB: %q", err) + } + if len(subnets) != len(subnetIDs) { + return nil, fmt.Errorf("error querying Subnet for ELB, got %d subnets for %v", len(subnets), subnetIDs) + } + + cidrs := make([]string, 0, len(subnets)) + for _, subnet := range subnets { + cidrs = append(cidrs, aws.StringValue(subnet.CidrBlock)) + } + return cidrs, nil +} + // EnsureLoadBalancer implements LoadBalancer.EnsureLoadBalancer func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiService *v1.Service, nodes []*v1.Node) (*v1.LoadBalancerStatus, error) { annotations := apiService.Annotations @@ -3796,6 +3817,12 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, err } + subnetCidrs, err := c.getSubnetCidrs(subnetIDs) + if err != nil { + klog.Errorf("Error getting subnet cidrs: %q", err) + return nil, err + } + sourceRangeCidrs := []string{} for cidr := range sourceRanges { sourceRangeCidrs = append(sourceRangeCidrs, cidr) @@ -3804,7 +3831,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS sourceRangeCidrs = append(sourceRangeCidrs, "0.0.0.0/0") } - err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, sourceRangeCidrs, v2Mappings) + err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, subnetCidrs, sourceRangeCidrs, v2Mappings) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err @@ -4381,7 +4408,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } } - return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil) + return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil, nil) } lb, err := c.describeLoadBalancer(loadBalancerName) diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go index d02fdceb31567..8c45fe6a521a2 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go @@ -723,30 +723,9 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty return targetGroup, nil } -func (c *Cloud) getVpcCidrBlocks() ([]string, error) { - vpcs, err := c.ec2.DescribeVpcs(&ec2.DescribeVpcsInput{ - VpcIds: []*string{aws.String(c.vpcID)}, - }) - if err != nil { - return nil, fmt.Errorf("error querying VPC for ELB: %q", err) - } - if len(vpcs.Vpcs) != 1 { - return nil, fmt.Errorf("error querying VPC for ELB, got %d vpcs for %s", len(vpcs.Vpcs), c.vpcID) - } - - cidrBlocks := make([]string, 0, len(vpcs.Vpcs[0].CidrBlockAssociationSet)) - for _, cidr := range vpcs.Vpcs[0].CidrBlockAssociationSet { - if aws.StringValue(cidr.CidrBlockState.State) != ec2.VpcCidrBlockStateCodeAssociated { - continue - } - cidrBlocks = append(cidrBlocks, aws.StringValue(cidr.CidrBlock)) - } - return cidrBlocks, nil -} - // updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. // TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. -func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, clientCIDRs []string, portMappings []nlbPortMapping) error { +func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } @@ -794,14 +773,10 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ } clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) - vpcCIDRs, err := c.getVpcCidrBlocks() - if err != nil { - return err - } for sgID, sg := range clusterSGs { sgPerms := NewIPPermissionSet(sg.IpPermissions...).Ungroup() if desiredSGIDs.Has(sgID) { - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, vpcCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, subnetCIDRs); err != nil { return err } if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil {