From d378b6ca536dcb9f2c819cf3e56493e319095993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 8 Aug 2022 21:34:32 +0800 Subject: [PATCH] Improve wintun read --- tun_windows.go | 21 ++++++++++++++------- tun_windows_gvisor.go | 26 +++++++++++++------------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/tun_windows.go b/tun_windows.go index 9badbe2..3dc80dc 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -136,26 +136,33 @@ func (t *NativeTun) configure() error { } func (t *NativeTun) Read(p []byte) (n int, err error) { + err = t.ReadFunc(func(b []byte) { + n = copy(p, b) + }) + return +} + +func (t *NativeTun) ReadFunc(block func(b []byte)) error { t.running.Add(1) defer t.running.Done() retry: if atomic.LoadInt32(&t.close) == 1 { - return 0, os.ErrClosed + return os.ErrClosed } start := nanotime() shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 for { if atomic.LoadInt32(&t.close) == 1 { - return 0, os.ErrClosed + return os.ErrClosed } packet, err := t.session.ReceivePacket() switch err { case nil: packetSize := len(packet) - n = copy(p, packet) + block(packet) t.session.ReleaseReceivePacket(packet) t.rate.update(uint64(packetSize)) - return n, nil + return nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(t.readWait, windows.INFINITE) @@ -164,11 +171,11 @@ retry: procyield(1) continue case windows.ERROR_HANDLE_EOF: - return 0, os.ErrClosed + return os.ErrClosed case windows.ERROR_INVALID_DATA: - return 0, errors.New("send ring corrupt") + return errors.New("send ring corrupt") } - return 0, fmt.Errorf("read failed: %w", err) + return fmt.Errorf("read failed: %w", err) } } diff --git a/tun_windows_gvisor.go b/tun_windows_gvisor.go index 6915053..33227e2 100644 --- a/tun_windows_gvisor.go +++ b/tun_windows_gvisor.go @@ -3,9 +3,6 @@ package tun import ( - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -59,29 +56,32 @@ func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } func (e *WintunEndpoint) dispatchLoop() { - _buffer := buf.StackNewSize(int(e.tun.mtu)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - data := buffer.FreeBytes() for { - n, err := e.tun.Read(data) + var buffer bufferv2.Buffer + err := e.tun.ReadFunc(func(b []byte) { + buffer = bufferv2.MakeWithData(b) + }) if err != nil { break } - packet := data[:n] + ihl, ok := buffer.PullUp(0, 1) + if !ok { + buffer.Release() + continue + } var networkProtocol tcpip.NetworkProtocolNumber - switch header.IPVersion(packet) { + switch header.IPVersion(ihl.AsSlice()) { case header.IPv4Version: networkProtocol = header.IPv4ProtocolNumber case header.IPv6Version: networkProtocol = header.IPv6ProtocolNumber default: - e.tun.Write(packet) + e.tun.Write(buffer.Flatten()) + buffer.Release() continue } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: bufferv2.MakeWithData(packet), + Payload: buffer, IsForwardedPacket: true, }) dispatcher := e.dispatcher