Skip to content

Commit

Permalink
Merge pull request joewalnes#61 from asergeyev/originchecks
Browse files Browse the repository at this point in the history
Originchecks per joewalnes#20, closes joewalnes#58 too...

--origin and --sameorigin flags added, code refactored to separate some logical parts
  • Loading branch information
asergeyev committed May 16, 2014
2 parents 7102858 + bdc04fc commit 59d46aa
Show file tree
Hide file tree
Showing 17 changed files with 595 additions and 262 deletions.
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ func parseCommandLine() *Config {
cgiDirFlag := flag.String("cgidir", "", "Serve CGI scripts from this directory over HTTP")
devConsoleFlag := flag.Bool("devconsole", false, "Enable development console (cannot be used in conjunction with --staticdir)")
passEnvFlag := flag.String("passenv", defaultPassEnv[runtime.GOOS], "List of envvars to pass to subprocesses (others will be cleaned out)")
sameOriginFlag := flag.Bool("sameorigin", false, "Restrict upgrades if origin and host headers differ")
allowOriginsFlag := flag.String("origin", "", "Restrict upgrades if origin does not match the list")

err := flag.CommandLine.Parse(os.Args[1:])
if err != nil {
Expand Down Expand Up @@ -176,6 +178,11 @@ func parseCommandLine() *Config {
}
}

if *allowOriginsFlag != "" {
config.AllowOrigins = strings.Split(*allowOriginsFlag, ",")
}
config.SameOrigin = *sameOriginFlag

args := flag.Args()
if len(args) < 1 && config.ScriptDir == "" && config.StaticDir == "" && config.CgiDir == "" {
fmt.Fprintf(os.Stderr, "Please specify COMMAND or provide --dir, --staticdir or --cgidir argument.\n")
Expand Down
12 changes: 12 additions & 0 deletions examples/bash/send-receive.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash


while true; do
cnt=0
while read -t 0.01 line; do
cnt=$(($cnt + 1))
done

echo `date` "($cnt line(s) received)"
sleep $((RANDOM % 10 + 1)) & wait
done
11 changes: 11 additions & 0 deletions help.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ Options:
Use square brackets to specify IPv6 address.
Default: "" (all)
--sameorigin={true,false} Restrict (HTTP 403) protocol upgrades if the
Origin header does not match to requested HTTP
Host. Default: false.
--origin=host[:port][,host[:port]...]
Restrict (HTTP 403) protocol upgrades if the
Origin header does not match to one of the host
and port combinations listed. If the port is not
specified, any port number will match.
Default: "" (allow any origin)
--ssl Listen for HTTPS socket instead of HTTP.
--sslcert=FILE All three options must be used or all of
--sslkey=FILE them should be omitted.
Expand Down
2 changes: 2 additions & 0 deletions libwebsocketd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ type Config struct {
ServerSoftware string // Value to pass to SERVER_SOFTWARE environment variable (e.g. websocketd/1.2.3).
Env []string // Additional environment variables to pass to process ("key=value").
ParentEnv []string // Variables kept from os.Environ() before sanitizing it for subprocess.
AllowOrigins []string // List of allowed origin addresses for websocket upgrade.
SameOrigin bool // If set, requires websocket upgrades to be performed from same origin only.
}
4 changes: 3 additions & 1 deletion libwebsocketd/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ package libwebsocketd
// We can get by without jQuery or Bootstrap for this one ;).

const (
ConsoleContent = `
defaultConsoleContent = `
<!--
websocketd console
Expand Down Expand Up @@ -330,3 +330,5 @@ Full documentation at http://websocketd.com/
`
)

var ConsoleContent = defaultConsoleContent
23 changes: 22 additions & 1 deletion libwebsocketd/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,28 @@
package libwebsocketd

type Endpoint interface {
StartReading()
Terminate()
Output() chan string
Send(msg string) bool
Send(string) bool
}

func PipeEndpoints(e1, e2 Endpoint) {
e1.StartReading()
e2.StartReading()

defer e1.Terminate()
defer e2.Terminate()
for {
select {
case msgOne, ok := <-e1.Output():
if !ok || !e2.Send(msgOne) {
return
}
case msgTwo, ok := <-e2.Output():
if !ok || !e1.Send(msgTwo) {
return
}
}
}
}
73 changes: 73 additions & 0 deletions libwebsocketd/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package libwebsocketd

import (
"strconv"
"testing"
"time"
)

var eol_tests = []string{
"", "\n", "\r\n", "ok\n", "ok\n",
"quite long string for our test\n",
"quite long string for our test\r\n",
}

var eol_answers = []string{
"", "", "", "ok", "ok",
"quite long string for our test", "quite long string for our test",
}

func TestTrimEOL(t *testing.T) {
for n := 0; n < len(eol_tests); n++ {
answ := trimEOL(eol_tests[n])
if answ != eol_answers[n] {
t.Errorf("Answer '%s' did not match predicted '%s'", answ, eol_answers[n])
}
}
}

func BenchmarkTrimEOL(b *testing.B) {
for n := 0; n < b.N; n++ {
trimEOL(eol_tests[n%len(eol_tests)])
}
}

type TestEndpoint struct {
limit int
prefix string
c chan string
result []string
}

func (e *TestEndpoint) StartReading() {
go func() {
for i := 0; i < e.limit; i++ {
e.c <- e.prefix + strconv.Itoa(i)
}
time.Sleep(time.Millisecond) // should be enough for smaller channel to catch up with long one
close(e.c)
}()
}

func (e *TestEndpoint) Terminate() {
}

func (e *TestEndpoint) Output() chan string {
return e.c
}

func (e *TestEndpoint) Send(msg string) bool {
e.result = append(e.result, msg)
return true
}

func TestEndpointPipe(t *testing.T) {
one := &TestEndpoint{2, "one:", make(chan string), make([]string, 0)}
two := &TestEndpoint{4, "two:", make(chan string), make([]string, 0)}
PipeEndpoints(one, two)
if len(one.result) != 4 || len(two.result) != 2 {
t.Errorf("Invalid lengths, should be 4 and 2: %v %v", one.result, two.result)
} else if one.result[0] != "two:0" || two.result[0] != "one:0" {
t.Errorf("Invalid first results, should be two:0 and one:0: %#v %#v", one.result[0], two.result[0])
}
}
80 changes: 19 additions & 61 deletions libwebsocketd/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ package libwebsocketd

import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
)

const (
Expand All @@ -21,86 +18,47 @@ const (
var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
var headerDashToUnderscore = strings.NewReplacer("-", "_")

func generateId() string {
return strconv.FormatInt(time.Now().UnixNano(), 10)
}

func remoteDetails(req *http.Request, config *Config) (string, string, string, error) {
remoteAddr, remotePort, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return "", "", "", err
}

var remoteHost string
if config.ReverseLookup {
remoteHosts, err := net.LookupAddr(remoteAddr)
if err != nil || len(remoteHosts) == 0 {
remoteHost = remoteAddr
} else {
remoteHost = remoteHosts[0]
}
} else {
remoteHost = remoteAddr
}

return remoteAddr, remoteHost, remotePort, nil
}

func createEnv(req *http.Request, config *Config, urlInfo *URLInfo, id string, log *LogScope) ([]string, error) {
func createEnv(handler *WebsocketdHandler, req *http.Request, log *LogScope) []string {
headers := req.Header

url := req.URL

remoteAddr, remoteHost, remotePort, err := remoteDetails(req, config)
if err != nil {
return nil, err
}

serverName, serverPort, err := net.SplitHostPort(req.Host)
serverName, serverPort, err := tellHostPort(req.Host, handler.server.Config.Ssl)
if err != nil {
// Without hijacking socket connection we cannot know port for sure.
if addrerr, ok := err.(*net.AddrError); ok && strings.Contains(addrerr.Err, "missing port") {
serverName = req.Host
if config.Ssl {
serverPort = "443"
} else {
serverPort = "80"
}
} else {
// this does mean that we cannot detect port from Host: header... Just keep going with ""
serverPort = ""
}
// This does mean that we cannot detect port from Host: header... Just keep going with "", guessing is bad.
log.Debug("env", "Host port detection error: %s", err)
serverPort = ""
}

standardEnvCount := 20
if config.Ssl {
if handler.server.Config.Ssl {
standardEnvCount += 1
}

parentLen := len(config.ParentEnv)
env := make([]string, 0, len(headers)+standardEnvCount+parentLen+len(config.Env))
parentLen := len(handler.server.Config.ParentEnv)
env := make([]string, 0, len(headers)+standardEnvCount+parentLen+len(handler.server.Config.Env))

// This variable could be rewritten from outside
env = appendEnv(env, "SERVER_SOFTWARE", config.ServerSoftware)
env = appendEnv(env, "SERVER_SOFTWARE", handler.server.Config.ServerSoftware)

parentStarts := len(env)
for _, v := range config.ParentEnv {
for _, v := range handler.server.Config.ParentEnv {
env = append(env, v)
}

// IMPORTANT ---> Adding a header? Make sure standardHeaderCount (above) is up to date.

// Standard CGI specification headers.
// As defined in http://tools.ietf.org/html/rfc3875
env = appendEnv(env, "REMOTE_ADDR", remoteAddr)
env = appendEnv(env, "REMOTE_HOST", remoteHost)
env = appendEnv(env, "REMOTE_ADDR", handler.RemoteInfo.Addr)
env = appendEnv(env, "REMOTE_HOST", handler.RemoteInfo.Host)
env = appendEnv(env, "SERVER_NAME", serverName)
env = appendEnv(env, "SERVER_PORT", serverPort)
env = appendEnv(env, "SERVER_PROTOCOL", req.Proto)
env = appendEnv(env, "GATEWAY_INTERFACE", gatewayInterface)
env = appendEnv(env, "REQUEST_METHOD", req.Method)
env = appendEnv(env, "SCRIPT_NAME", urlInfo.ScriptPath)
env = appendEnv(env, "PATH_INFO", urlInfo.PathInfo)
env = appendEnv(env, "SCRIPT_NAME", handler.URLInfo.ScriptPath)
env = appendEnv(env, "PATH_INFO", handler.URLInfo.PathInfo)
env = appendEnv(env, "PATH_TRANSLATED", url.Path)
env = appendEnv(env, "QUERY_STRING", url.RawQuery)

Expand All @@ -112,8 +70,8 @@ func createEnv(req *http.Request, config *Config, urlInfo *URLInfo, id string, l
env = appendEnv(env, "REMOTE_USER", "")

// Non standard, but commonly used headers.
env = appendEnv(env, "UNIQUE_ID", id) // Based on Apache mod_unique_id.
env = appendEnv(env, "REMOTE_PORT", remotePort)
env = appendEnv(env, "UNIQUE_ID", handler.Id) // Based on Apache mod_unique_id.
env = appendEnv(env, "REMOTE_PORT", handler.RemoteInfo.Port)
env = appendEnv(env, "REQUEST_URI", url.RequestURI()) // e.g. /foo/blah?a=b

// The following variables are part of the CGI specification, but are optional
Expand All @@ -128,7 +86,7 @@ func createEnv(req *http.Request, config *Config, urlInfo *URLInfo, id string, l
// SSL_*
// -- SSL variables are not supported, HTTPS=on added for websocketd running with --ssl

if config.Ssl {
if handler.server.Config.Ssl {
env = appendEnv(env, "HTTPS", "on")
}

Expand All @@ -148,12 +106,12 @@ func createEnv(req *http.Request, config *Config, urlInfo *URLInfo, id string, l
log.Debug("env", "Header variable %s", env[len(env)-1])
}

for _, v := range config.Env {
for _, v := range handler.server.Config.Env {
env = append(env, v)
log.Debug("env", "External variable: %s", v)
}

return env, nil
return env
}

// Adapted from net/http/header.go
Expand Down
Loading

0 comments on commit 59d46aa

Please sign in to comment.