Skip to content

Commit

Permalink
Adding gorilla websocket lib support
Browse files Browse the repository at this point in the history
  • Loading branch information
asergeyev committed Dec 14, 2017
1 parent 5b7b1d4 commit 853821f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 45 deletions.
2 changes: 2 additions & 0 deletions libwebsocketd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type Config struct {
ServerSoftware string // Value to pass to SERVER_SOFTWARE environment variable (e.g. websocketd/1.2.3).
CloseMs uint // Milliseconds to start sending signals

HandshakeTimeout time.Duration // time to finish handshake (default 1500ms)

// settings
Binary bool // Use binary communication (send data in chunks they are read from process)
ReverseLookup bool // Perform reverse DNS lookups on hostnames (useful, but slower).
Expand Down
72 changes: 33 additions & 39 deletions libwebsocketd/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,6 @@ func NewWebsocketdServer(config *Config, log *LogScope, maxforks int) *Websocket
return mux
}

// wshandshake returns closure to verify websocket origin header according to configured rules
func (h *WebsocketdServer) wshandshake(log *LogScope) func(*websocket.Config, *http.Request) error {
// CONVERT GORILLA
// this is going away mostly, except that feature of passing pre-configured headers to client
// should be implemented before gorilla's upgrader takes the control

return func(wsconf *websocket.Config, req *http.Request) error {
if len(h.Config.Headers)+len(h.Config.HeadersWs) > 0 {
if wsconf.Header == nil {
wsconf.Header = http.Header(make(map[string][]string))
}
pushHeaders(wsconf.Header, h.Config.Headers)
pushHeaders(wsconf.Header, h.Config.HeadersWs)
}
return checkOrigin(wsconf, req, h.Config, log)
}
}

