Skip to content

Commit

Permalink
improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed Nov 13, 2022
1 parent c907242 commit c72e497
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 194 deletions.
105 changes: 13 additions & 92 deletions internal/relay/relay.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
package relay

import (
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
"time"

"github.com/Ehco1996/ehco/internal/config"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
mytls "github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/internal/transporter"
"github.com/Ehco1996/ehco/internal/web"
"github.com/Ehco1996/ehco/pkg/log"
"github.com/gorilla/mux"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -195,116 +191,41 @@ func (r *Relay) RunLocalUDPServer() error {
}

func (r *Relay) RunLocalMTCPServer() error {
mTCPServer := transporter.NewMTCPServer(r.L.Named("MTCPServer"), r.LocalTCPAddr)
tp := r.TP.(*transporter.Raw)
mTCPServer := transporter.NewMTCPServer(r.LocalTCPAddr.String(), tp, r.L.Named("MTCPServer"))
r.closeTcpF = func() error {
return mTCPServer.Close()
}

go func() {
r.L.Infof("Start MTCP relay server %s", r.Name)
mTCPServer.ListenAndServe()
}()

tp := r.TP.(*transporter.Raw)
for {
conn, e := mTCPServer.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
remote := tp.GetRemote()
web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc()
defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec()
defer c.Close()
if err := tp.HandleTCPConn(c, remote); err != nil {
r.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err)
}
}(conn)
}
r.L.Infof("Start MTCP relay server %s", r.Name)
return mTCPServer.ListenAndServe()
}

func (r *Relay) RunLocalWSServer() error {
tp := r.TP.(*transporter.Raw)
mux := mux.NewRouter()
mux.HandleFunc("/", web.MakeIndexF(r.L))
mux.HandleFunc("/ws/", tp.HandleWsRequest)
server := &http.Server{
Addr: r.LocalTCPAddr.String(),
ReadHeaderTimeout: 30 * time.Second,
Handler: mux,
}
lis, err := net.Listen("tcp", r.LocalTCPAddr.String())
if err != nil {
return err
}
defer lis.Close()
wsServer := transporter.NewWSServer(r.LocalTCPAddr.String(), tp, r.L.Named("WSServer"))
r.closeTcpF = func() error {
return lis.Close()
return wsServer.Close()
}
r.L.Infof("Start WS relay Server %s", r.Name)
return server.Serve(lis)
return wsServer.ListenAndServe()
}

func (r *Relay) RunLocalWSSServer() error {
tp := r.TP.(*transporter.Raw)
mux := mux.NewRouter()
mux.HandleFunc("/", web.MakeIndexF(r.L))
mux.HandleFunc("/wss/", tp.HandleWssRequest)

server := &http.Server{
Addr: r.LocalTCPAddr.String(),
TLSConfig: mytls.DefaultTLSConfig,
ReadHeaderTimeout: 30 * time.Second,
Handler: mux,
}
lis, err := net.Listen("tcp", r.LocalTCPAddr.String())
if err != nil {
return err
}
defer lis.Close()
wssServer := transporter.NewWSSServer(r.LocalTCPAddr.String(), tp, r.L.Named("NewWSSServer"))
r.closeTcpF = func() error {
return lis.Close()
return wssServer.Close()
}
r.L.Infof("Start WSS relay Server %s", r.Name)
return server.Serve(tls.NewListener(lis, server.TLSConfig))
return wssServer.ListenAndServe()
}

func (r *Relay) RunLocalMWSSServer() error {
tp := r.TP.(*transporter.Raw)
mwssServer := transporter.NewMWSSServer(r.L.Named("MWSSServer"))
mux := mux.NewRouter()
mux.Handle("/", web.MakeIndexF(r.L))
mux.Handle("/mwss/", http.HandlerFunc(mwssServer.Upgrade))
httpServer := &http.Server{
Addr: r.LocalTCPAddr.String(),
Handler: mux,
TLSConfig: mytls.DefaultTLSConfig,
ReadHeaderTimeout: 30 * time.Second,
}
mwssServer.Server = httpServer

lis, err := net.Listen("tcp", r.LocalTCPAddr.String())
if err != nil {
return err
}
defer lis.Close()
mwssServer := transporter.NewMWSSServer(r.LocalTCPAddr.String(), tp, r.L.Named("MWSSServer"))
r.closeTcpF = func() error {
return lis.Close()
return mwssServer.Close()
}
r.L.Infof("Start MWSS relay Server %s", r.Name)
go func() {
err := httpServer.Serve(tls.NewListener(lis, httpServer.TLSConfig))
if err != nil {
mwssServer.ErrChan <- err
}
close(mwssServer.ErrChan)
}()

for {
conn, e := mwssServer.Accept()
if e != nil {
return e
}
go tp.HandleMWssRequest(conn)
}
return mwssServer.ListenAndServe()
}
63 changes: 41 additions & 22 deletions internal/transporter/mtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"

"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/web"
"github.com/xtaci/smux"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -38,20 +39,22 @@ func (s *MTCP) GetRemote() *lb.Node {
}

type MTCPServer struct {
listenAddr *net.TCPAddr
listener *net.TCPListener
raw *Raw
listenAddr string
listener net.Listener
L *zap.SugaredLogger

ConnChan chan net.Conn
ErrChan chan error
L *zap.SugaredLogger
errChan chan error
connChan chan net.Conn
}

