diff --git a/apis/v1beta1/util/validation/gateway.go b/apis/v1beta1/util/validation/gateway.go new file mode 100644 index 0000000000..e25c0a2d14 --- /dev/null +++ b/apis/v1beta1/util/validation/gateway.go @@ -0,0 +1,41 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validation + +import gatewayv1b1 "sigs.k8s.io/gateway-api/apis/v1beta1" + +// ContainsInProtocolSlice checks whether the provided Protocol +// is in the target Protocol slice. +func ContainsInProtocolSlice(items []gatewayv1b1.ProtocolType, item *gatewayv1b1.ProtocolType) bool { + for _, eachItem := range items { + if eachItem == *item { + return true + } + } + return false +} + +// ContainsInPortSlice checks whether the provided Port +// is in the target Port slice. +func ContainsInPortSlice(items []gatewayv1b1.PortNumber, item *gatewayv1b1.PortNumber) bool { + for _, eachItem := range items { + if eachItem == *item { + return true + } + } + return false +} \ No newline at end of file diff --git a/apis/v1beta1/validation/gateway.go b/apis/v1beta1/validation/gateway.go index c755eae6f9..89f559aaa1 100644 --- a/apis/v1beta1/validation/gateway.go +++ b/apis/v1beta1/validation/gateway.go @@ -22,6 +22,7 @@ import ( "k8s.io/apimachinery/pkg/util/validation/field" gatewayv1b1 "sigs.k8s.io/gateway-api/apis/v1beta1" + utils "sigs.k8s.io/gateway-api/apis/v1beta1/util/validation" ) var ( @@ -68,8 +69,7 @@ func validateGatewayListeners(listeners []gatewayv1b1.Listener, path *field.Path errs = append(errs, ValidateListenerTLSConfig(listeners, path)...) errs = append(errs, validateListenerHostname(listeners, path)...) errs = append(errs, ValidateTLSCertificateRefs(listeners, path)...) - errs = append(errs, validateUniqueProtocol(listeners, path)...) - errs = append(errs, validateUniquePort(listeners, path)...) + errs = append(errs, ValidateHostnameProtocolPort(listeners, path)...) return errs } @@ -120,55 +120,39 @@ func ValidateTLSCertificateRefs(listeners []gatewayv1b1.Listener, path *field.Pa return errs } -// validateUniqueProtocol validates each listener hostname -// should not have the duplicate protocols -func validateUniqueProtocol(listeners []gatewayv1b1.Listener, path *field.Path) field.ErrorList { + +// validateHostnameProtocolPort validates listener protocols or ports +// must be different for the same hostname +func ValidateHostnameProtocolPort(listeners []gatewayv1b1.Listener, path *field.Path) field.ErrorList { var errs field.ErrorList - hostnameProtocolUnique := make(map[gatewayv1b1.Hostname]map[gatewayv1b1.ProtocolType]int) - for i, h := range listeners { - if h.Hostname == nil { - continue + hostnameProtocolMap := make(map[gatewayv1b1.Hostname][]gatewayv1b1.ProtocolType) + hostnamePortMap := make(map[gatewayv1b1.Hostname][]gatewayv1b1.PortNumber) + for i, listener := range listeners { + targetHostname := *new(gatewayv1b1.Hostname) + if listener.Hostname != nil { + targetHostname = *listener.Hostname } - if len(hostnameProtocolUnique) == 0 { - hostnameProtocolUnique[*h.Hostname] = map[gatewayv1b1.ProtocolType]int{h.Protocol: i} - continue + if len(listener.Protocol) == 0 && listener.Port == 0 { + errs = append(errs, field.Forbidden(path.Index(i).Child("hostname"), fmt.Sprintf("Protocol and Port should not all be empty for the same Hostname: %s", targetHostname))) + return errs } - if _, hostnameFound := hostnameProtocolUnique[*h.Hostname]; hostnameFound { - if _, protocolFound := hostnameProtocolUnique[*h.Hostname][h.Protocol]; protocolFound { - errs = append(errs, field.Forbidden(path.Index(i).Child("protocol"), fmt.Sprintf("should be unique in hostname: %v", *h.Hostname))) - return errs - } else { - hostnameProtocolUnique[*h.Hostname][h.Protocol] = i + if len(listener.Protocol) != 0 { + if protocolSlice, ok := hostnameProtocolMap[targetHostname]; ok { + if utils.ContainsInProtocolSlice(protocolSlice, &listener.Protocol) { + errs = append(errs, field.Forbidden(path.Index(i).Child("protocol"), fmt.Sprintf("must be different for the same Hostname: %s", targetHostname))) + return errs + } } - } else { - hostnameProtocolUnique[*h.Hostname] = map[gatewayv1b1.ProtocolType]int{h.Protocol: i} - } - } - return errs -} - -// validateUniquePort validates each listener hostname -// should not have the duplicate ports -func validateUniquePort(listeners []gatewayv1b1.Listener, path *field.Path) field.ErrorList { - var errs field.ErrorList - hostnamePortUnique := make(map[gatewayv1b1.Hostname]map[gatewayv1b1.PortNumber]int) - for i, h := range listeners { - if h.Hostname == nil { - continue - } - if len(hostnamePortUnique) == 0 { - hostnamePortUnique[*h.Hostname] = map[gatewayv1b1.PortNumber]int{h.Port: i} - continue + hostnameProtocolMap[targetHostname] = append(hostnameProtocolMap[targetHostname], listener.Protocol) } - if _, hostnameFound := hostnamePortUnique[*h.Hostname]; hostnameFound { - if _, portFound := hostnamePortUnique[*h.Hostname][h.Port]; portFound { - errs = append(errs, field.Forbidden(path.Index(i).Child("port"), fmt.Sprintf("should be unique in hostname: %v", *h.Hostname))) - return errs - } else { - hostnamePortUnique[*h.Hostname][h.Port] = i + if listener.Port != 0 { + if portSlice, ok := hostnamePortMap[targetHostname]; ok { + if utils.ContainsInPortSlice(portSlice, &listener.Port) { + errs = append(errs, field.Forbidden(path.Index(i).Child("port"), fmt.Sprintf("must be different for the same Hostname: %s", targetHostname))) + return errs + } } - } else { - hostnamePortUnique[*h.Hostname] = map[gatewayv1b1.PortNumber]int{h.Port: i} + hostnamePortMap[targetHostname] = append(hostnamePortMap[targetHostname], listener.Port) } } return errs diff --git a/apis/v1beta1/validation/gateway_test.go b/apis/v1beta1/validation/gateway_test.go index 717f011070..52bc01d39a 100644 --- a/apis/v1beta1/validation/gateway_test.go +++ b/apis/v1beta1/validation/gateway_test.go @@ -29,9 +29,6 @@ func TestValidateGateway(t *testing.T) { { Hostname: nil, }, - { - Hostname: nil, - }, } addresses := []gatewayv1b1.GatewayAddress{ { @@ -137,29 +134,16 @@ func TestValidateGateway(t *testing.T) { }, expectErrsOnFields: []string{"spec.listeners[0].tls.certificateRefs"}, }, - "protocal is not unique in hostname": { - mutate: func(gw *gatewayv1b1.Gateway) { - hostname := gatewayv1b1.Hostname("foo.bar.com") - gw.Spec.Listeners[0].Hostname = &hostname - gw.Spec.Listeners[0].Protocol = gatewayv1b1.HTTPProtocolType - gw.Spec.Listeners[0].Port = gatewayv1b1.PortNumber(80) - gw.Spec.Listeners[1].Hostname = &hostname - gw.Spec.Listeners[1].Protocol = gatewayv1b1.HTTPProtocolType - gw.Spec.Listeners[1].Port = gatewayv1b1.PortNumber(81) - }, - expectErrsOnFields: []string{"spec.listeners[1].protocol"}, - }, - "port is not unique in hostname": { + "ports are not different for the same hostname": { mutate: func(gw *gatewayv1b1.Gateway) { - hostname := gatewayv1b1.Hostname("foo.bar.com") - gw.Spec.Listeners[0].Hostname = &hostname - gw.Spec.Listeners[0].Protocol = gatewayv1b1.HTTPProtocolType - gw.Spec.Listeners[0].Port = gatewayv1b1.PortNumber(80) - gw.Spec.Listeners[1].Hostname = &hostname - gw.Spec.Listeners[1].Protocol = gatewayv1b1.HTTPSProtocolType - gw.Spec.Listeners[1].Port = gatewayv1b1.PortNumber(80) + moreListeners := []gatewayv1b1.Listener { + { + Hostname: nil, + }, + } + gw.Spec.Listeners = append(gw.Spec.Listeners, moreListeners...) }, - expectErrsOnFields: []string{"spec.listeners[1].port"}, + expectErrsOnFields: []string{"spec.listeners[0].hostname"}, }, }