From c72e497feeada11c76f39a537c4fdb71e6d874a8 Mon Sep 17 00:00:00 2001 From: ehco1996 Date: Sun, 13 Nov 2022 11:42:48 +0800 Subject: [PATCH] improve code --- internal/relay/relay.go | 105 +++++------------------------------ internal/transporter/mtcp.go | 63 +++++++++++++-------- internal/transporter/mwss.go | 76 +++++++++++++++++++------ internal/transporter/raw.go | 64 +-------------------- internal/transporter/ws.go | 52 +++++++++++++++++ internal/transporter/wss.go | 65 +++++++++++++++++++++- 6 files changed, 231 insertions(+), 194 deletions(-) diff --git a/internal/relay/relay.go b/internal/relay/relay.go index 5c7b477cf..72a474683 100644 --- a/internal/relay/relay.go +++ b/internal/relay/relay.go @@ -1,21 +1,17 @@ package relay import ( - "crypto/tls" "fmt" "net" - "net/http" "sync" "time" "github.com/Ehco1996/ehco/internal/config" "github.com/Ehco1996/ehco/internal/constant" "github.com/Ehco1996/ehco/internal/lb" - mytls "github.com/Ehco1996/ehco/internal/tls" "github.com/Ehco1996/ehco/internal/transporter" "github.com/Ehco1996/ehco/internal/web" "github.com/Ehco1996/ehco/pkg/log" - "github.com/gorilla/mux" "go.uber.org/zap" ) @@ -195,116 +191,41 @@ func (r *Relay) RunLocalUDPServer() error { } func (r *Relay) RunLocalMTCPServer() error { - mTCPServer := transporter.NewMTCPServer(r.L.Named("MTCPServer"), r.LocalTCPAddr) + tp := r.TP.(*transporter.Raw) + mTCPServer := transporter.NewMTCPServer(r.LocalTCPAddr.String(), tp, r.L.Named("MTCPServer")) r.closeTcpF = func() error { return mTCPServer.Close() } - - go func() { - r.L.Infof("Start MTCP relay server %s", r.Name) - mTCPServer.ListenAndServe() - }() - - tp := r.TP.(*transporter.Raw) - for { - conn, e := mTCPServer.Accept() - if e != nil { - return e - } - go func(c net.Conn) { - remote := tp.GetRemote() - web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() - defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() - defer c.Close() - if err := tp.HandleTCPConn(c, remote); err != nil { - r.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err) - } - }(conn) - } + r.L.Infof("Start MTCP relay server %s", r.Name) + return mTCPServer.ListenAndServe() } func (r *Relay) RunLocalWSServer() error { tp := r.TP.(*transporter.Raw) - mux := mux.NewRouter() - mux.HandleFunc("/", web.MakeIndexF(r.L)) - mux.HandleFunc("/ws/", tp.HandleWsRequest) - server := &http.Server{ - Addr: r.LocalTCPAddr.String(), - ReadHeaderTimeout: 30 * time.Second, - Handler: mux, - } - lis, err := net.Listen("tcp", r.LocalTCPAddr.String()) - if err != nil { - return err - } - defer lis.Close() + wsServer := transporter.NewWSServer(r.LocalTCPAddr.String(), tp, r.L.Named("WSServer")) r.closeTcpF = func() error { - return lis.Close() + return wsServer.Close() } r.L.Infof("Start WS relay Server %s", r.Name) - return server.Serve(lis) + return wsServer.ListenAndServe() } func (r *Relay) RunLocalWSSServer() error { tp := r.TP.(*transporter.Raw) - mux := mux.NewRouter() - mux.HandleFunc("/", web.MakeIndexF(r.L)) - mux.HandleFunc("/wss/", tp.HandleWssRequest) - - server := &http.Server{ - Addr: r.LocalTCPAddr.String(), - TLSConfig: mytls.DefaultTLSConfig, - ReadHeaderTimeout: 30 * time.Second, - Handler: mux, - } - lis, err := net.Listen("tcp", r.LocalTCPAddr.String()) - if err != nil { - return err - } - defer lis.Close() + wssServer := transporter.NewWSSServer(r.LocalTCPAddr.String(), tp, r.L.Named("NewWSSServer")) r.closeTcpF = func() error { - return lis.Close() + return wssServer.Close() } r.L.Infof("Start WSS relay Server %s", r.Name) - return server.Serve(tls.NewListener(lis, server.TLSConfig)) + return wssServer.ListenAndServe() } func (r *Relay) RunLocalMWSSServer() error { tp := r.TP.(*transporter.Raw) - mwssServer := transporter.NewMWSSServer(r.L.Named("MWSSServer")) - mux := mux.NewRouter() - mux.Handle("/", web.MakeIndexF(r.L)) - mux.Handle("/mwss/", http.HandlerFunc(mwssServer.Upgrade)) - httpServer := &http.Server{ - Addr: r.LocalTCPAddr.String(), - Handler: mux, - TLSConfig: mytls.DefaultTLSConfig, - ReadHeaderTimeout: 30 * time.Second, - } - mwssServer.Server = httpServer - - lis, err := net.Listen("tcp", r.LocalTCPAddr.String()) - if err != nil { - return err - } - defer lis.Close() + mwssServer := transporter.NewMWSSServer(r.LocalTCPAddr.String(), tp, r.L.Named("MWSSServer")) r.closeTcpF = func() error { - return lis.Close() + return mwssServer.Close() } r.L.Infof("Start MWSS relay Server %s", r.Name) - go func() { - err := httpServer.Serve(tls.NewListener(lis, httpServer.TLSConfig)) - if err != nil { - mwssServer.ErrChan <- err - } - close(mwssServer.ErrChan) - }() - - for { - conn, e := mwssServer.Accept() - if e != nil { - return e - } - go tp.HandleMWssRequest(conn) - } + return mwssServer.ListenAndServe() } diff --git a/internal/transporter/mtcp.go b/internal/transporter/mtcp.go index 0f79ee685..e3c47d74d 100644 --- a/internal/transporter/mtcp.go +++ b/internal/transporter/mtcp.go @@ -5,6 +5,7 @@ import ( "net" "github.com/Ehco1996/ehco/internal/lb" + "github.com/Ehco1996/ehco/internal/web" "github.com/xtaci/smux" "go.uber.org/zap" ) @@ -38,20 +39,22 @@ func (s *MTCP) GetRemote() *lb.Node { } type MTCPServer struct { - listenAddr *net.TCPAddr - listener *net.TCPListener + raw *Raw + listenAddr string + listener net.Listener + L *zap.SugaredLogger - ConnChan chan net.Conn - ErrChan chan error - L *zap.SugaredLogger + errChan chan error + connChan chan net.Conn } -func NewMTCPServer(l *zap.SugaredLogger, listenAddr *net.TCPAddr) *MTCPServer { +func NewMTCPServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *MTCPServer { return &MTCPServer{ - ConnChan: make(chan net.Conn, 1024), - ErrChan: make(chan error, 1), L: l, + raw: raw, listenAddr: listenAddr, + errChan: make(chan error, 1), + connChan: make(chan net.Conn, 1024), } } @@ -77,7 +80,7 @@ func (s *MTCPServer) mux(conn net.Conn) { break } select { - case s.ConnChan <- stream: + case s.connChan <- stream: default: stream.Close() s.L.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) @@ -87,29 +90,45 @@ func (s *MTCPServer) mux(conn net.Conn) { func (s *MTCPServer) Accept() (conn net.Conn, err error) { select { - case conn = <-s.ConnChan: - case err = <-s.ErrChan: + case conn = <-s.connChan: + case err = <-s.errChan: } return } -func (s *MTCPServer) ListenAndServe() { - lis, err := net.ListenTCP("tcp", s.listenAddr) +func (s *MTCPServer) ListenAndServe() error { + lis, err := net.Listen("tcp", s.listenAddr) if err != nil { - s.ErrChan <- err - return + return err } s.listener = lis - for { - c, err := lis.AcceptTCP() - if err != nil { - s.ErrChan <- err - continue + + go func() { + for { + c, err := lis.Accept() + if err != nil { + s.errChan <- err + continue + } + go s.mux(c) } + }() - go s.mux(c) + for { + conn, e := s.Accept() + if e != nil { + return e + } + go func(c net.Conn) { + remote := s.raw.GetRemote() + web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() + defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() + defer c.Close() + if err := s.raw.HandleTCPConn(c, remote); err != nil { + s.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err) + } + }(conn) } - } func (s *MTCPServer) Close() error { diff --git a/internal/transporter/mwss.go b/internal/transporter/mwss.go index 2057609f5..6534965c1 100644 --- a/internal/transporter/mwss.go +++ b/internal/transporter/mwss.go @@ -2,13 +2,17 @@ package transporter import ( "context" + "crypto/tls" "net" "net/http" + "time" "github.com/Ehco1996/ehco/internal/constant" "github.com/Ehco1996/ehco/internal/lb" mytls "github.com/Ehco1996/ehco/internal/tls" + "github.com/Ehco1996/ehco/internal/web" "github.com/gobwas/ws" + "github.com/gorilla/mux" "github.com/xtaci/smux" "go.uber.org/zap" ) @@ -42,21 +46,61 @@ func (s *Mwss) GetRemote() *lb.Node { } type MWSSServer struct { - Server *http.Server - ConnChan chan net.Conn - ErrChan chan error - L *zap.SugaredLogger + raw *Raw + httpServer *http.Server + L *zap.SugaredLogger + + connChan chan net.Conn + errChan chan error } -func NewMWSSServer(l *zap.SugaredLogger) *MWSSServer { - return &MWSSServer{ - ConnChan: make(chan net.Conn, 1024), - ErrChan: make(chan error, 1), +func NewMWSSServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *MWSSServer { + s := &MWSSServer{ + raw: raw, L: l, + errChan: make(chan error, 1), + connChan: make(chan net.Conn, 1024), + } + + mux := mux.NewRouter() + mux.Handle("/", web.MakeIndexF(l)) + mux.Handle("/mwss/", http.HandlerFunc(s.HandleRequest)) + s.httpServer = &http.Server{ + Addr: listenAddr, + Handler: mux, + TLSConfig: mytls.DefaultTLSConfig, + ReadHeaderTimeout: 30 * time.Second, + } + return s +} + +func (s *MWSSServer) ListenAndServe() error { + lis, err := net.Listen("tcp", s.httpServer.Addr) + if err != nil { + return err + } + go func() { + s.errChan <- s.httpServer.Serve(tls.NewListener(lis, s.httpServer.TLSConfig)) + }() + + for { + conn, e := s.Accept() + if e != nil { + return e + } + go func(c net.Conn) { + remote := s.raw.GetRemote() + web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() + defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() + defer c.Close() + if err := s.raw.HandleTCPConn(c, remote); err != nil { + s.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err) + } + }(conn) } } -func (s *MWSSServer) Upgrade(w http.ResponseWriter, r *http.Request) { +func (s *MWSSServer) HandleRequest(w http.ResponseWriter, r *http.Request) { conn, _, _, err := ws.UpgradeHTTP(r, w) if err != nil { s.L.Error(err) @@ -72,13 +116,13 @@ func (s *MWSSServer) mux(conn net.Conn) { cfg.KeepAliveDisabled = true session, err := smux.Server(conn, cfg) if err != nil { - s.L.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.Server.Addr, err) + s.L.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.httpServer.Addr, err) return } defer session.Close() - s.L.Debugf("session init %s %s", conn.RemoteAddr(), s.Server.Addr) - defer s.L.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.Server.Addr) + s.L.Debugf("session init %s %s", conn.RemoteAddr(), s.httpServer.Addr) + defer s.L.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.httpServer.Addr) for { stream, err := session.AcceptStream() @@ -87,7 +131,7 @@ func (s *MWSSServer) mux(conn net.Conn) { break } select { - case s.ConnChan <- stream: + case s.connChan <- stream: default: stream.Close() s.L.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) @@ -97,14 +141,14 @@ func (s *MWSSServer) mux(conn net.Conn) { func (s *MWSSServer) Accept() (conn net.Conn, err error) { select { - case conn = <-s.ConnChan: - case err = <-s.ErrChan: + case conn = <-s.connChan: + case err = <-s.errChan: } return } func (s *MWSSServer) Close() error { - return s.Server.Close() + return s.httpServer.Close() } type MWSSClient struct { diff --git a/internal/transporter/raw.go b/internal/transporter/raw.go index 70a4bb659..d9d52aa4c 100644 --- a/internal/transporter/raw.go +++ b/internal/transporter/raw.go @@ -3,14 +3,12 @@ package transporter import ( "context" "net" - "net/http" "sync" "time" "github.com/Ehco1996/ehco/internal/constant" "github.com/Ehco1996/ehco/internal/lb" "github.com/Ehco1996/ehco/internal/web" - "github.com/gobwas/ws" "go.uber.org/zap" ) @@ -55,7 +53,7 @@ func (raw *Raw) HandleUDPConn(uaddr *net.UDPAddr, local *net.UDPConn) { raw.udpmu.Unlock() }() - raw.L.Infof("[raw] HandleUDPConn from %s to %s", local.LocalAddr().String(), remote.Label) + raw.L.Infof("HandleUDPConn from %s to %s", local.LocalAddr().String(), remote.Label) buf := BufferPool.Get() defer BufferPool.Put(buf) @@ -124,65 +122,7 @@ func (raw *Raw) HandleTCPConn(c net.Conn, remote *lb.Node) error { if err != nil { return err } - raw.L.Infof("[raw] HandleTCPConn from %s to %s", c.RemoteAddr(), remote.Address) + raw.L.Infof("HandleTCPConn from %s to %s", c.RemoteAddr(), remote.Address) defer rc.Close() return transport(rc, c, remote.Label) } - -func (raw *Raw) HandleWsRequest(w http.ResponseWriter, req *http.Request) { - wsc, _, _, err := ws.UpgradeHTTP(req, w) - if err != nil { - return - } - defer wsc.Close() - remote := raw.TCPRemotes.Next() - web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() - defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() - rc, err := raw.DialRemote(remote) - if err != nil { - return - } - defer rc.Close() - raw.L.Infof("[tun] HandleWsRequest from:%s to:%s", wsc.RemoteAddr(), remote.Address) - if err := transport(rc, wsc, remote.Label); err != nil { - raw.L.Infof("[tun] HandleWsRequest meet error from:%s to:%s err:%s", wsc.RemoteAddr(), remote.Address, err.Error()) - } -} - -func (raw *Raw) HandleWssRequest(w http.ResponseWriter, req *http.Request) { - wsc, _, _, err := ws.UpgradeHTTP(req, w) - if err != nil { - return - } - defer wsc.Close() - remote := raw.TCPRemotes.Next() - web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() - defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() - - rc, err := raw.DialRemote(remote) - if err != nil { - return - } - defer rc.Close() - raw.L.Infof("[tun] HandleWssRequest from:%s to:%s", wsc.RemoteAddr(), remote.Address) - if err := transport(rc, wsc, remote.Label); err != nil { - raw.L.Infof("[tun] HandleWssRequest meet error from:%s to:%s err:%s", wsc.LocalAddr(), remote.Label, err.Error()) - } -} - -func (raw *Raw) HandleMWssRequest(wsc net.Conn) { - defer wsc.Close() - remote := raw.TCPRemotes.Next() - web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() - defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() - - rc, err := raw.DialRemote(remote) - if err != nil { - return - } - defer rc.Close() - raw.L.Infof("[tun] HandleMWssRequest from:%s to:%s", wsc.RemoteAddr(), remote.Address) - if err := transport(wsc, rc, remote.Label); err != nil { - raw.L.Infof("[tun] HandleMWssRequest meet error from:%s to:%s err:%s", wsc.RemoteAddr(), remote.Label, err.Error()) - } -} diff --git a/internal/transporter/ws.go b/internal/transporter/ws.go index fa5660126..fb62f451c 100644 --- a/internal/transporter/ws.go +++ b/internal/transporter/ws.go @@ -3,9 +3,14 @@ package transporter import ( "context" "net" + "net/http" + "time" "github.com/Ehco1996/ehco/internal/lb" + "github.com/Ehco1996/ehco/internal/web" "github.com/gobwas/ws" + "github.com/gorilla/mux" + "go.uber.org/zap" ) type Ws struct { @@ -35,3 +40,50 @@ func (s *Ws) HandleTCPConn(c net.Conn, remote *lb.Node) error { func (s *Ws) GetRemote() *lb.Node { return s.raw.GetRemote() } + +type WSServer struct { + raw *Raw + L *zap.SugaredLogger + httpServer *http.Server +} + +func NewWSServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *WSServer { + s := &WSServer{raw: raw, L: l} + mux := mux.NewRouter() + mux.HandleFunc("/", web.MakeIndexF(l)) + mux.HandleFunc("/ws/", s.HandleRequest) + s.httpServer = &http.Server{ + Addr: listenAddr, + ReadHeaderTimeout: 30 * time.Second, + Handler: mux, + } + return s +} + +func (s *WSServer) ListenAndServe() error { + return s.httpServer.ListenAndServe() +} + +func (s *WSServer) Close() error { + return s.httpServer.Close() +} + +func (s *WSServer) HandleRequest(w http.ResponseWriter, req *http.Request) { + wsc, _, _, err := ws.UpgradeHTTP(req, w) + if err != nil { + return + } + defer wsc.Close() + remote := s.raw.TCPRemotes.Next() + web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() + defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() + rc, err := s.raw.DialRemote(remote) + if err != nil { + return + } + defer rc.Close() + s.L.Infof("HandleRequest from:%s to:%s", wsc.RemoteAddr(), remote.Address) + if err := transport(rc, wsc, remote.Label); err != nil { + s.L.Infof("HandleRequest meet error from:%s to:%s err:%s", wsc.RemoteAddr(), remote.Address, err.Error()) + } +} diff --git a/internal/transporter/wss.go b/internal/transporter/wss.go index 10ceb072e..7a9e93618 100644 --- a/internal/transporter/wss.go +++ b/internal/transporter/wss.go @@ -2,11 +2,17 @@ package transporter import ( "context" + "crypto/tls" "net" + "net/http" + "time" "github.com/Ehco1996/ehco/internal/lb" - "github.com/Ehco1996/ehco/internal/tls" + mytls "github.com/Ehco1996/ehco/internal/tls" + "github.com/Ehco1996/ehco/internal/web" "github.com/gobwas/ws" + "github.com/gorilla/mux" + "go.uber.org/zap" ) type Wss struct { @@ -24,7 +30,7 @@ func (s *Wss) HandleUDPConn(uaddr *net.UDPAddr, local *net.UDPConn) { func (s *Wss) HandleTCPConn(c net.Conn, remote *lb.Node) error { defer c.Close() - d := ws.Dialer{TLSConfig: tls.DefaultTLSConfig} + d := ws.Dialer{TLSConfig: mytls.DefaultTLSConfig} wsc, _, _, err := d.Dial(context.TODO(), remote.Address+"/wss/") if err != nil { return err @@ -37,3 +43,58 @@ func (s *Wss) HandleTCPConn(c net.Conn, remote *lb.Node) error { func (s *Wss) GetRemote() *lb.Node { return s.raw.GetRemote() } + +type WSSServer struct { + raw *Raw + L *zap.SugaredLogger + httpServer *http.Server +} + +func NewWSSServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *WSSServer { + s := &WSSServer{raw: raw, L: l} + mux := mux.NewRouter() + mux.HandleFunc("/", web.MakeIndexF(l)) + mux.HandleFunc("/wss/", s.HandleRequest) + + s.httpServer = &http.Server{ + Handler: mux, + Addr: listenAddr, + ReadHeaderTimeout: 30 * time.Second, + TLSConfig: mytls.DefaultTLSConfig, + } + return s +} + +func (s *WSSServer) ListenAndServe() error { + lis, err := net.Listen("tcp", s.httpServer.Addr) + if err != nil { + return err + } + defer lis.Close() + return s.httpServer.Serve(tls.NewListener(lis, s.httpServer.TLSConfig)) +} + +func (s *WSSServer) Close() error { + return s.httpServer.Close() +} + +func (s *WSSServer) HandleRequest(w http.ResponseWriter, req *http.Request) { + wsc, _, _, err := ws.UpgradeHTTP(req, w) + if err != nil { + return + } + defer wsc.Close() + remote := s.raw.TCPRemotes.Next() + web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc() + defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec() + + rc, err := s.raw.DialRemote(remote) + if err != nil { + return + } + defer rc.Close() + s.L.Infof("HandleRequest from:%s to:%s", wsc.RemoteAddr(), remote.Address) + if err := transport(rc, wsc, remote.Label); err != nil { + s.L.Infof("HandleRequest meet error from:%s to:%s err:%s", wsc.LocalAddr(), remote.Label, err.Error()) + } +}