Skip to content

Commit

Permalink
Check routes when checking for network changes (#132)
Browse files Browse the repository at this point in the history
Fixes #127
  • Loading branch information
samuong authored Jul 31, 2024
1 parent fbfe6a9 commit c4e31c8
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 26 deletions.
58 changes: 51 additions & 7 deletions netmonitor.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2024 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@ package main
import (
"log"
"net"
"slices"
)

type netMonitor interface {
Expand All @@ -25,11 +26,13 @@ type netMonitor interface {

type netMonitorImpl struct {
addrs map[string]struct{}
routes []net.IP
getAddrs func() ([]net.Addr, error)
dial func(network, addr string) (net.Conn, error)
}

func newNetMonitor() *netMonitorImpl {
return &netMonitorImpl{getAddrs: net.InterfaceAddrs}
return &netMonitorImpl{getAddrs: net.InterfaceAddrs, dial: net.Dial}
}

func (nm *netMonitorImpl) addrsChanged() bool {
Expand All @@ -39,13 +42,24 @@ func (nm *netMonitorImpl) addrsChanged() bool {
return false
}
set := addrSliceToSet(addrs)
if setsAreEqual(set, nm.addrs) {
// Probe for routes to a set of remote addresses. These addresses are
// the same as those used by myIpAddressEx.
// TODO: Cache the results so they don't need to be recalculated in
// myIpAddress (and myIpAddressEx, when implemented).
remotes := []string{
"8.8.8.8", "2001:4860:4860::8888", // public addresses
"10.0.0.0", "172.16.0.0", "192.168.0.0", "FC00::", // private addresses
}
locals := make([]net.IP, len(remotes))
for i, remote := range remotes {
locals[i] = nm.probeRoute(remote, false)
}
if setsAreEqual(set, nm.addrs) && slices.EqualFunc(locals, nm.routes, net.IP.Equal) {
return false
} else {
log.Printf("Network changes detected: %v", addrs)
nm.addrs = set
return true
}
nm.addrs = set
nm.routes = locals
return true
}

func addrSliceToSet(slice []net.Addr) map[string]struct{} {
Expand All @@ -67,3 +81,33 @@ func setsAreEqual(a, b map[string]struct{}) bool {
}
return true
}

// probeRoute creates a UDP "connection" to the remote address, and returns the
// local interface address. This does involve a system call, but does not
// generate any network traffic since UDP is a connectionless protocol.
func (nm *netMonitorImpl) probeRoute(host string, ipv4only bool) net.IP {
var network string
if ipv4only {
network = "udp4"
} else {
network = "udp"
}
conn, err := nm.dial(network, net.JoinHostPort(host, "80"))
if err != nil {
return nil
}
defer conn.Close()
local, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
// Since we called dial with network set to "udp4" or "udp", we
// expect this to be a *net.UDPAddr. If this fails, it's a bug
// in Alpaca, and hopefully users will report it. But it's not
// worth panicking over so we won't end the request here.
log.Printf("unexpected: probeRoute host=%q ipv4only=%t: %v", host, ipv4only, err)
return nil
}
if ip := local.IP; ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return nil
}
return local.IP
}
159 changes: 140 additions & 19 deletions netmonitor_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 The Alpaca Authors
// Copyright 2019, 2024 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,34 @@ package main

import (
"errors"
"math/rand/v2"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

// In order to test netMonitor, we use mock implementations of the
// net.InterfaceAddrs() and net.Dial() functions, as well as the net.Addr and
// net.Conn types. The mocks below will implement just enough functionality to
// allow the tests to run, and will panic if unimplemented functions are
// called. We simulate three network states: "offline", "wifi" and "vpn".
//
// In "offline" mode, only the loopback addresses exist, and attempts to dial
// anywhere will result in a "network is unreachable" error.
//
// In "wifi" mode, in addition to the loopback addresses, we've also got an IP
// address in the 192.168.1.0/24 range, which is meant to look like a home wifi
// router, and we simulate a routing table that routes everything through this
// interface.
//
// In "vpn" mode, we've got the same IP addresses as in "wifi" mode, and any
// connection attempts will be routed via an address in the 10.0.0.0/8 range
// (this is meant to look like a private corporate network). Note that our
// 10.0.0.0/8 address does *not* appear in the output of net.InterfaceAddrs()
// because apparently some VPN clients behave like this.

type mockAddr string

func (a mockAddr) Network() string {
Expand All @@ -40,28 +62,127 @@ func toAddrs(ss ...string) []net.Addr {
return addrs
}

type mockConn struct {
localAddr net.Addr
}

var _ net.Conn = mockConn{}

func (c mockConn) Close() error {
return nil
}

func (c mockConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c mockConn) Read(b []byte) (n int, err error) {
panic("unreachable")
}

func (c mockConn) RemoteAddr() net.Addr {
panic("unreachable")
}

func (c mockConn) SetDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) SetReadDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) SetWriteDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) Write(b []byte) (n int, err error) {
panic("unreachable")
}

type mockNet struct {
state string
}

func (n *mockNet) interfaceAddrs() ([]net.Addr, error) {
var addrs []net.Addr
switch n.state {
case "vpn", "wifi":
addrs = append(addrs, toAddrs("192.168.1.2/24", "fe80::fedc:ba98:7654:3210/64")...)
fallthrough
case "offline":
addrs = append(addrs, toAddrs("127.0.0.1/8", "::1/128")...)
default:
panic("interfaceAddrs state=" + n.state)
}
return addrs, nil
}

func (n *mockNet) dial(network, address string) (net.Conn, error) {
if network != "udp" && network != "udp4" {
panic("dial network=" + network)
}
host, _, err := net.SplitHostPort(address)
if err != nil {
panic("dial: " + err.Error())
}
ip := net.ParseIP(host)
if ip == nil {
panic("dial host=" + host)
}
if ipv4 := ip.To4(); ipv4 == nil {
// Pretend we can't route to any IPv6 addresses.
return nil, newDialError(network, address, "connect: no route to host")
}
switch n.state {
case "vpn":
// Pretend we're routing through a corporate VPN.
return newMockConn(10, 0, 0, 3), nil
case "wifi":
// Pretend we're routing through a home WiFi router.
return newMockConn(192, 168, 1, 2), nil
case "offline":
// Pretend we can't route to anywhere.
return nil, newDialError(network, address, "connect: network is unreachable")
default:
panic("dial state=" + n.state)
}
}

func newMockConn(a, b, c, d byte) mockConn {
// Pretend that the operating system has assigned a random port (in the
// range 1024 to 65535) on the outbound connection. This allows us to
// test that Alpaca doesn't think the routing table has changed just
// because the port is different on each call to net.Dial(); we only
// need to consider the outgoing IP address without the port number.
return mockConn{
localAddr: &net.UDPAddr{
IP: net.IPv4(a, b, c, d),
Port: rand.IntN(65535-1024) + 1024,
},
}
}

func newDialError(network, address, text string) *net.OpError {
return &net.OpError{
Op: "dial",
Net: network,
Source: nil,
Addr: mockAddr(address),
Err: errors.New(text),
}
}

func TestNetworkMonitor(t *testing.T) {
var next []net.Addr
nm := &netMonitorImpl{getAddrs: func() ([]net.Addr, error) { return next, nil }}
// Start with just loopback interfaces
next = toAddrs("127.0.0.1/8", "::1/128")
var network mockNet
nm := &netMonitorImpl{getAddrs: network.interfaceAddrs, dial: network.dial}
network.state = "offline"
assert.True(t, nm.addrsChanged())
// Connect to network, and get local IPv4 and IPv6 addresses
next = toAddrs("127.0.0.1/8", "192.168.1.6/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3a/64")
network.state = "wifi"
assert.True(t, nm.addrsChanged())
// Stay connected, nothing changed
next = toAddrs("127.0.0.1/8", "192.168.1.6/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3a/64")
assert.False(t, nm.addrsChanged())
// DHCP lease expires, get new addresses
next = toAddrs("127.0.0.1/8", "192.168.1.7/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3b/64")
network.state = "vpn"
assert.True(t, nm.addrsChanged())
// Disconnect, and go back to having just loopback addresses
next = toAddrs("127.0.0.1/8", "::1/128")
network.state = "offline"
assert.True(t, nm.addrsChanged())
}

func TestFailToGetAddrs(t *testing.T) {
alwaysFail := func() ([]net.Addr, error) { return nil, errors.New("failed") }
nm := &netMonitorImpl{getAddrs: alwaysFail}
assert.False(t, nm.addrsChanged())
}

0 comments on commit c4e31c8

Please sign in to comment.