Skip to content

Commit

Permalink
Merge pull request shadowsocks#244 from imgk/master
Browse files Browse the repository at this point in the history
replace string with  netip.AddrPort
  • Loading branch information
riobard authored Oct 20, 2024
2 parents 43c8294 + c2cf969 commit e82e917
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 23 deletions.
63 changes: 63 additions & 0 deletions shadowaead/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"net/netip"
"sync"

"github.com/shadowsocks/go-shadowsocks2/internal"
Expand Down Expand Up @@ -73,6 +74,9 @@ type packetConn struct {
// NewPacketConn wraps a net.PacketConn with cipher
func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
const maxPacketSize = 64 * 1024
if cc, ok := c.(*net.UDPConn); ok {
return &udpConn{UDPConn: cc, Cipher: ciph, buf: make([]byte, maxPacketSize)}
}
return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)}
}

Expand Down Expand Up @@ -101,3 +105,62 @@ func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
copy(b, bb)
return len(bb), addr, err
}

type udpConn struct {
*net.UDPConn
Cipher
sync.Mutex
buf []byte // write lock
}

// WriteTo encrypts b and write to addr using the embedded UDPConn.
func (c *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c)
if err != nil {
return 0, err
}
_, err = c.UDPConn.WriteTo(buf, addr)
return len(b), err
}

// ReadFrom reads from the embedded UDPConn and decrypts into b.
func (c *udpConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPConn.ReadFrom(b)
if err != nil {
return n, addr, err
}
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
if err != nil {
return n, addr, err
}
copy(b, bb)
return len(bb), addr, err
}

// WriteToUDPAddrPort encrypts b and write to addr using the embedded PacketConn.
func (c *udpConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c)
if err != nil {
return 0, err
}
_, err = c.UDPConn.WriteToUDPAddrPort(buf, addr)
return len(b), err
}

// ReadFromUDPAddrPort reads from the embedded UDPConn and decrypts into b.
func (c *udpConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
n, addr, err := c.UDPConn.ReadFromUDPAddrPort(b)
if err != nil {
return n, addr, err
}
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
if err != nil {
return n, addr, err
}
copy(b, bb)
return len(bb), addr, err
}
70 changes: 47 additions & 23 deletions udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"net"
"net/netip"
"sync"
"time"

Expand Down Expand Up @@ -34,7 +35,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
return
}

c, err := net.ListenPacket("udp", laddr)
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
logf("UDP listen address error: %v", err)
return
}

c, err := net.ListenUDP("udp", lnAddr)
if err != nil {
logf("UDP local listen error: %v", err)
return
Expand All @@ -47,13 +54,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack

logf("UDP tunnel %s <-> %s <-> %s", laddr, server, target)
for {
n, raddr, err := c.ReadFrom(buf[len(tgt):])
n, raddr, err := c.ReadFromUDPAddrPort(buf[len(tgt):])
if err != nil {
logf("UDP local read error: %v", err)
continue
}

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand Down Expand Up @@ -81,7 +88,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
return
}

c, err := net.ListenPacket("udp", laddr)
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
logf("UDP listen address error: %v", err)
return
}

c, err := net.ListenUDP("udp", lnAddr)
if err != nil {
logf("UDP local listen error: %v", err)
return
Expand All @@ -92,13 +105,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
buf := make([]byte, udpBufSize)

for {
n, raddr, err := c.ReadFrom(buf)
n, raddr, err := c.ReadFromUDPAddrPort(buf)
if err != nil {
logf("UDP local read error: %v", err)
continue
}

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand All @@ -118,22 +131,33 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
}
}

type UDPConn interface {
net.PacketConn
ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
}

// Listen on addr for encrypted packets and basically do UDP NAT.
func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
c, err := net.ListenPacket("udp", addr)
nAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
logf("UDP server address error: %v", err)
return
}
cc, err := net.ListenUDP("udp", nAddr)
if err != nil {
logf("UDP remote listen error: %v", err)
return
}
defer c.Close()
c = shadow(c)
defer cc.Close()
c := shadow(cc).(UDPConn)

nm := newNATmap(config.UDPTimeout)
buf := make([]byte, udpBufSize)

logf("listening UDP on %s", addr)
for {
n, raddr, err := c.ReadFrom(buf)
n, raddr, err := c.ReadFromUDPAddrPort(buf)
if err != nil {
logf("UDP remote read error: %v", err)
continue
Expand All @@ -153,7 +177,7 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {

payload := buf[len(tgtAddr):n]

pc := nm.Get(raddr.String())
pc := nm.Get(raddr)
if pc == nil {
pc, err = net.ListenPacket("udp", "")
if err != nil {
Expand All @@ -175,31 +199,31 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
// Packet NAT table
type natmap struct {
sync.RWMutex
m map[string]net.PacketConn
m map[netip.AddrPort]net.PacketConn
timeout time.Duration
}

func newNATmap(timeout time.Duration) *natmap {
m := &natmap{}
m.m = make(map[string]net.PacketConn)
m.m = make(map[netip.AddrPort]net.PacketConn)
m.timeout = timeout
return m
}

func (m *natmap) Get(key string) net.PacketConn {
func (m *natmap) Get(key netip.AddrPort) net.PacketConn {
m.RLock()
defer m.RUnlock()
return m.m[key]
}

func (m *natmap) Set(key string, pc net.PacketConn) {
func (m *natmap) Set(key netip.AddrPort, pc net.PacketConn) {
m.Lock()
defer m.Unlock()

m.m[key] = pc
}

func (m *natmap) Del(key string) net.PacketConn {
func (m *natmap) Del(key netip.AddrPort) net.PacketConn {
m.Lock()
defer m.Unlock()

Expand All @@ -211,19 +235,19 @@ func (m *natmap) Del(key string) net.PacketConn {
return nil
}

func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) {
m.Set(peer.String(), src)
func (m *natmap) Add(peer netip.AddrPort, dst UDPConn, src net.PacketConn, role mode) {
m.Set(peer, src)

go func() {
timedCopy(dst, peer, src, m.timeout, role)
if pc := m.Del(peer.String()); pc != nil {
if pc := m.Del(peer); pc != nil {
pc.Close()
}
}()
}

// copy from src to dst at target with read timeout
func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, role mode) error {
func timedCopy(dst UDPConn, target netip.AddrPort, src net.PacketConn, timeout time.Duration, role mode) error {
buf := make([]byte, udpBufSize)

for {
Expand All @@ -238,12 +262,12 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout
srcAddr := socks.ParseAddr(raddr.String())
copy(buf[len(srcAddr):], buf[:n])
copy(buf, srcAddr)
_, err = dst.WriteTo(buf[:len(srcAddr)+n], target)
_, err = dst.WriteToUDPAddrPort(buf[:len(srcAddr)+n], target)
case relayClient: // client -> user: strip original packet source
srcAddr := socks.SplitAddr(buf[:n])
_, err = dst.WriteTo(buf[len(srcAddr):n], target)
_, err = dst.WriteToUDPAddrPort(buf[len(srcAddr):n], target)
case socksClient: // client -> socks5 program: just set RSV and FRAG = 0
_, err = dst.WriteTo(append([]byte{0, 0, 0}, buf[:n]...), target)
_, err = dst.WriteToUDPAddrPort(append([]byte{0, 0, 0}, buf[:n]...), target)
}

if err != nil {
Expand Down

0 comments on commit e82e917

Please sign in to comment.