Skip to content

Commit f1e1d9f

Browse files
wbtholiman
andauthored
node: support expressive origin rules in ws.origins (#21481)
* Only compare hostnames in ws.origins Also using a helper function for ToLower consolidates all preparation steps in one function for more maintainable consistency. Spaces => tabs Remove a semicolon Add space at start of comment Remove parens around conditional Handle case wehre parsed hostname is empty When passing a single word like "localhost" the parsed hostname is an empty string. Handle this and the error-parsing case together as default, and the nonempty hostname case in the conditional. Refactor with new originIsAllowed functions Adds originIsAllowed() & ruleAllowsOrigin(); removes prepOriginForComparison Remove blank line Added tests for simple allowed-orign rule which does not specify a protocol or port, just a hostname Fix copy-paste: `:=` => `=` Remove parens around conditional Remove autoadded whitespace on blank lines Compare scheme, hostname, and port with rule if the rule specifies those portions. Remove one autoadded trailing whitespace Better handle case where only origin host is given e.g. "localhost" Remove parens around conditional Refactor: attemptWebsocketConnectionFromOrigin DRY Include return type on helper function Provide srv obj in helper fn Provide srv to helper fn Remove stray underscore Remove blank line parent 93e666b4c1e7e49b8406dc83ed93f4a02ea49ac1 author wbt <wbt@users.noreply.github.com> 1598559718 -0400 committer Martin Holst Swende <martin@swende.se> 1605602257 +0100 gpgsig -----BEGIN PGP SIGNATURE----- iQFFBAABCAAvFiEEypmrtbNuJK1doP1AaDtDjAWl3fAFAl+zi9ARHG1hcnRpbkBz d2VuZGUuc2UACgkQaDtDjAWl3fDRiwgAoMtzU8dwRV7Q9xkCwWEx9Wz2f3n6jUr2 VWBycDKGKwRkPPOER3oc9kzjGU/P1tFlK07PjfnAKZ9KWzxpDcJZwYM3xCBurG7A 16y4YsQnzgPNONv3xIkdi3RZtDBIiPFFEmdZFFvZ/jKexfI6JIYPngCAoqdTIFb9 On/aPvvVWQn1ExfmarsvvJ7kUDUG77tZipuacEH5FfFsfelBWOEYPe+I9ToUHskv +qO6rOkV1Ojk8eBc6o0R1PnApwCAlEhJs7aM/SEOg4B4ZJJneiFuEXBIG9+0yS2I NOicuDPLGucOB5nBsfIKI3USPeE+3jxdT8go2lN5Nrhm6MimoILDsQ== =sgUp -----END PGP SIGNATURE----- Refactor: drop err var for more concise test lines Add several tests for new WebSocket origin checks Remove autoadded whitespace on blank lines Restore TestWebsocketOrigins originally-named test and rename the others to be helpers rather than full tests Remove autoadded whitespace on blank line Temporarily comment out new test sets Uncomment test around origin rule with scheme Remove tests without scheme on browser origin per https://github.com/ethereum/go-ethereum/pull/21481/files#r479371498 Uncomment tests with port; remove some blank lines Handle when browser does not specify scheme/port Uncomment test for including scheme & port in rule Add IP tests * node: more tests + table-driven, ws origin changes Co-authored-by: Martin Holst Swende <martin@swende.se>
1 parent 2808046 commit f1e1d9f

File tree

2 files changed

+170
-20
lines changed

2 files changed

+170
-20
lines changed

node/rpcstack_test.go

+108-17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package node
1919
import (
2020
"bytes"
2121
"net/http"
22+
"strings"
2223
"testing"
2324

2425
"github.com/ethereum/go-ethereum/internal/testlog"
@@ -52,25 +53,104 @@ func TestVhosts(t *testing.T) {
5253
assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
5354
}
5455

55-
// TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server.
56-
func TestWebsocketOrigins(t *testing.T) {
57-
srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: []string{"test"}})
58-
defer srv.stop()
56+
type originTest struct {
57+
spec string
58+
expOk []string
59+
expFail []string
60+
}
5961

60-
dialer := websocket.DefaultDialer
61-
_, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{
62-
"Content-type": []string{"application/json"},
63-
"Sec-WebSocket-Version": []string{"13"},
64-
"Origin": []string{"test"},
65-
})
66-
assert.NoError(t, err)
62+
// splitAndTrim splits input separated by a comma
63+
// and trims excessive white space from the substrings.
64+
// Copied over from flags.go
65+
func splitAndTrim(input string) (ret []string) {
66+
l := strings.Split(input, ",")
67+
for _, r := range l {
68+
r = strings.TrimSpace(r)
69+
if len(r) > 0 {
70+
ret = append(ret, r)
71+
}
72+
}
73+
return ret
74+
}
6775

