Skip to content

Commit

Permalink
chore(deps): update smux pkg
Browse files Browse the repository at this point in the history
  • Loading branch information
nadoo committed Mar 9, 2023
1 parent 8d0d888 commit 7e80055
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
type=semver,pattern={{major}}.{{minor}}
- name: Docker - Build and push
uses: docker/build-push-action@v3
uses: docker/build-push-action@v4
with:
context: .
file: .Dockerfile
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d
github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb
github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152
github.com/insomniacslk/dhcp v0.0.0-20230301142404-3e45eea5edd7
github.com/insomniacslk/dhcp v0.0.0-20230307103557-e252950ab961
github.com/nadoo/conflag v0.3.1
github.com/nadoo/ipset v0.5.0
github.com/xtaci/kcp-go/v5 v5.6.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/insomniacslk/dhcp v0.0.0-20230301142404-3e45eea5edd7 h1:Fg8rHYs8luh8kCSAHDUIQCNMkn74Gvr1o5YPZdNRgY0=
github.com/insomniacslk/dhcp v0.0.0-20230301142404-3e45eea5edd7/go.mod h1:I9wtoXVkcRwQJ+U9nhxzZytbnT1xjn2DzUjxQ8Qegpc=
github.com/insomniacslk/dhcp v0.0.0-20230307103557-e252950ab961 h1:x/YtdDlmypenG1te/FfH6LVM+3krhXk5CFV8VYNNX5M=
github.com/insomniacslk/dhcp v0.0.0-20230307103557-e252950ab961/go.mod h1:IKrnDWs3/Mqq5n0lI+RxA2sB7MvN/vbMBP3ehXg65UI=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
Expand Down
59 changes: 41 additions & 18 deletions pkg/smux/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import (
"container/heap"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"sync/atomic"
"time"
Expand All @@ -17,20 +15,30 @@ import (

const (
defaultAcceptBacklog = 1024
maxShaperSize = 1024
openCloseTimeout = 30 * time.Second // stream open/close timeout
)

// define frame class
type CLASSID int

const (
CLSCTRL CLASSID = iota
CLSDATA
)

var (
ErrInvalidProtocol = errors.New("invalid protocol")
ErrConsumed = errors.New("peer consumed more than sent")
ErrGoAway = errors.New("stream id overflows, should start a new connection")
// ErrTimeout = errors.New("timeout")
ErrTimeout = fmt.Errorf("smux: %w", os.ErrDeadlineExceeded)
ErrWouldBlock = errors.New("operation would block on IO")
ErrTimeout = errors.New("timeout")
ErrWouldBlock = errors.New("operation would block on IO")
)

type writeRequest struct {
prio uint32
class CLASSID
frame Frame
seq uint32
result chan writeResult
}

Expand All @@ -39,10 +47,6 @@ type writeResult struct {
err error
}

type buffersWriter interface {
WriteBuffers(v [][]byte) (n int, err error)
}

// Session defines a multiplexed connection for streams
type Session struct {
conn io.ReadWriteCloser
Expand Down Expand Up @@ -81,8 +85,9 @@ type Session struct {

deadline atomic.Value

shaper chan writeRequest // a shaper for writing
writes chan writeRequest
requestID uint32 // write request monotonic increasing
shaper chan writeRequest // a shaper for writing
writes chan writeRequest
}

func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
Expand Down Expand Up @@ -401,7 +406,7 @@ func (s *Session) keepalive() {
for {
select {
case <-tickerPing.C:
s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, CLSCTRL)
s.notifyBucket() // force a signal to the recvLoop
case <-tickerTimeout.C:
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
Expand All @@ -423,19 +428,33 @@ func (s *Session) shaperLoop() {
var reqs shaperHeap
var next writeRequest
var chWrite chan writeRequest
var chShaper chan writeRequest

for {
// chWrite is not available until it has packet to send
if len(reqs) > 0 {
chWrite = s.writes
next = heap.Pop(&reqs).(writeRequest)
} else {
chWrite = nil
}

// control heap size, chShaper is not available until packets are less than maximum allowed
if len(reqs) >= maxShaperSize {
chShaper = nil
} else {
chShaper = s.shaper
}

// assertion on non nil
if chShaper == nil && chWrite == nil {
panic("both channel are nil")
}

select {
case <-s.die:
return
case r := <-s.shaper:
case r := <-chShaper:
if chWrite != nil { // next is valid, reshape
heap.Push(&reqs, next)
}
Expand All @@ -451,7 +470,10 @@ func (s *Session) sendLoop() {
var err error
var vec [][]byte // vector for writeBuffers

bw, ok := s.conn.(buffersWriter)
bw, ok := s.conn.(interface {
WriteBuffers(v [][]byte) (n int, err error)
})

if ok {
buf = make([]byte, headerSize)
vec = make([][]byte, 2)
Expand Down Expand Up @@ -503,14 +525,15 @@ func (s *Session) sendLoop() {
// writeFrame writes the frame to the underlying connection
// and returns the number of bytes written if successful
func (s *Session) writeFrame(f Frame) (n int, err error) {
return s.writeFrameInternal(f, nil, 0)
return s.writeFrameInternal(f, time.After(openCloseTimeout), CLSCTRL)
}

// internal writeFrame version to support deadline used in keepalive
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint32) (int, error) {
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, class CLASSID) (int, error) {
req := writeRequest{
prio: prio,
class: class,
frame: f,
seq: atomic.AddUint32(&s.requestID, 1),
result: make(chan writeResult, 1),
}
select {
Expand Down
8 changes: 4 additions & 4 deletions pkg/smux/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ func TestWriteFrameInternal(t *testing.T) {
session.Close()
for i := 0; i < 100; i++ {
f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), CLSDATA)
}

// random cmds
Expand All @@ -879,14 +879,14 @@ func TestWriteFrameInternal(t *testing.T) {
session, _ = Client(cli, nil)
for i := 0; i < 100; i++ {
f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), CLSDATA)
}
//deadline occur
{
c := make(chan time.Time)
close(c)
f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
_, err := session.writeFrameInternal(f, c, 0)
_, err := session.writeFrameInternal(f, c, CLSDATA)
if !strings.Contains(err.Error(), "timeout") {
t.Fatal("write frame with deadline failed", err)
}
Expand All @@ -911,7 +911,7 @@ func TestWriteFrameInternal(t *testing.T) {
time.Sleep(time.Second)
close(c)
}()
_, err = session.writeFrameInternal(f, c, 0)
_, err = session.writeFrameInternal(f, c, CLSDATA)
if !strings.Contains(err.Error(), "closed pipe") {
t.Fatal("write frame with to closed conn failed", err)
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/smux/shaper.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@ func _itimediff(later, earlier uint32) int32 {

type shaperHeap []writeRequest

func (h shaperHeap) Len() int { return len(h) }
func (h shaperHeap) Less(i, j int) bool { return _itimediff(h[j].prio, h[i].prio) > 0 }
func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *shaperHeap) Push(x any) { *h = append(*h, x.(writeRequest)) }
func (h shaperHeap) Len() int { return len(h) }
func (h shaperHeap) Less(i, j int) bool {
if h[i].class != h[j].class {
return h[i].class < h[j].class
}
return _itimediff(h[j].seq, h[i].seq) > 0
}

func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }

func (h *shaperHeap) Pop() any {
func (h *shaperHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
Expand Down
40 changes: 29 additions & 11 deletions pkg/smux/shaper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
)

func TestShaper(t *testing.T) {
w1 := writeRequest{prio: 10}
w2 := writeRequest{prio: 10}
w3 := writeRequest{prio: 20}
w4 := writeRequest{prio: 100}
w5 := writeRequest{prio: (1 << 32) - 1}
w1 := writeRequest{seq: 1}
w2 := writeRequest{seq: 2}
w3 := writeRequest{seq: 3}
w4 := writeRequest{seq: 4}
w5 := writeRequest{seq: 5}

var reqs shaperHeap
heap.Push(&reqs, w5)
Expand All @@ -19,14 +19,32 @@ func TestShaper(t *testing.T) {
heap.Push(&reqs, w2)
heap.Push(&reqs, w1)

var lastPrio = reqs[0].prio
for len(reqs) > 0 {
w := heap.Pop(&reqs).(writeRequest)
if int32(w.prio-lastPrio) < 0 {
t.Fatal("incorrect shaper priority")
}
t.Log("sid:", w.frame.sid, "seq:", w.seq)
}
}

func TestShaper2(t *testing.T) {
w1 := writeRequest{class: CLSDATA, seq: 1} // stream 0
w2 := writeRequest{class: CLSDATA, seq: 2}
w3 := writeRequest{class: CLSDATA, seq: 3}
w4 := writeRequest{class: CLSDATA, seq: 4}
w5 := writeRequest{class: CLSDATA, seq: 5}
w6 := writeRequest{class: CLSCTRL, seq: 6, frame: Frame{sid: 10}} // ctrl 1
w7 := writeRequest{class: CLSCTRL, seq: 7, frame: Frame{sid: 11}} // ctrl 2

var reqs shaperHeap
heap.Push(&reqs, w6)
heap.Push(&reqs, w5)
heap.Push(&reqs, w4)
heap.Push(&reqs, w3)
heap.Push(&reqs, w2)
heap.Push(&reqs, w1)
heap.Push(&reqs, w7)

t.Log("prio:", w.prio)
lastPrio = w.prio
for len(reqs) > 0 {
w := heap.Pop(&reqs).(writeRequest)
t.Log("sid:", w.frame.sid, "seq:", w.seq)
}
}
20 changes: 14 additions & 6 deletions pkg/smux/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (s *Stream) tryReadv2(b []byte) (n int, err error) {

// in an ideal environment:
// if more than half of buffer has consumed, send read ack to peer
// based on round-trip time of ACK, continuous flowing data
// based on round-trip time of ACK, continous flowing data
// won't slow down because of waiting for ACK, as long as the
// consumer keeps on reading data
// s.numRead == n also notify window at the first read
Expand All @@ -156,8 +156,9 @@ func (s *Stream) tryReadv2(b []byte) (n int, err error) {
if notifyConsumed > 0 {
err := s.sendWindowUpdate(notifyConsumed)
return n, err
} else {
return n, nil
}
return n, nil
}

select {
Expand Down Expand Up @@ -256,7 +257,7 @@ func (s *Stream) sendWindowUpdate(consumed uint32) error {
binary.LittleEndian.PutUint32(hdr[:], consumed)
binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
frame.data = hdr[:]
_, err := s.sess.writeFrameInternal(frame, deadline, 0)
_, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
return err
}

Expand All @@ -273,6 +274,12 @@ func (s *Stream) waitRead() error {
case <-s.chReadEvent:
return nil
case <-s.chFinEvent:
// BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82
s.bufferLock.Lock()
defer s.bufferLock.Unlock()
if len(s.buffers) > 0 {
return nil
}
return io.EOF
case <-s.sess.chSocketReadError:
return s.sess.socketReadError.Load().(error)
Expand Down Expand Up @@ -320,7 +327,7 @@ func (s *Stream) Write(b []byte) (n int, err error) {
}
frame.data = bts[:sz]
bts = bts[sz:]
n, err := s.sess.writeFrameInternal(frame, deadline, s.numWritten)
n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
s.numWritten++
sent += n
if err != nil {
Expand Down Expand Up @@ -388,7 +395,7 @@ func (s *Stream) writeV2(b []byte) (n int, err error) {
}
frame.data = bts[:sz]
bts = bts[sz:]
n, err := s.sess.writeFrameInternal(frame, deadline, atomic.LoadUint32(&s.numWritten))
n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
atomic.AddUint32(&s.numWritten, uint32(sz))
sent += n
if err != nil {
Expand Down Expand Up @@ -432,8 +439,9 @@ func (s *Stream) Close() error {
_, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id))
s.sess.streamClosed(s.id)
return err
} else {
return io.ErrClosedPipe
}
return io.ErrClosedPipe
}

// GetDieCh returns a readonly chan which can be readable
Expand Down

0 comments on commit 7e80055

Please sign in to comment.