diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 5e5ed76531..198b8b8fd1 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -317,9 +317,16 @@ func (w *udpWorker) removeConn(id connID) { w.Unlock() } +func (w *udpWorker) handlePackets() { + receive := w.hub.Receive() + for payload := range receive { + w.callback(payload.Content, payload.Source, payload.OriginalDestination) + } +} + func (w *udpWorker) Start() error { w.activeConn = make(map[connID]*udpConn, 16) - h, err := udp.ListenUDP(w.address, w.port, w.callback, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256)) + h, err := udp.ListenUDP(w.address, w.port, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256)) if err != nil { return err } @@ -352,6 +359,7 @@ func (w *udpWorker) Start() error { return err } w.hub = h + go w.handlePackets() return nil } diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 2265334182..f6c3cfaa7b 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -61,7 +61,7 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon l.tlsConfig = config.GetTLSConfig() } - hub, err := udp.ListenUDP(address, port, l.OnReceive, udp.HubCapacity(1024)) + hub, err := udp.ListenUDP(address, port, udp.HubCapacity(1024)) if err != nil { return nil, err } @@ -69,10 +69,20 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon l.hub = hub l.Unlock() newError("listening on ", address, ":", port).WriteToLog() + + go l.handlePackets() + return l, nil } -func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalDest net.Destination) { +func (l *Listener) handlePackets() { + receive := l.hub.Receive() + for payload := range receive { + l.OnReceive(payload.Content, payload.Source) + } +} + +func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination) { segments := l.reader.Read(payload.Bytes()) payload.Release() @@ -81,13 +91,6 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD return } - l.Lock() - defer l.Unlock() - - if l.hub == nil { - return - } - conv := segments[0].Conversation() cmd := segments[0].Command() @@ -96,6 +99,10 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD Port: src.Port, Conv: conv, } + + l.Lock() + defer l.Unlock() + conn, found := l.sessions[id] if !found { diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index d92df47d97..b0e0676e1e 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -7,14 +7,11 @@ import ( // Payload represents a single UDP payload. type Payload struct { - payload *buf.Buffer - source net.Destination - originalDest net.Destination + Content *buf.Buffer + Source net.Destination + OriginalDestination net.Destination } -// PayloadHandler is function to handle Payload. -type PayloadHandler func(payload *buf.Buffer, source net.Destination, originalDest net.Destination) - type HubOption func(h *Hub) func HubCapacity(cap int) HubOption { @@ -31,12 +28,12 @@ func HubReceiveOriginalDestination(r bool) HubOption { type Hub struct { conn *net.UDPConn - callback PayloadHandler + cache chan *Payload capacity int recvOrigDest bool } -func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, options ...HubOption) (*Hub, error) { +func ListenUDP(address net.Address, port net.Port, options ...HubOption) (*Hub, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ IP: address.IP(), Port: int(port), @@ -48,13 +45,14 @@ func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, opti hub := &Hub{ conn: udpConn, capacity: 256, - callback: callback, recvOrigDest: false, } for _, opt := range options { opt(hub) } + hub.cache = make(chan *Payload, hub.capacity) + if hub.recvOrigDest { rawConn, err := udpConn.SyscallConn() if err != nil { @@ -70,10 +68,7 @@ func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, opti } } - c := make(chan *Payload, hub.capacity) - - go hub.start(c) - go hub.process(c) + go hub.start() return hub, nil } @@ -90,13 +85,8 @@ func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) { }) } -func (h *Hub) process(c <-chan *Payload) { - for p := range c { - h.callback(p.payload, p.source, p.originalDest) - } -} - -func (h *Hub) start(c chan<- *Payload) { +func (h *Hub) start() { + c := h.cache defer close(c) oobBytes := make([]byte, 256) @@ -119,13 +109,13 @@ func (h *Hub) start(c chan<- *Payload) { } payload := &Payload{ - payload: buffer, + Content: buffer, + Source: net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), } - payload.source = net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)) if h.recvOrigDest && noob > 0 { - payload.originalDest = RetrieveOriginalDest(oobBytes[:noob]) - if payload.originalDest.IsValid() { - newError("UDP original destination: ", payload.originalDest).AtDebug().WriteToLog() + payload.OriginalDestination = RetrieveOriginalDest(oobBytes[:noob]) + if payload.OriginalDestination.IsValid() { + newError("UDP original destination: ", payload.OriginalDestination).AtDebug().WriteToLog() } else { newError("failed to read UDP original destination").WriteToLog() } @@ -143,3 +133,7 @@ func (h *Hub) start(c chan<- *Payload) { func (h *Hub) Addr() net.Addr { return h.conn.LocalAddr() } + +func (h *Hub) Receive() <-chan *Payload { + return h.cache +}