68-
_, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{
69-
"Content-type": []string{"application/json"},
70-
"Sec-WebSocket-Version": []string{"13"},
71-
"Origin": []string{"bad"},
72-
})
73-
assert.Error(t, err)
76+
// TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server.
77+
func TestWebsocketOrigins(t *testing.T) {
78+
tests := []originTest{
79+
{
80+
spec: "*", // allow all
81+
expOk: []string{"", "http://test", "https://test", "http://test:8540", "https://test:8540",
82+
"http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"},
83+
},
84+
{
85+
spec: "test",
86+
expOk: []string{"http://test", "https://test", "http://test:8540", "https://test:8540"},
87+
expFail: []string{"http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"},
88+
},
89+
// scheme tests
90+
{
91+
spec: "https://test",
92+
expOk: []string{"https://test", "https://test:9999"},
93+
expFail: []string{
94+
"test", // no scheme, required by spec
95+
"http://test", // wrong scheme
96+
"http://test.foo", "https://a.test.x", // subdomain variatoins
97+
"http://testx:8540", "https://xtest:8540"},
98+
},
99+
// ip tests
100+
{
101+
spec: "https://12.34.56.78",
102+
expOk: []string{"https://12.34.56.78", "https://12.34.56.78:8540"},
103+
expFail: []string{
104+
"http://12.34.56.78", // wrong scheme
105+
"http://12.34.56.78:443", // wrong scheme
106+
"http://1.12.34.56.78", // wrong 'domain name'
107+
"http://12.34.56.78.a", // wrong 'domain name'
108+
"https://87.65.43.21", "http://87.65.43.21:8540", "https://87.65.43.21:8540"},
109+
},
110+
// port tests
111+
{
112+
spec: "test:8540",
113+
expOk: []string{"http://test:8540", "https://test:8540"},
114+
expFail: []string{
115+
"http://test", "https://test", // spec says port required
116+
"http://test:8541", "https://test:8541", // wrong port
117+
"http://bad", "https://bad", "http://bad:8540", "https://bad:8540"},
118+
},
119+
// scheme and port
120+
{
121+
spec: "https://test:8540",
122+
expOk: []string{"https://test:8540"},
123+
expFail: []string{
124+
"https://test", // missing port
125+
"http://test", // missing port, + wrong scheme
126+
"http://test:8540", // wrong scheme
127+
"http://test:8541", "https://test:8541", // wrong port
128+
"http://bad", "https://bad", "http://bad:8540", "https://bad:8540"},
129+
},
130+
// several allowed origins
131+
{
132+
spec: "localhost,http://127.0.0.1",
133+
expOk: []string{"localhost", "http://localhost", "https://localhost:8443",
134+
"http://127.0.0.1", "http://127.0.0.1:8080"},
135+
expFail: []string{
136+
"https://127.0.0.1", // wrong scheme
137+
"http://bad", "https://bad", "http://bad:8540", "https://bad:8540"},
138+
},
139+
}
140+
for _, tc := range tests {
141+
srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: splitAndTrim(tc.spec)})
142+
for _, origin := range tc.expOk {
143+
if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err != nil {
144+
t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err)
145+
}
146+
}
147+
for _, origin := range tc.expFail {
148+
if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err == nil {
149+
t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin)
150+
}
151+
}
152+
srv.stop()
153+
}
74154
}
75155

