Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process routes before peers #2105

Merged
merged 15 commits into from
Jun 19, 2024
6 changes: 3 additions & 3 deletions client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}

privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
Expand Down Expand Up @@ -801,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}

privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
Expand Down
45 changes: 31 additions & 14 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand Down Expand Up @@ -735,6 +736,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}

protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}

_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}

e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()

log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))

e.updateOfflinePeers(networkMap.GetOfflinePeers())
Expand Down Expand Up @@ -776,19 +791,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}

_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}

e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()

protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
Expand Down Expand Up @@ -1287,7 +1289,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default:
}

return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
}

func (e *Engine) wgInterfaceCreate() (err error) {
Expand Down Expand Up @@ -1485,6 +1487,21 @@ func (e *Engine) startNetworkMonitor() {
}()
}

func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}

return false, netip.Prefix{}, nil
}

// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
Expand Down
6 changes: 3 additions & 3 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -574,7 +574,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
Expand Down Expand Up @@ -745,7 +745,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")

mockRouteManager := &routemanager.MockManager{
Expand Down
28 changes: 0 additions & 28 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
Expand All @@ -15,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
Expand Down Expand Up @@ -763,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
return
}

if candidateViaRoutes(candidate, haRoutes) {
return
}

err := conn.agent.AddRemoteCandidate(candidate)
if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
Expand Down Expand Up @@ -797,25 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelPort: relatedAdd.Port,
})
}

func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var vpnRoutes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}

addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}

if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix)
return true
}

return false
}
2 changes: 1 addition & 1 deletion client/internal/routemanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down
3 changes: 2 additions & 1 deletion client/internal/routemanager/systemops/systemops_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac

// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)

exitNextHop = initialNextHop
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down Expand Up @@ -213,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

Expand Down Expand Up @@ -345,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet()
require.NoError(t, err)

wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WireGuard interface")

err = wgInterface.Create()
Expand Down
10 changes: 7 additions & 3 deletions iface/bind/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ type ICEBind struct {

transportNet transport.Net
udpMux *UniversalUDPMuxDefault

filterFn FilterFn
}

func NewICEBind(transportNet transport.Net) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}

rc := receiverCreator{
Expand All @@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC

s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
Expand Down
72 changes: 70 additions & 2 deletions iface/bind/udp_mux_universal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"

log "github.com/sirupsen/logrus"
Expand All @@ -17,6 +19,10 @@ import (
"github.com/pion/transport/v3"
)

// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)

// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
Expand All @@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
}

// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
Expand All @@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
}

// embed UDPMux
Expand Down Expand Up @@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
}

func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}

if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}

return u.handleUncachedAddress(b, addr)
}

func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}

func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}

func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}

a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}

if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}

func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
}

// GetSharedConn returns the shared udp conn
Expand Down
6 changes: 4 additions & 2 deletions iface/iface_android.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"fmt"

"github.com/pion/transport/v3"

"github.com/netbirdio/netbird/iface/bind"
)

// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
}

wgIFace := &WGIface{
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter),
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
Expand Down
Loading
Loading