Skip to content

Commit

Permalink
opt mwss
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed Jul 2, 2022
1 parent 2d02908 commit 927f948
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ kill -HUP pid

# 重载成功可以看到如下信息
[cfg-reload] Got A HUP Signal! Now Reloading Conf ...
[cfg] Load Config From file:config.json
Load Config From file:config.json
[cfg-reload] starr new relay name=[At=127.0.0.1:12342 Over=raw TCP-To=[0.0.0.0:5201] UDP-To=[0.0.0.0:5201] Through=raw]
[relay] Close relay [At=127.0.0.1:1234 Over=raw TCP-To=[0.0.0.0:5201] UDP-To=[0.0.0.0:5201] Through=raw]
[relay] Start UDP relay [At=127.0.0.1:12342 Over=raw TCP-To=[0.0.0.0:5201] UDP-To=[0.0.0.0:5201] Through=raw]
Expand Down
12 changes: 6 additions & 6 deletions cmd/ehco/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func startRelayServers(ctx context.Context, cfg *config.Config) error {
}

func watchAndReloadConfig(ctx context.Context, relayM *sync.Map, errCh chan error) {
cmdLogger.Infof("[cfg] Start to watch config file: %s ", ConfigPath)
cmdLogger.Infof("Start to watch config file: %s ", ConfigPath)

reloadCH := make(chan os.Signal, 1)
signal.Notify(reloadCH, syscall.SIGHUP)
Expand All @@ -271,32 +271,32 @@ func watchAndReloadConfig(ctx context.Context, relayM *sync.Map, errCh chan erro
case <-ctx.Done():
return
case <-reloadCH:
cmdLogger.Info("[cfg] Got A HUP Signal! Now Reloading Conf")
cmdLogger.Info("Got A HUP Signal! Now Reloading Conf")
newCfg, err := loadConfig()
if err != nil {
cmdLogger.Fatalf("[cfg] Reloading Conf meet error: %s ", err)
cmdLogger.Fatalf("Reloading Conf meet error: %s ", err)
}

var newRelayAddrList []string
for idx := range newCfg.RelayConfigs {
r, err := relay.NewRelay(&newCfg.RelayConfigs[idx])
if err != nil {
cmdLogger.Fatalf("[cfg] reload new relay failed err=%s", err.Error())
cmdLogger.Fatalf("reload new relay failed err=%s", err.Error())
}
newRelayAddrList = append(newRelayAddrList, r.Name)

// reload old relay
if oldR, ok := relayM.Load(r.Name); ok {
oldR := oldR.(*relay.Relay)
if oldR.Name != r.Name {
cmdLogger.Infof("[cfg] close old relay name=%s", oldR.Name)
cmdLogger.Infof("close old relay name=%s", oldR.Name)
stopOneRelay(oldR, relayM)
go startOneRelay(r, relayM, errCh)
}
continue // no need to reload
}
// start bread new relay that not in old relayM
cmdLogger.Infof("[cfg] starr new relay name=%s", r.Name)
cmdLogger.Infof("starr new relay name=%s", r.Name)
go startOneRelay(r, relayM, errCh)
}
// closed relay not in new config
Expand Down
4 changes: 2 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (c *Config) readFromFile() error {
if err != nil {
return err
}
log.InfoLogger.Info("[cfg] Load Config From file: ", c.PATH)
log.InfoLogger.Info("Load Config From file: ", c.PATH)
if err != nil {
return err
}
Expand All @@ -113,7 +113,7 @@ func (c *Config) readFromHttp() error {
return err
}
defer r.Body.Close()
log.InfoLogger.Info("[cfg] Load Config From http:", c.PATH)
log.InfoLogger.Info("Load Config From http:", c.PATH)
return json.NewDecoder(r.Body).Decode(&c)
}

Expand Down
4 changes: 1 addition & 3 deletions internal/constant/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ var (
)

const (
TCP_RATE_LIMIT = 60 // 每秒每个 IP 可以处理 60 个链接 TODO support config this

DialTimeOut = 3 * time.Second

MaxMWSSStreamCnt = 100
Expand All @@ -50,6 +48,6 @@ const (
Transport_MWSS = "mwss"

// todo add udp buffer size
BUFFER_POOL_SIZE = 1024 // suport 512 connections
BUFFER_POOL_SIZE = 1024 // support 512 connections
BUFFER_SIZE = 20 * 1024 // 20KB the maximum packet size of shadowsocks is about 16 KiB
)
2 changes: 1 addition & 1 deletion internal/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (r *Relay) RunLocalWSSServer() error {

func (r *Relay) RunLocalMWSSServer() error {
tp := r.TP.(*transporter.Raw)
mwssServer := transporter.NewMWSSServer()
mwssServer := transporter.NewMWSSServer(r.L)
mux := mux.NewRouter()
mux.Handle("/", http.HandlerFunc(web.Index))
mux.Handle("/mwss/", http.HandlerFunc(mwssServer.Upgrade))
Expand Down
15 changes: 10 additions & 5 deletions internal/transporter/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"io"
"net"
"os"
"syscall"
"time"

Expand All @@ -12,6 +13,7 @@ import (
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/web"
"github.com/Ehco1996/ehco/pkg/log"
"github.com/xtaci/smux"
)

// 全局pool
Expand Down Expand Up @@ -67,9 +69,12 @@ type WriteOnlyWriter struct {
io.Writer
}

// mute broken pipe or connection reset err.
// mute broken pipe connection reset timeout err.
func MuteErr(err error) error {
if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || err == nil {
if errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, smux.ErrTimeout) ||
os.IsTimeout(err) || err == nil {
return nil
}
return err
Expand All @@ -81,17 +86,17 @@ func transport(conn1, conn2 net.Conn, remote string) error {
go func() {
rn, err := io.Copy(WriteOnlyWriter{Writer: conn1}, ReadOnlyReader{Reader: conn2})
web.NetWorkTransmitBytes.WithLabelValues(remote, web.METRIC_CONN_TCP).Add(float64(rn * 2))
conn1.SetReadDeadline(time.Now().Add(constant.IdleTimeOut)) // unblock read on conn1
_ = conn1.SetReadDeadline(time.Now().Add(constant.IdleTimeOut)) // unblock read on conn1
errCH <- err
}()

// conn2 to conn1
rn, err := io.Copy(WriteOnlyWriter{Writer: conn2}, ReadOnlyReader{Reader: conn1})
web.NetWorkTransmitBytes.WithLabelValues(remote, web.METRIC_CONN_TCP).Add(float64(rn * 2))
if err2 := MuteErr(err); err2 != nil {
log.Logger.Errorf("[transport] from:%s to:%s meet error:%s", conn2.LocalAddr(), conn1.RemoteAddr(), err2.Error())
log.Logger.Errorf("from:%s to:%s meet error:%s", conn2.LocalAddr(), conn1.RemoteAddr(), err2.Error())
}
conn2.SetReadDeadline(time.Now().Add(constant.IdleTimeOut)) // unblock read on conn2
_ = conn2.SetReadDeadline(time.Now().Add(constant.IdleTimeOut)) // unblock read on conn2
return MuteErr(<-errCH)
}

Expand Down
49 changes: 25 additions & 24 deletions internal/transporter/conn.go → internal/transporter/mwss_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@ import (

"github.com/Ehco1996/ehco/internal/constant"
mytls "github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/pkg/log"
"github.com/gobwas/ws"
"github.com/xtaci/smux"
"go.uber.org/zap"
)

type mwssTransporter struct {
sessions map[string][]*smux.Session
sessionM map[string][]*smux.Session
sessionMutex sync.Mutex
dialer ws.Dialer
L *zap.SugaredLogger
}

func NewMWSSTransporter() *mwssTransporter {
func NewMWSSTransporter(l *zap.SugaredLogger) *mwssTransporter {
return &mwssTransporter{
sessions: make(map[string][]*smux.Session),
sessionM: make(map[string][]*smux.Session),
dialer: ws.Dialer{
TLSConfig: mytls.DefaultTLSConfig,
Timeout: constant.DialTimeOut},
L: l,
}
}

Expand All @@ -37,7 +39,7 @@ func (tr *mwssTransporter) Dial(addr string) (conn net.Conn, err error) {
var sessions []*smux.Session
var ok bool

sessions, ok = tr.sessions[addr]
sessions, ok = tr.sessionM[addr]
// 找到可以用的session
for sessionIndex, session = range sessions {
if session.NumStreams() >= constant.MaxMWSSStreamCnt {
Expand All @@ -50,7 +52,7 @@ func (tr *mwssTransporter) Dial(addr string) (conn net.Conn, err error) {

// 删除已经关闭的session
if session != nil && session.IsClosed() {
log.Logger.Infof("find closed session %v idx: %d", session, sessionIndex)
tr.L.Infof("find closed idx: %d", sessionIndex)
sessions = append(sessions[:sessionIndex], sessions[sessionIndex+1:]...)
ok = false
}
Expand All @@ -62,20 +64,13 @@ func (tr *mwssTransporter) Dial(addr string) (conn net.Conn, err error) {
return nil, err
}
sessions = append(sessions, session)
} else {
if len(sessions) > 1 {
// close last not used session, but we keep one conn in session pool
if lastSession := sessions[len(sessions)-1]; lastSession.NumStreams() == 0 {
lastSession.Close()
}
}
}
stream, err := session.OpenStream()
if err != nil {
session.Close()
return nil, err
}
tr.sessions[addr] = sessions
tr.sessionM[addr] = sessions
return stream, nil
}

Expand All @@ -85,31 +80,35 @@ func (tr *mwssTransporter) initSession(addr string) (*smux.Session, error) {
return nil, err
}
// stream multiplex
session, err := smux.Client(rc, nil)
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Client(rc, cfg)
if err != nil {
return nil, err
}
log.Logger.Infof("[mwss] Init new session to: %s", rc.RemoteAddr())
tr.L.Infof("Init new session to: %s", rc.RemoteAddr())
return session, nil
}

type MWSSServer struct {
Server *http.Server
ConnChan chan net.Conn
ErrChan chan error
L *zap.SugaredLogger
}

func NewMWSSServer() *MWSSServer {
func NewMWSSServer(l *zap.SugaredLogger) *MWSSServer {
return &MWSSServer{
ConnChan: make(chan net.Conn, 1024),
ErrChan: make(chan error, 1),
L: l,
}
}

func (s *MWSSServer) Upgrade(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
log.Logger.Error(err)
s.L.Error(err)
return
}
s.mux(conn)
Expand All @@ -118,27 +117,29 @@ func (s *MWSSServer) Upgrade(w http.ResponseWriter, r *http.Request) {
func (s *MWSSServer) mux(conn net.Conn) {
defer conn.Close()

session, err := smux.Server(conn, nil)
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
log.Logger.Infof("[mwss] server err %s - %s : %s", conn.RemoteAddr(), s.Server.Addr, err)
s.L.Infof("server err %s - %s : %s", conn.RemoteAddr(), s.Server.Addr, err)
return
}
defer session.Close()

log.Logger.Infof("[mwss] server init %s %s", conn.RemoteAddr(), s.Server.Addr)
defer log.Logger.Infof("[mwss] server close %s >-< %s", conn.RemoteAddr(), s.Server.Addr)
s.L.Infof("server init %s %s", conn.RemoteAddr(), s.Server.Addr)
defer s.L.Infof("server close %s >-< %s", conn.RemoteAddr(), s.Server.Addr)

for {
stream, err := session.AcceptStream()
if err != nil {
log.Logger.Infof("[mwss] accept stream err: %s", err)
s.L.Infof("accept stream err: %s", err)
break
}
select {
case s.ConnChan <- stream:
default:
stream.Close()
log.Logger.Infof("[mwss] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
s.L.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/transporter/picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func PickTransporter(transType string, tcpRemotes, udpRemotes lb.RoundRobin) Rel
case constant.Transport_WSS:
return &Wss{raw: &raw}
case constant.Transport_MWSS:
return &Mwss{raw: &raw, mtp: NewMWSSTransporter()}
return &Mwss{raw: &raw, mtp: NewMWSSTransporter(raw.L)}
}
return nil
}

0 comments on commit 927f948

Please sign in to comment.