func NewMTCPServer(l *zap.SugaredLogger, listenAddr *net.TCPAddr) *MTCPServer {
func NewMTCPServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *MTCPServer {
return &MTCPServer{
ConnChan: make(chan net.Conn, 1024),
ErrChan: make(chan error, 1),
L: l,
raw: raw,
listenAddr: listenAddr,
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
}
}

Expand All @@ -77,7 +80,7 @@ func (s *MTCPServer) mux(conn net.Conn) {
break
}
select {
case s.ConnChan <- stream:
case s.connChan <- stream:
default:
stream.Close()
s.L.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
Expand All @@ -87,29 +90,45 @@ func (s *MTCPServer) mux(conn net.Conn) {

func (s *MTCPServer) Accept() (conn net.Conn, err error) {
select {
case conn = <-s.ConnChan:
case err = <-s.ErrChan:
case conn = <-s.connChan:
case err = <-s.errChan:
}
return
}

func (s *MTCPServer) ListenAndServe() {
lis, err := net.ListenTCP("tcp", s.listenAddr)
func (s *MTCPServer) ListenAndServe() error {
lis, err := net.Listen("tcp", s.listenAddr)
if err != nil {
s.ErrChan <- err
return
return err
}
s.listener = lis
for {
c, err := lis.AcceptTCP()
if err != nil {
s.ErrChan <- err
continue

go func() {
for {
c, err := lis.Accept()
if err != nil {
s.errChan <- err
continue
}
go s.mux(c)
}
}()

go s.mux(c)
for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
remote := s.raw.GetRemote()
web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc()
defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec()
defer c.Close()
if err := s.raw.HandleTCPConn(c, remote); err != nil {
s.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err)
}
}(conn)
}

}

func (s *MTCPServer) Close() error {
Expand Down
76 changes: 60 additions & 16 deletions internal/transporter/mwss.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package transporter

import (
"context"
"crypto/tls"
"net"
"net/http"
"time"

"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
mytls "github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/internal/web"
"github.com/gobwas/ws"
"github.com/gorilla/mux"
"github.com/xtaci/smux"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -42,21 +46,61 @@ func (s *Mwss) GetRemote() *lb.Node {
}

type MWSSServer struct {
Server *http.Server
ConnChan chan net.Conn
ErrChan chan error
L *zap.SugaredLogger
raw *Raw
httpServer *http.Server
L *zap.SugaredLogger

connChan chan net.Conn
errChan chan error
}

func NewMWSSServer(l *zap.SugaredLogger) *MWSSServer {
return &MWSSServer{
ConnChan: make(chan net.Conn, 1024),
ErrChan: make(chan error, 1),
func NewMWSSServer(listenAddr string, raw *Raw, l *zap.SugaredLogger) *MWSSServer {
s := &MWSSServer{
raw: raw,
L: l,
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
}

mux := mux.NewRouter()
mux.Handle("/", web.MakeIndexF(l))
mux.Handle("/mwss/", http.HandlerFunc(s.HandleRequest))
s.httpServer = &http.Server{
Addr: listenAddr,
Handler: mux,
TLSConfig: mytls.DefaultTLSConfig,
ReadHeaderTimeout: 30 * time.Second,
}
return s
}

func (s *MWSSServer) ListenAndServe() error {
lis, err := net.Listen("tcp", s.httpServer.Addr)
if err != nil {
return err
}
go func() {
s.errChan <- s.httpServer.Serve(tls.NewListener(lis, s.httpServer.TLSConfig))
}()

for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
remote := s.raw.GetRemote()
web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Inc()
defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TCP).Dec()
defer c.Close()
if err := s.raw.HandleTCPConn(c, remote); err != nil {
s.L.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err)
}
}(conn)
}
}

func (s *MWSSServer) Upgrade(w http.ResponseWriter, r *http.Request) {
func (s *MWSSServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
s.L.Error(err)
Expand All @@ -72,13 +116,13 @@ func (s *MWSSServer) mux(conn net.Conn) {
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
s.L.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.Server.Addr, err)
s.L.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.httpServer.Addr, err)
return
}
defer session.Close()

s.L.Debugf("session init %s %s", conn.RemoteAddr(), s.Server.Addr)
defer s.L.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.Server.Addr)
s.L.Debugf("session init %s %s", conn.RemoteAddr(), s.httpServer.Addr)
defer s.L.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.httpServer.Addr)

for {
stream, err := session.AcceptStream()
Expand All @@ -87,7 +131,7 @@ func (s *MWSSServer) mux(conn net.Conn) {
break
}
select {
case s.ConnChan <- stream:
case s.connChan <- stream:
default:
stream.Close()
s.L.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
Expand All @@ -97,14 +141,14 @@ func (s *MWSSServer) mux(conn net.Conn) {

func (s *MWSSServer) Accept() (conn net.Conn, err error) {
select {
case conn = <-s.ConnChan:
case err = <-s.ErrChan:
case conn = <-s.connChan:
case err = <-s.errChan:
}
return
}

func (s *MWSSServer) Close() error {
return s.Server.Close()
return s.httpServer.Close()
}

type MWSSClient struct {
Expand Down
Loading

0 comments on commit c72e497

Please sign in to comment.