From 3c5ef1c1ba5f15ab450f51dacf32e762cd34f16e Mon Sep 17 00:00:00 2001 From: nightfury1204 Date: Thu, 29 Aug 2024 22:52:02 +0100 Subject: [PATCH] Fix proxy websocket --- go.mod | 2 +- go.sum | 2 + pkg/api/controllers.go | 10 +- pkg/helpers/stream.go | 61 ---- pkg/helpers/stream_test.go | 18 -- provider/aws/proxy.go | 7 +- provider/k8s/proxy.go | 3 +- vendor/github.com/convox/stdsdk/client.go | 65 +--- .../github.com/convox/stdsdk/stream_helper.go | 304 ++++++++++++++++++ vendor/modules.txt | 2 +- 10 files changed, 324 insertions(+), 150 deletions(-) delete mode 100644 pkg/helpers/stream.go delete mode 100644 pkg/helpers/stream_test.go create mode 100644 vendor/github.com/convox/stdsdk/stream_helper.go diff --git a/go.mod b/go.mod index b5039d0e9a..0fcdafe97a 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/convox/logger v0.0.0-20180522214415-e39179955b52 github.com/convox/stdapi v1.1.3-0.20221110171947-8d98f61e61ed github.com/convox/stdcli v0.0.0-20230203181735-23ed17b69b51 - github.com/convox/stdsdk v0.0.0-20190422120437-3e80a397e377 + github.com/convox/stdsdk v0.0.2 github.com/convox/version v0.0.0-20160822184233-ffefa0d565d2 github.com/docker/docker v1.13.1 github.com/docker/go-units v0.3.2 diff --git a/go.sum b/go.sum index 07254b7f79..e3deaa015d 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/convox/stdcli v0.0.0-20230203181735-23ed17b69b51 h1:03Yia3LZwsHUovVEP github.com/convox/stdcli v0.0.0-20230203181735-23ed17b69b51/go.mod h1:kLknwv4KTN9f9fZ2DDvG/L9ndXkcijaexIZ0enACvSc= github.com/convox/stdsdk v0.0.0-20190422120437-3e80a397e377 h1:PuSJ72MD0mYsMCTvTQ1YydIbQUWtEykNHyweI/vA0PY= github.com/convox/stdsdk v0.0.0-20190422120437-3e80a397e377/go.mod h1:y1vtmkDKBkWSQ6e2gPXAyz1NCuWZ2x3vrP/SFeDDNco= +github.com/convox/stdsdk v0.0.2 h1:WPJ697SzzawzUBz5hW8i9VbWqwN9DM7aSfusb9USCi8= +github.com/convox/stdsdk v0.0.2/go.mod h1:y1vtmkDKBkWSQ6e2gPXAyz1NCuWZ2x3vrP/SFeDDNco= github.com/convox/version v0.0.0-20160822184233-ffefa0d565d2 h1:tdp/1KHBnbne0yT1yuKnAdOTBHRue9yQ4oON8rzGgZc= github.com/convox/version v0.0.0-20160822184233-ffefa0d565d2/go.mod h1:s8HHEf4LLsmPppeubX/A5bz1JpLYkDXbu+ciuYMTk8A= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/pkg/api/controllers.go b/pkg/api/controllers.go index d3e3bbe96a..fc0ea9887b 100644 --- a/pkg/api/controllers.go +++ b/pkg/api/controllers.go @@ -9,6 +9,7 @@ import ( "github.com/convox/rack/pkg/structs" "github.com/convox/stdapi" + "github.com/convox/stdsdk" ) func (s *Server) AppCancel(c *stdapi.Context) error { @@ -579,14 +580,13 @@ func (s *Server) InstanceShell(c *stdapi.Context) error { } id := c.Var("id") - rw := c var opts structs.InstanceShellOptions if err := stdapi.UnmarshalOptions(c.Request(), &opts); err != nil { return err } - v, err := s.provider(c).WithContext(c.Context()).InstanceShell(id, rw, opts) + v, err := s.provider(c).WithContext(c.Context()).InstanceShell(id, stdsdk.NewAdapterWs(c.Websocket()), opts) if err != nil { return err } @@ -731,14 +731,13 @@ func (s *Server) ProcessExec(c *stdapi.Context) error { app := c.Var("app") pid := c.Var("pid") command := c.Value("command") - rw := c var opts structs.ProcessExecOptions if err := stdapi.UnmarshalOptions(c.Request(), &opts); err != nil { return err } - v, err := s.provider(c).WithContext(c.Context()).ProcessExec(app, pid, command, rw, opts) + v, err := s.provider(c).WithContext(c.Context()).ProcessExec(app, pid, command, stdsdk.NewAdapterWs(c.Websocket()), opts) if err != nil { return err } @@ -874,7 +873,6 @@ func (s *Server) Proxy(c *stdapi.Context) error { } host := c.Var("host") - rw := c port, cerr := strconv.Atoi(c.Var("port")) if cerr != nil { @@ -886,7 +884,7 @@ func (s *Server) Proxy(c *stdapi.Context) error { return err } - err := s.provider(c).WithContext(c.Context()).Proxy(host, port, rw, opts) + err := s.provider(c).WithContext(c.Context()).Proxy(host, port, stdsdk.NewAdapterWs(c.Websocket()), opts) if err != nil { return err } diff --git a/pkg/helpers/stream.go b/pkg/helpers/stream.go deleted file mode 100644 index 2dcd8109d4..0000000000 --- a/pkg/helpers/stream.go +++ /dev/null @@ -1,61 +0,0 @@ -package helpers - -import ( - "io" - "net/http" -) - -type ReadWriter struct { - io.Reader - io.Writer -} - -func Pipe(a, b io.ReadWriter) error { - ch := make(chan error) - - go halfPipe(a, b, ch) - go halfPipe(b, a, ch) - - if err := <-ch; err != nil { - return err - } - - if err := <-ch; err != nil { - return err - } - - return nil -} - -func halfPipe(w io.Writer, r io.Reader, ch chan error) { - ch <- Stream(w, r) - - if c, ok := w.(io.Closer); ok { - c.Close() - } -} - -func Stream(w io.Writer, r io.Reader) error { - buf := make([]byte, 1024) - - for { - n, err := r.Read(buf) - if n > 0 { - if _, err := w.Write(buf[0:n]); err != nil { - if err == io.ErrClosedPipe { - return nil - } - return err - } - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - } - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } -} diff --git a/pkg/helpers/stream_test.go b/pkg/helpers/stream_test.go deleted file mode 100644 index 0abce807a5..0000000000 --- a/pkg/helpers/stream_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package helpers_test - -import ( - "bytes" - "testing" - - "github.com/convox/rack/pkg/helpers" - "github.com/stretchr/testify/require" -) - -func TestStream(t *testing.T) { - text := "hello world" - w := &bytes.Buffer{} - r := bytes.NewReader([]byte(text)) - err := helpers.Stream(w, r) - require.NoError(t, err) - require.Equal(t, text, w.String()) -} diff --git a/provider/aws/proxy.go b/provider/aws/proxy.go index 02e2243824..79145028fa 100644 --- a/provider/aws/proxy.go +++ b/provider/aws/proxy.go @@ -7,12 +7,12 @@ import ( "net" "time" - "github.com/convox/rack/pkg/helpers" "github.com/convox/rack/pkg/structs" + "github.com/convox/stdsdk" ) func (p *Provider) Proxy(host string, port int, rw io.ReadWriter, opts structs.ProxyOptions) error { - cn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 3*time.Second) + cn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 5*time.Second) if err != nil { return err } @@ -21,7 +21,8 @@ func (p *Provider) Proxy(host string, port int, rw io.ReadWriter, opts structs.P cn = tls.Client(cn, &tls.Config{}) } - if err := helpers.Pipe(cn, rw); err != nil { + if err := stdsdk.CopyStreamToEachOther(cn, rw); err != nil { + p.log.Errorf("proxy %s", err) return err } diff --git a/provider/k8s/proxy.go b/provider/k8s/proxy.go index 893569a24f..ba5d3d58ab 100644 --- a/provider/k8s/proxy.go +++ b/provider/k8s/proxy.go @@ -9,6 +9,7 @@ import ( "github.com/convox/rack/pkg/helpers" "github.com/convox/rack/pkg/structs" + "github.com/convox/stdsdk" ) func (p *Provider) Proxy(host string, port int, rw io.ReadWriter, opts structs.ProxyOptions) error { @@ -21,7 +22,7 @@ func (p *Provider) Proxy(host string, port int, rw io.ReadWriter, opts structs.P cn = tls.Client(cn, &tls.Config{}) } - if err := helpers.Pipe(cn, rw); err != nil { + if err := stdsdk.CopyStreamToEachOther(cn, rw); err != nil { return err } diff --git a/vendor/github.com/convox/stdsdk/client.go b/vendor/github.com/convox/stdsdk/client.go index e1b8ba0682..211ea0a113 100644 --- a/vendor/github.com/convox/stdsdk/client.go +++ b/vendor/github.com/convox/stdsdk/client.go @@ -179,9 +179,7 @@ func (c *Client) Delete(path string, opts RequestOptions, out interface{}) error } func (c *Client) Websocket(path string, opts RequestOptions) (io.ReadCloser, error) { - var u url.URL - - u = *c.Endpoint + u := *c.Endpoint u.Scheme = "wss" @@ -218,64 +216,13 @@ func (c *Client) Websocket(path string, opts RequestOptions) (io.ReadCloser, err h.Set("Content-Type", ct) - go copyToWebsocket(c.ctx, ws, or) - go copyFromWebsocket(c.ctx, w, ws) + adapterWs := NewAdapterWs(ws) - return r, nil -} + go copyToWS(c.ctx, adapterWs, or) + go copyFromWS(c.ctx, adapterWs, w) + go WsKeepAlivePing(c.ctx, adapterWs) -func copyToWebsocket(ctx context.Context, ws *websocket.Conn, r io.Reader) { - if r == nil { - return - } - - // used as eof - defer ws.WriteMessage(websocket.BinaryMessage, []byte{}) - - buf := make([]byte, 1024) - - for { - select { - case <-ctx.Done(): - return - default: - n, err := r.Read(buf) - switch err { - case io.EOF: - return - case nil: - ws.WriteMessage(websocket.TextMessage, buf[0:n]) - default: - return - } - } - } -} - -func copyFromWebsocket(ctx context.Context, w io.WriteCloser, ws *websocket.Conn) { - defer w.Close() - - for { - select { - case <-ctx.Done(): - return - default: - code, data, err := ws.ReadMessage() - switch err { - case io.EOF: - return - case nil: - switch code { - case websocket.TextMessage: - w.Write(data) - case websocket.BinaryMessage: // interpreted as eof - return - } - default: - return - } - } - } + return r, nil } func (c *Client) Request(method, path string, opts RequestOptions) (*http.Request, error) { diff --git a/vendor/github.com/convox/stdsdk/stream_helper.go b/vendor/github.com/convox/stdsdk/stream_helper.go new file mode 100644 index 0000000000..9a50a39792 --- /dev/null +++ b/vendor/github.com/convox/stdsdk/stream_helper.go @@ -0,0 +1,304 @@ +package stdsdk + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// some caveats apply: https://github.com/gorilla/websocket/issues/441 +type AdapterWs struct { + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func NewAdapterWs(conn *websocket.Conn) *AdapterWs { + return &AdapterWs{ + conn: conn, + readMutex: sync.Mutex{}, + writeMutex: sync.Mutex{}, + } +} + +func (a *AdapterWs) Read(b []byte) (int, error) { + // Read() can be called concurrently, and we mutate some internal state here + a.readMutex.Lock() + defer a.readMutex.Unlock() + + if a.reader == nil { + messageType, reader, err := a.conn.NextReader() + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return 0, io.EOF + } + if err != nil { + return 0, err + } + + if messageType == websocket.BinaryMessage { + return 0, io.EOF + } + + if messageType != websocket.TextMessage { + return 0, nil + } + + a.reader = reader + } + + bytesRead, err := a.reader.Read(b) + if err != nil { + a.reader = nil + + // EOF for the current Websocket frame, more will probably come so.. + if err == io.EOF { + // .. we must hide this from the caller since our semantics are a + // stream of bytes across many frames + err = nil + } + } + + return bytesRead, err +} + +func (a *AdapterWs) ReadMessage() (int, []byte, error) { + a.readMutex.Lock() + defer a.readMutex.Unlock() + return a.conn.ReadMessage() +} + +func (a *AdapterWs) Write(b []byte) (int, error) { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + + nextWriter, err := a.conn.NextWriter(websocket.TextMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (a *AdapterWs) WriteMessage(messageType int, data []byte) error { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + return a.conn.WriteMessage(messageType, data) +} + +func (a *AdapterWs) Close() error { + return a.conn.Close() +} + +func (a *AdapterWs) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +func (a *AdapterWs) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +func (a *AdapterWs) SetDeadline(t time.Time) error { + if err := a.SetReadDeadline(t); err != nil { + return err + } + + return a.SetWriteDeadline(t) +} + +func (a *AdapterWs) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +func (a *AdapterWs) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} + +func chanFromReader(r io.Reader) (chan []byte, chan error) { + c := make(chan []byte) + errCh := make(chan error) + + go func() { + b := make([]byte, 1024) + + for { + n, err := r.Read(b) + if n > 0 { + res := make([]byte, n) + // Copy the buffer so it doesn't get changed while read by the recipient. + copy(res, b[:n]) + c <- res + } + if err != nil { + if err != io.EOF { + errCh <- err + } + c <- nil + return + } + } + }() + + return c, errCh +} + +// CopyFromToWsTcp accepts a websocket connection and TCP connection and copies data between them +func CopyFromToWsTcp(wsConn *AdapterWs, tcpConn net.Conn) error { + wsChan, wsErrChan := chanFromReader(wsConn) + tcpChan, tcpErrChan := chanFromReader(tcpConn) + + defer wsConn.Close() + defer tcpConn.Close() + for { + select { + case wsData := <-wsChan: + if wsData == nil { + return fmt.Errorf("TCP connection closed: D: %s, S: %s", tcpConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + } else { + _, err := tcpConn.Write(wsData) + if err == io.ErrClosedPipe { + return fmt.Errorf("TCP connection closed: D: %s, S: %s", tcpConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + } + } + case tcpData := <-tcpChan: + if tcpData == nil { + return fmt.Errorf("TCP connection closed: D: %s, S: %s", tcpConn.LocalAddr().String(), wsConn.LocalAddr().String()) + } else { + _, err := wsConn.Write(tcpData) + if err != nil { + return fmt.Errorf("TCP connection closed: D: %s, S: %s", tcpConn.LocalAddr().String(), wsConn.LocalAddr().String()) + } + } + case err := <-wsErrChan: + return err + case err := <-tcpErrChan: + return err + } + } +} + +func CopyStreamToEachOther(fromConn io.ReadWriter, toConn io.ReadWriter) error { + fromChan, fromErrChan := chanFromReader(fromConn) + toChan, toErrChan := chanFromReader(toConn) + + if xc, ok := toConn.(io.Closer); ok { + defer xc.Close() + } + + if yc, ok := fromConn.(io.Closer); ok { + defer yc.Close() + } + + for { + select { + case toData := <-toChan: + if toData == nil { + return fmt.Errorf("TCP connection closed from destination") + } else { + _, err := fromConn.Write(toData) + if err != nil { + if err == io.ErrClosedPipe { + return nil + } + return err + } + } + case fromData := <-fromChan: + if fromData == nil { + return fmt.Errorf("TCP connection closed from source") + } else { + _, err := toConn.Write(fromData) + if err != nil { + if err == io.ErrClosedPipe { + return nil + } + return err + } + } + case err := <-toErrChan: + return err + case err := <-fromErrChan: + return err + } + } +} + +func WsKeepAlivePing(ctx context.Context, ws *AdapterWs) { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + ws.WriteMessage(websocket.PingMessage, []byte{}) + } + } +} + +func copyToWS(ctx context.Context, ws *AdapterWs, r io.Reader) error { + if r == nil { + return nil + } + rChan, rErrChan := chanFromReader(r) + + // used as eof + defer ws.WriteMessage(websocket.BinaryMessage, []byte{}) + + for { + select { + case <-ctx.Done(): + return nil + case data := <-rChan: + if data == nil { + return fmt.Errorf("TCP connection closed from destination") + } else { + _, err := ws.Write(data) + if err != nil { + if err == io.ErrClosedPipe { + return nil + } + return err + } + } + case err := <-rErrChan: + return err + } + } +} + +func copyFromWS(ctx context.Context, ws *AdapterWs, w io.WriteCloser) error { + wsChan, wsErrChan := chanFromReader(ws) + + defer w.Close() + + for { + select { + case <-ctx.Done(): + return nil + case data := <-wsChan: + if data == nil { + return fmt.Errorf("TCP connection closed from destination") + } else { + _, err := w.Write(data) + if err != nil { + if err == io.ErrClosedPipe { + return nil + } + return err + } + } + case err := <-wsErrChan: + return err + } + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 489cac3528..0f0908301b 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -138,7 +138,7 @@ github.com/convox/stdapi # github.com/convox/stdcli v0.0.0-20230203181735-23ed17b69b51 ## explicit; go 1.13 github.com/convox/stdcli -# github.com/convox/stdsdk v0.0.0-20190422120437-3e80a397e377 +# github.com/convox/stdsdk v0.0.2 ## explicit github.com/convox/stdsdk # github.com/convox/version v0.0.0-20160822184233-ffefa0d565d2