Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions proxy/wireguard/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ import (
"github.com/xtls/xray-core/transport/internet"
)

const udpBufferSize = 1700 // max MTU for WireGuard

var bufferPool = sync.Pool{
New: func() any {
return make([]byte, udpBufferSize)
},
}

// netReadInfo holds the result of a read operation from a specific endpoint
type netReadInfo struct {
// status
waiter sync.WaitGroup
// param
buff []byte
// result
buff []byte
bytes int
endpoint conn.Endpoint
err error
Expand All @@ -30,8 +36,8 @@ type netBind struct {
dns dns.Client
dnsOption dns.IPOption

workers int
readQueue chan *netReadInfo
workers int
responseRecv chan *netReadInfo // responses from all endpoints flow through here
}

// SetMark implements conn.Bind
Expand Down Expand Up @@ -79,7 +85,7 @@ func (bind *netBind) BatchSize() int {

// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
bind.responseRecv = make(chan *netReadInfo)

fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
defer func() {
Expand All @@ -89,13 +95,16 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
}
}()

r := &netReadInfo{
buff: bufs[0],
r, ok := <-bind.responseRecv
if !ok {
return 0, errors.New("channel closed")
}
r.waiter.Add(1)
bind.readQueue <- r
r.waiter.Wait() // wait read goroutine done, or we will miss the result
sizes[0], eps[0] = r.bytes, r.endpoint

copy(bufs[0], r.buff[:r.bytes])
sizes[0] = r.bytes
eps[0] = r.endpoint
// Return buffer to pool
bufferPool.Put(r.buff)
return 1, r.err
}
workers := bind.workers
Expand All @@ -112,8 +121,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {

// Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
if bind.responseRecv != nil {
close(bind.responseRecv)
}
return nil
}
Expand All @@ -133,30 +142,32 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
}
endpoint.conn = c

go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
defer func() {
_ = recover() // gracefully handle send on closed channel
}()
for {
v, ok := <-readQueue
if !ok {
return
}
i, err := c.Read(v.buff)
buff := bufferPool.Get().([]byte)
i, err := c.Read(buff)

if i > 3 {
v.buff[1] = 0
v.buff[2] = 0
v.buff[3] = 0
buff[1] = 0
buff[2] = 0
buff[3] = 0
}

v.bytes = i
v.endpoint = endpoint
v.err = err
v.waiter.Done()
responseRecv <- &netReadInfo{
buff: buff,
bytes: i,
endpoint: endpoint,
err: err,
}
if err != nil {
endpoint.conn = nil
return
}
}
}(bind.readQueue, endpoint)
}(bind.responseRecv, endpoint, c)

return nil
}
Expand Down
20 changes: 12 additions & 8 deletions proxy/wireguard/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,20 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
}

for _, payload := range mpayload {
v, ok := <-s.bindServer.readQueue
if !ok {
data := bufferPool.Get().([]byte)
n, err := payload.Read(data)

select {
case s.bindServer.responseRecv <- &netReadInfo{
buff: data,
bytes: n,
endpoint: nep,
err: err,
}:
case <-ctx.Done():
bufferPool.Put(data) // Return buffer if not sent
return nil
}
i, err := payload.Read(v.buff)

v.bytes = i
v.endpoint = nep
v.err = err
v.waiter.Done()
if err != nil && goerrors.Is(err, io.EOF) {
nep.conn = nil
return nil
Expand Down