Skip to content

Commit

Permalink
Improve wintun read
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Aug 8, 2022
1 parent 0fd822f commit d378b6c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
21 changes: 14 additions & 7 deletions tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,26 +136,33 @@ func (t *NativeTun) configure() error {
}

func (t *NativeTun) Read(p []byte) (n int, err error) {
err = t.ReadFunc(func(b []byte) {
n = copy(p, b)
})
return
}

func (t *NativeTun) ReadFunc(block func(b []byte)) error {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
return 0, os.ErrClosed
return os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
return 0, os.ErrClosed
return os.ErrClosed
}
packet, err := t.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
n = copy(p, packet)
block(packet)
t.session.ReleaseReceivePacket(packet)
t.rate.update(uint64(packetSize))
return n, nil
return nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
Expand All @@ -164,11 +171,11 @@ retry:
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
return os.ErrClosed
case windows.ERROR_INVALID_DATA:
return 0, errors.New("send ring corrupt")
return errors.New("send ring corrupt")
}
return 0, fmt.Errorf("read failed: %w", err)
return fmt.Errorf("read failed: %w", err)
}
}

Expand Down
26 changes: 13 additions & 13 deletions tun_windows_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package tun

import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"

"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
Expand Down Expand Up @@ -59,29 +56,32 @@ func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}

func (e *WintunEndpoint) dispatchLoop() {
_buffer := buf.StackNewSize(int(e.tun.mtu))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
for {
n, err := e.tun.Read(data)
var buffer bufferv2.Buffer
err := e.tun.ReadFunc(func(b []byte) {
buffer = bufferv2.MakeWithData(b)
})
if err != nil {
break
}
packet := data[:n]
ihl, ok := buffer.PullUp(0, 1)
if !ok {
buffer.Release()
continue
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
switch header.IPVersion(ihl.AsSlice()) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
default:
e.tun.Write(packet)
e.tun.Write(buffer.Flatten())
buffer.Release()
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(packet),
Payload: buffer,
IsForwardedPacket: true,
})
dispatcher := e.dispatcher
Expand Down

0 comments on commit d378b6c

Please sign in to comment.