diff --git a/libwebsocketd/handler.go b/libwebsocketd/handler.go index 9c1ff28a..9b208ef7 100644 --- a/libwebsocketd/handler.go +++ b/libwebsocketd/handler.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) var ScriptNotFoundError = errors.New("script not found") @@ -57,13 +57,6 @@ func NewWebsocketdHandler(s *WebsocketdServer, req *http.Request, log *LogScope) return wsh, nil } -// wshandler returns function that executes code with given log context -func (wsh *WebsocketdHandler) wshandler(log *LogScope) websocket.Handler { - return websocket.Handler(func(ws *websocket.Conn) { - wsh.accept(ws, log) - }) -} - func (wsh *WebsocketdHandler) accept(ws *websocket.Conn, log *LogScope) { defer ws.Close() diff --git a/libwebsocketd/http.go b/libwebsocketd/http.go index 1711b0c2..4f2e61d0 100644 --- a/libwebsocketd/http.go +++ b/libwebsocketd/http.go @@ -19,7 +19,7 @@ import ( "regexp" "strings" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) var ForkNotAllowedError = errors.New("too many forks active") diff --git a/libwebsocketd/websocket_endpoint.go b/libwebsocketd/websocket_endpoint.go index 111bb516..f2d053a5 100644 --- a/libwebsocketd/websocket_endpoint.go +++ b/libwebsocketd/websocket_endpoint.go @@ -7,8 +7,9 @@ package libwebsocketd import ( "io" + "io/ioutil" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) // CONVERT GORILLA @@ -19,16 +20,20 @@ type WebSocketEndpoint struct { ws *websocket.Conn output chan []byte log *LogScope - bin bool + mtype int } func NewWebSocketEndpoint(ws *websocket.Conn, bin bool, log *LogScope) *WebSocketEndpoint { - return &WebSocketEndpoint{ + endpoint := &WebSocketEndpoint{ ws: ws, output: make(chan []byte), log: log, - bin: bin, + mtype: websocket.TextMessage, } + if bin { + endpoint.mtype = websocket.BinaryMessage + } + return endpoint } func (we *WebSocketEndpoint) Terminate() { @@ -40,53 +45,48 @@ func (we *WebSocketEndpoint) Output() chan []byte { } func (we *WebSocketEndpoint) Send(msg []byte) bool { - var err error - if we.bin { - err = websocket.Message.Send(we.ws, msg) - } else { - err = websocket.Message.Send(we.ws, string(msg)) + w, err := we.ws.NextWriter(we.mtype) + if err == nil { + _, err := w.Write(msg) } + w.Close() // could need error handling + if err != nil { we.log.Trace("websocket", "Cannot send: %s", err) return false } + return true } func (we *WebSocketEndpoint) StartReading() { - if we.bin { - go we.read_binary_frames() - } else { - go we.read_text_frames() - } + go we.read_frames() } -func (we *WebSocketEndpoint) read_text_frames() { +func (we *WebSocketEndpoint) read_frames() { for { - var msg string - err := websocket.Message.Receive(we.ws, &msg) + mtype, rd, err := we.ws.NextReader() if err != nil { - if err != io.EOF { - we.log.Debug("websocket", "Cannot receive: %s", err) - } + we.log.Debug("websocket", "Cannot receive: %s", err) break } - we.output <- append([]byte(msg), '\n') - } - close(we.output) -} + if mtype != we.mtype { + we.log.Debug("websocket", "Received message of type that we did not expect... Ignoring...") + } -func (we *WebSocketEndpoint) read_binary_frames() { - for { - var msg []byte - err := websocket.Message.Receive(we.ws, &msg) - if err != nil { - if err != io.EOF { - we.log.Debug("websocket", "Cannot receive: %s", err) - } + p, err := ioutil.ReadAll(rd) + if err != nil && err != io.EOF { + we.log.Debug("websocket", "Cannot read received message: %s", err) break } - we.output <- msg + switch mtype { + case websocket.TextMessage: + we.output <- append(p, '\n') + case websocket.BinaryMessage: + we.output <- p + default: + we.log.Debug("websocket", "Received message of unknown type: %d", mtype) + } } close(we.output) }