Skip to content

Commit

Permalink
Add unit tests for pkg/agent/util/net.
Browse files Browse the repository at this point in the history
Signed-off-by: Qiyue Yao <yaoq@vmware.com>
  • Loading branch information
qiyueyao committed Jan 10, 2023
1 parent 114e14f commit 7548847
Show file tree
Hide file tree
Showing 6 changed files with 2,776 additions and 126 deletions.
39 changes: 24 additions & 15 deletions pkg/agent/util/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ const (
bridgedUplinkSuffix = "~"
)

var (
// netInterfaceByName, netInterfaceByIndex, netInterface, netInterfaceAddrs
// are meant to be overridden for testing.
netInterfaceByName = net.InterfaceByName
netInterfaceByIndex = net.InterfaceByIndex
netInterface = net.Interfaces
netInterfaceAddrs = (*net.Interface).Addrs
)

func generateInterfaceName(key string, name string, useHead bool) string {
hash := sha1.New() // #nosec G401: not used for security purposes
io.WriteString(hash, key)
Expand Down Expand Up @@ -119,7 +128,7 @@ func dialUnix(address string) (net.Conn, error) {
// GetIPNetDeviceFromIP returns local IPs/masks and associated device from IP, and ignores the interfaces which have
// names in the ignoredInterfaces.
func GetIPNetDeviceFromIP(localIPs *ip.DualStackIPs, ignoredInterfaces sets.String) (v4IPNet *net.IPNet, v6IPNet *net.IPNet, iface *net.Interface, err error) {
linkList, err := net.Interfaces()
linkList, err := netInterface()
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -139,7 +148,7 @@ func GetIPNetDeviceFromIP(localIPs *ip.DualStackIPs, ignoredInterfaces sets.Stri
if ignoredInterfaces.Has(link.Name) {
continue
}
addrList, err := link.Addrs()
addrList, err := netInterfaceAddrs(&link)
if err != nil {
continue
}
Expand All @@ -166,11 +175,11 @@ func GetIPNetDeviceFromIP(localIPs *ip.DualStackIPs, ignoredInterfaces sets.Stri
}

func GetIPNetDeviceByName(ifaceName string) (v4IPNet *net.IPNet, v6IPNet *net.IPNet, link *net.Interface, err error) {
link, err = net.InterfaceByName(ifaceName)
link, err = netInterfaceByName(ifaceName)
if err != nil {
return nil, nil, nil, err
}
addrList, err := link.Addrs()
addrList, err := netInterfaceAddrs(link)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -212,12 +221,12 @@ func GetIPNetDeviceByCIDRs(cidrsList []string) (v4IPNet, v6IPNet *net.IPNet, lin
return nil, nil, nil, fmt.Errorf("length of cidrs is %v more than max allowed of 2", len(cidrs))
}

ifaces, err := net.Interfaces()
ifaces, err := netInterface()
if err != nil {
return nil, nil, nil, err
}
for _, i := range ifaces {
addresses, err := i.Addrs()
for i := range ifaces {
addresses, err := netInterfaceAddrs(&ifaces[i])
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -238,19 +247,19 @@ func GetIPNetDeviceByCIDRs(cidrsList []string) (v4IPNet, v6IPNet *net.IPNet, lin
}
}
if v4IPNet != nil || v6IPNet != nil {
return v4IPNet, v6IPNet, &i, nil
return v4IPNet, v6IPNet, &ifaces[i], nil
}
}
return nil, nil, nil, fmt.Errorf("unable to find local IP and device")
}

func GetAllIPNetsByName(ifaceName string) ([]*net.IPNet, error) {
ips := []*net.IPNet{}
adapter, err := net.InterfaceByName(ifaceName)
adapter, err := netInterfaceByName(ifaceName)
if err != nil {
return nil, err
}
addrs, _ := adapter.Addrs()
addrs, _ := netInterfaceAddrs(adapter)
for _, addr := range addrs {
if ip, ipNet, err := net.ParseCIDR(addr.String()); err != nil {
klog.Warningf("Unable to parse addr %+v, err=%+v", addr, err)
Expand Down Expand Up @@ -343,22 +352,22 @@ func GetAllNodeAddresses(excludeDevices []string) ([]net.IP, []net.IP, error) {
_, ipv6LinkLocalNet, _ := net.ParseCIDR("fe80::/64")

// Get all interfaces.
interfaces, err := net.Interfaces()
interfaces, err := netInterface()
if err != nil {
return nil, nil, err
}

// Transform excludeDevices to a set
excludeDevicesSet := sets.NewString(excludeDevices...)

for _, itf := range interfaces {
for i := range interfaces {
// If the device is in excludeDevicesSet, skip it.
if excludeDevicesSet.Has(itf.Name) {
if excludeDevicesSet.Has(interfaces[i].Name) {
continue
}

// Get all IPs of every interface
addrs, err := itf.Addrs()
addrs, err := netInterfaceAddrs(&interfaces[i])
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -413,7 +422,7 @@ func GenerateRandomMAC() net.HardwareAddr {
}

func GetIPNetsByLink(link *net.Interface) ([]*net.IPNet, error) {
addrList, err := link.Addrs()
addrList, err := netInterfaceAddrs(link)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7548847

Please sign in to comment.