func splitMimeHeader(s string) (string, string) {
p := strings.IndexByte(s, ':')
if p < 0 {
Expand All @@ -84,15 +66,9 @@ func pushHeaders(h http.Header, hdrs []string) {

// ServeHTTP muxes between WebSocket handler, CGI handler, DevConsole, Static HTML or 404.
func (h *WebsocketdServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// CONVERT GORILLA
// this is main HTTP response handler it's core that initiates websocket connection upgrade
// and input/output handling needs to be replaced with gorilla-websocket read loop...
// connection should be passed to libwebsocketd handler (see handler.go file, wshandler method)

log := h.Log.NewLevel(h.Log.LogFunc)
log.Associate("url", h.TellURL("http", req.Host, req.RequestURI))

pushHeaders(w.Header(), h.Config.Headers)
if h.Config.CommandName != "" || h.Config.UsingScriptDir {
hdrs := req.Header
upgradeRe := regexp.MustCompile("(?i)(^|[,\\s])Upgrade($|[,\\s])")
Expand All @@ -101,6 +77,7 @@ func (h *WebsocketdServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if h.noteForkCreated() == nil {
defer h.noteForkCompled()

// start figuring out if we even need to upgrade
handler, err := NewWebsocketdHandler(h, req, log)
if err != nil {
if err == ScriptNotFoundError {
Expand All @@ -113,12 +90,32 @@ func (h *WebsocketdServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}

// Now we are ready for connection upgrade dance...
wsServer := &websocket.Server{
Handshake: h.wshandshake(log),
Handler: handler.wshandler(log),
var headers http.Header
if len(h.Config.Headers)+len(h.Config.HeadersWs) > 0 {
headers = http.Header(make(map[string][]string))
pushHeaders(headers, h.Config.Headers)
pushHeaders(headers, h.Config.HeadersWs)
}

upgrader := &websocket.Upgrader{
HandshakeTimeout: h.Config.HandshakeTimeout,
CheckOrigin: func(r *http.Request) bool {
// backporting previous checkorigin for use in gorilla/websocket for now
err := checkOrigin(req, h.Config, log)
return err == nil
},
}
conn, err := upgrader.Upgrade(w, req, headers)
if err != nil {
log.Access("session", "Unable to Upgrade: %s", err)
http.Error(w, "500 Internal Error", 500)
return
}
wsServer.ServeHTTP(w, req)

// old func was used in x/net/websocket style, we reuse it here for gorilla/websocket
handler.accept(conn, log)
return

} else {
log.Error("http", "Max of possible forks already active, upgrade rejected")
http.Error(w, "429 Too Many Requests", 429)
Expand Down Expand Up @@ -234,7 +231,7 @@ func (h *WebsocketdServer) noteForkCompled() {
return
}

func checkOrigin(wsconf *websocket.Config, req *http.Request, config *Config, log *LogScope) (err error) {
func checkOrigin(req *http.Request, config *Config, log *LogScope) (err error) {
// CONVERT GORILLA:
// this is origin checking function, it's called from wshandshake which is from ServeHTTP main handler
// should be trivial to reuse in gorilla's upgrader.CheckOrigin function.
Expand All @@ -248,23 +245,20 @@ func checkOrigin(wsconf *websocket.Config, req *http.Request, config *Config, lo
if origin == "" || (origin == "null" && config.AllowOrigins == nil) {
// we don't want to trust string "null" if there is any
// enforcements are active
req.Header.Set("Origin", "file:")
origin = "file:"
}

wsconf.Origin, err = websocket.Origin(wsconf, req)
if err == nil && wsconf.Origin == nil {
log.Access("session", "rejected null origin")
return fmt.Errorf("null origin not allowed")
}
originParsed, err := url.ParseRequestURI(origin)
if err != nil {
log.Access("session", "Origin parsing error: %s", err)
return err
}
log.Associate("origin", wsconf.Origin.String())

log.Associate("origin", originParsed.String())

// If some origin restrictions are present:
if config.SameOrigin || config.AllowOrigins != nil {
originServer, originPort, err := tellHostPort(wsconf.Origin.Host, wsconf.Origin.Scheme == "https")
originServer, originPort, err := tellHostPort(originParsed.Host, originParsed.Scheme == "https")
if err != nil {
log.Access("session", "Origin hostname parsing error: %s", err)
return err
Expand All @@ -289,7 +283,7 @@ func checkOrigin(wsconf *websocket.Config, req *http.Request, config *Config, lo
if err != nil {
continue // pass bad URLs in origin list
}
if allowedURL.Scheme != wsconf.Origin.Scheme {
if allowedURL.Scheme != originParsed.Scheme {
continue // mismatch
}
allowed = allowed[pos+3:]
Expand Down
7 changes: 2 additions & 5 deletions libwebsocketd/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"net/http"
"strings"
"testing"

"golang.org/x/net/websocket"
)

var tellHostPortTests = []struct {
Expand Down Expand Up @@ -101,7 +99,6 @@ Sec-WebSocket-Version: 13
log := new(LogScope)
log.LogFunc = func(*LogScope, LogLevel, string, string, string, ...interface{}) {}

wsconf := &websocket.Config{Version: websocket.ProtocolVersionHybi13}
config := new(Config)

if testcase.reqtls == ReqHTTPS { // Fake TLS
Expand All @@ -115,11 +112,11 @@ Sec-WebSocket-Version: 13
config.AllowOrigins = testcase.allowed
}

err = checkOrigin(wsconf, req, config, log)
err = checkOrigin(req, config, log)
if testcase.getsErr == ReturnsError && err == nil {
t.Errorf("Test case %#v did not get an error", testcase.name)
} else if testcase.getsErr == ReturnsPass && err != nil {
t.Errorf("Test case %#v got error while should've", testcase.name)
t.Errorf("Test case %#v got error while expected to pass", testcase.name)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion libwebsocketd/websocket_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (we *WebSocketEndpoint) Output() chan []byte {
func (we *WebSocketEndpoint) Send(msg []byte) bool {
w, err := we.ws.NextWriter(we.mtype)
if err == nil {
_, err := w.Write(msg)
_, err = w.Write(msg)
}
w.Close() // could need error handling

Expand Down

0 comments on commit 853821f

Please sign in to comment.