Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace string with netip.AddrPort #244

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions shadowaead/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"net/netip"
"sync"

"github.com/shadowsocks/go-shadowsocks2/internal"
Expand Down Expand Up @@ -73,6 +74,9 @@ type packetConn struct {
// NewPacketConn wraps a net.PacketConn with cipher
func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
const maxPacketSize = 64 * 1024
if cc, ok := c.(*net.UDPConn); ok {
return &udpConn{UDPConn: cc, Cipher: ciph, buf: make([]byte, maxPacketSize)}
}
return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)}
}

Expand Down Expand Up @@ -101,3 +105,62 @@ func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
copy(b, bb)
return len(bb), addr, err
}

type udpConn struct {
*net.UDPConn
Cipher
sync.Mutex
buf []byte // write lock
}

// WriteTo encrypts b and write to addr using the embedded UDPConn.
func (c *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c)
if err != nil {
return 0, err
}
_, err = c.UDPConn.WriteTo(buf, addr)
return len(b), err
}

// ReadFrom reads from the embedded UDPConn and decrypts into b.
func (c *udpConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPConn.ReadFrom(b)
if err != nil {
return n, addr, err
}
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
if err != nil {
return n, addr, err
}
copy(b, bb)
return len(bb), addr, err
}

// WriteToUDPAddrPort encrypts b and write to addr using the embedded PacketConn.
func (c *udpConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c)
if err != nil {
return 0, err
}
_, err = c.UDPConn.WriteToUDPAddrPort(buf, addr)
return len(b), err
}

// ReadFromUDPAddrPort reads from the embedded UDPConn and decrypts into b.
func (c *udpConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
n, addr, err := c.UDPConn.ReadFromUDPAddrPort(b)
if err != nil {
return n, addr, err
}
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
if err != nil {
return n, addr, err
}
copy(b, bb)
return len(bb), addr, err
}
70 changes: 47 additions & 23 deletions udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"net"
"net/netip"
"sync"
"time"

Expand Down Expand Up @@ -34,7 +35,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
return
}

c, err := net.ListenPacket("udp", laddr)
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
logf("UDP listen address error: %v", err)
return
}

c, err := net.ListenUDP("udp", lnAddr)
if err != nil {
logf("UDP local listen error: %v", err)
return
Expand All @@ -47,13 +54,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack

logf("UDP tunnel %s <-> %s <-> %s", laddr, server, target)
for {
n, raddr, err := c.ReadFrom(buf[len(tgt):])
n, raddr, err := c.ReadFromUDPAddrPort(buf[len(tgt):])
if err != nil {
logf("UDP local read error: %v", err)
continue
}

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand Down Expand Up @@ -81,7 +88,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
return
}

c, err := net.ListenPacket("udp", laddr)
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
logf("UDP listen address error: %v", err)
return
}

c, err := net.ListenUDP("udp", lnAddr)
if err != nil {
logf("UDP local listen error: %v", err)
return
Expand All @@ -92,13 +105,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
buf := make([]byte, udpBufSize)

for {
n, raddr, err := c.ReadFrom(buf)
n, raddr, err := c.ReadFromUDPAddrPort(buf)
if err != nil {
logf("UDP local read error: %v", err)
continue
}

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand All @@ -118,22 +131,33 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
}
}

type UDPConn interface {
net.PacketConn
ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
}

// Listen on addr for encrypted packets and basically do UDP NAT.
func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
c, err := net.ListenPacket("udp", addr)
nAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
logf("UDP server address error: %v", err)
return
}
cc, err := net.ListenUDP("udp", nAddr)
if err != nil {
logf("UDP remote listen error: %v", err)
return
}
defer c.Close()
c = shadow(c)
defer cc.Close()
c := shadow(cc).(UDPConn)

nm := newNATmap(config.UDPTimeout)
buf := make([]byte, udpBufSize)

logf("listening UDP on %s", addr)
for {
n, raddr, err := c.ReadFrom(buf)
n, raddr, err := c.ReadFromUDPAddrPort(buf)
if err != nil {
logf("UDP remote read error: %v", err)
continue
Expand All @@ -153,7 +177,7 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {

payload := buf[len(tgtAddr):n]

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand All @@ -175,31 +199,31 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
// Packet NAT table
type natmap struct {
sync.RWMutex
m map[string]net.PacketConn
m map[netip.AddrPort]net.PacketConn
timeout time.Duration
}

func newNATmap(timeout time.Duration) *natmap {
m := &natmap{}
m.m = make(map[string]net.PacketConn)
m.m = make(map[netip.AddrPort]net.PacketConn)
m.timeout = timeout
return m
}

func (m *natmap) Get(key string) net.PacketConn {
func (m *natmap) Get(key netip.AddrPort) net.PacketConn {
m.RLock()
defer m.RUnlock()
return m.m[key]
}

func (m *natmap) Set(key string, pc net.PacketConn) {
func (m *natmap) Set(key netip.AddrPort, pc net.PacketConn) {
m.Lock()
defer m.Unlock()

m.m[key] = pc
}

func (m *natmap) Del(key string) net.PacketConn {
func (m *natmap) Del(key netip.AddrPort) net.PacketConn {
m.Lock()
defer m.Unlock()

Expand All @@ -211,19 +235,19 @@ func (m *natmap) Del(key string) net.PacketConn {
return nil
}

func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) {
m.Set(peer.String(), src)
func (m *natmap) Add(peer netip.AddrPort, dst UDPConn, src net.PacketConn, role mode) {
m.Set(peer, src)

go func() {
timedCopy(dst, peer, src, m.timeout, role)
if pc := m.Del(peer.String()); pc != nil {
if pc := m.Del(peer); pc != nil {
pc.Close()
}
}()
}

// copy from src to dst at target with read timeout
func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, role mode) error {
func timedCopy(dst UDPConn, target netip.AddrPort, src net.PacketConn, timeout time.Duration, role mode) error {
buf := make([]byte, udpBufSize)

for {
Expand All @@ -238,12 +262,12 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout
srcAddr := socks.ParseAddr(raddr.String())
copy(buf[len(srcAddr):], buf[:n])
copy(buf, srcAddr)
_, err = dst.WriteTo(buf[:len(srcAddr)+n], target)
_, err = dst.WriteToUDPAddrPort(buf[:len(srcAddr)+n], target)
case relayClient: // client -> user: strip original packet source
srcAddr := socks.SplitAddr(buf[:n])
_, err = dst.WriteTo(buf[len(srcAddr):n], target)
_, err = dst.WriteToUDPAddrPort(buf[len(srcAddr):n], target)
case socksClient: // client -> socks5 program: just set RSV and FRAG = 0
_, err = dst.WriteTo(append([]byte{0, 0, 0}, buf[:n]...), target)
_, err = dst.WriteToUDPAddrPort(append([]byte{0, 0, 0}, buf[:n]...), target)
}

if err != nil {
Expand Down