76156
// TestIsWebsocket tests if an incoming websocket upgrade request is handled properly.
@@ -103,6 +183,17 @@ func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfi
103183
return srv
104184
}
105185

186+
func attemptWebsocketConnectionFromOrigin(t *testing.T, srv *httpServer, browserOrigin string) error {
187+
t.Helper()
188+
dialer := websocket.DefaultDialer
189+
_, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{
190+
"Content-type": []string{"application/json"},
191+
"Sec-WebSocket-Version": []string{"13"},
192+
"Origin": []string{browserOrigin},
193+
})
194+
return err
195+
}
196+
106197
func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response {
107198
t.Helper()
108199

rpc/websocket.go

+62-3
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
7575
allowAllOrigins = true
7676
}
7777
if origin != "" {
78-
origins.Add(strings.ToLower(origin))
78+
origins.Add(origin)
7979
}
8080
}
8181
// allow localhost if no allowedOrigins are specified.
8282
if len(origins.ToSlice()) == 0 {
8383
origins.Add("http://localhost")
8484
if hostname, err := os.Hostname(); err == nil {
85-
origins.Add("http://" + strings.ToLower(hostname))
85+
origins.Add("http://" + hostname)
8686
}
8787
}
8888
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
@@ -97,7 +97,7 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
9797
}
9898
// Verify origin against whitelist.
9999
origin := strings.ToLower(req.Header.Get("Origin"))
100-
if allowAllOrigins || origins.Contains(origin) {
100+
if allowAllOrigins || originIsAllowed(origins, origin) {
101101
return true
102102
}
103103
log.Warn("Rejected WebSocket connection", "origin", origin)
@@ -120,6 +120,65 @@ func (e wsHandshakeError) Error() string {
120120
return s
121121
}
122122

123+
func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
124+
it := allowedOrigins.Iterator()
125+
for origin := range it.C {
126+
if ruleAllowsOrigin(origin.(string), browserOrigin) {
127+
return true
128+
}
129+
}
130+
return false
131+
}
132+
133+
func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
134+
var (
135+
allowedScheme, allowedHostname, allowedPort string
136+
browserScheme, browserHostname, browserPort string
137+
err error
138+
)
139+
allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
140+
if err != nil {
141+
log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
142+
return false
143+
}
144+
browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
145+
if err != nil {
146+
log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
147+
return false
148+
}
149+
if allowedScheme != "" && allowedScheme != browserScheme {
150+
return false
151+
}
152+
if allowedHostname != "" && allowedHostname != browserHostname {
153+
return false
154+
}
155+
if allowedPort != "" && allowedPort != browserPort {
156+
return false
157+
}
158+
return true
159+
}
160+
161+
func parseOriginURL(origin string) (string, string, string, error) {
162+
parsedURL, err := url.Parse(strings.ToLower(origin))
163+
if err != nil {
164+
return "", "", "", err
165+
}
166+
var scheme, hostname, port string
167+
if strings.Contains(origin, "://") {
168+
scheme = parsedURL.Scheme
169+
hostname = parsedURL.Hostname()
170+
port = parsedURL.Port()
171+
} else {
172+
scheme = ""
173+
hostname = parsedURL.Scheme
174+
port = parsedURL.Opaque
175+
if hostname == "" {
176+
hostname = origin
177+
}
178+
}
179+
return scheme, hostname, port, nil
180+
}
181+
123182
// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
124183
// that is listening on the given endpoint using the provided dialer.
125184
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {

0 commit comments

Comments
 (0)