Skip to content

Commit bfaa7f7

Browse files
authored
Add allowed http hosts configuration (ava-labs#1566)
1 parent 8fb8afe commit bfaa7f7

File tree

8 files changed

+162
-1
lines changed

8 files changed

+162
-1
lines changed

api/server/allowed_hosts.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
2+
// See the file LICENSE for licensing terms.
3+
4+
package server
5+
6+
import (
7+
"net"
8+
"net/http"
9+
"strings"
10+
11+
"github.com/ava-labs/avalanchego/utils/set"
12+
)
13+
14+
const wildcard = "*"
15+
16+
var _ http.Handler = (*allowedHostsHandler)(nil)
17+
18+
func filterInvalidHosts(
19+
handler http.Handler,
20+
allowed []string,
21+
) http.Handler {
22+
s := set.Set[string]{}
23+
24+
for _, host := range allowed {
25+
if host == wildcard {
26+
// wildcards match all hostnames, so just return the base handler
27+
return handler
28+
}
29+
s.Add(strings.ToLower(host))
30+
}
31+
32+
return &allowedHostsHandler{
33+
handler: handler,
34+
hosts: s,
35+
}
36+
}
37+
38+
// allowedHostsHandler is an implementation of http.Handler that validates the
39+
// http host header of incoming requests. This can prevent DNS rebinding attacks
40+
// which do not utilize CORS-headers. Http request host headers are validated
41+
// against a whitelist to determine whether the request should be dropped or
42+
// not.
43+
type allowedHostsHandler struct {
44+
handler http.Handler
45+
hosts set.Set[string]
46+
}
47+
48+
func (a *allowedHostsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
49+
// if the host header is missing we can serve this request because dns
50+
// rebinding attacks rely on this header
51+
if r.Host == "" {
52+
a.handler.ServeHTTP(w, r)
53+
return
54+
}
55+
56+
host, _, err := net.SplitHostPort(r.Host)
57+
if err != nil {
58+
// either invalid (too many colons) or no port specified
59+
host = r.Host
60+
}
61+
62+
if ipAddr := net.ParseIP(host); ipAddr != nil {
63+
// accept requests from ips
64+
a.handler.ServeHTTP(w, r)
65+
return
66+
}
67+
68+
// a specific hostname - we need to check the whitelist to see if we should
69+
// accept this r
70+
if a.hosts.Contains(strings.ToLower(host)) {
71+
a.handler.ServeHTTP(w, r)
72+
return
73+
}
74+
75+
http.Error(w, "invalid host specified", http.StatusForbidden)
76+
}

api/server/allowed_hosts_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
2+
// See the file LICENSE for licensing terms.
3+
4+
package server
5+
6+
import (
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestAllowedHostsHandler_ServeHTTP(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
allowed []string
18+
host string
19+
serve bool
20+
}{
21+
{
22+
name: "no host header",
23+
allowed: []string{"www.foobar.com"},
24+
host: "",
25+
serve: true,
26+
},
27+
{
28+
name: "ip",
29+
allowed: []string{"www.foobar.com"},
30+
host: "192.168.1.1",
31+
serve: true,
32+
},
33+
{
34+
name: "hostname not allowed",
35+
allowed: []string{"www.foobar.com"},
36+
host: "www.evil.com",
37+
},
38+
{
39+
name: "hostname allowed",
40+
allowed: []string{"www.foobar.com"},
41+
host: "www.foobar.com",
42+
serve: true,
43+
},
44+
{
45+
name: "wildcard",
46+
allowed: []string{"*"},
47+
host: "www.foobar.com",
48+
serve: true,
49+
},
50+
}
51+
52+
for _, test := range tests {
53+
t.Run(test.name, func(t *testing.T) {
54+
require := require.New(t)
55+
56+
baseHandler := &testHandler{}
57+
58+
httpAllowedHostsHandler := filterInvalidHosts(
59+
baseHandler,
60+
test.allowed,
61+
)
62+
63+
w := &httptest.ResponseRecorder{}
64+
r := httptest.NewRequest("", "/", nil)
65+
r.Host = test.host
66+
67+
httpAllowedHostsHandler.ServeHTTP(w, r)
68+
69+
if test.serve {
70+
require.True(baseHandler.called)
71+
return
72+
}
73+
74+
require.Equal(http.StatusForbidden, w.Code)
75+
})
76+
}
77+
}

api/server/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func New(
119119
namespace string,
120120
registerer prometheus.Registerer,
121121
httpConfig HTTPConfig,
122+
allowedHosts []string,
122123
wrappers ...Wrapper,
123124
) (Server, error) {
124125
m, err := newMetrics(namespace, registerer)
@@ -127,10 +128,11 @@ func New(
127128
}
128129

129130
router := newRouter()
131+
allowedHostsHandler := filterInvalidHosts(router, allowedHosts)
130132
corsHandler := cors.New(cors.Options{
131133
AllowedOrigins: allowedOrigins,
132134
AllowCredentials: true,
133-
}).Handler(router)
135+
}).Handler(allowedHostsHandler)
134136
gzipHandler := gziphandler.GzipHandler(corsHandler)
135137
var handler http.Handler = http.HandlerFunc(
136138
func(w http.ResponseWriter, r *http.Request) {

config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) {
244244
HTTPSKey: httpsKey,
245245
HTTPSCert: httpsCert,
246246
APIAllowedOrigins: v.GetStringSlice(HTTPAllowedOrigins),
247+
HTTPAllowedHosts: v.GetStringSlice(HTTPAllowedHostsKey),
247248
ShutdownTimeout: v.GetDuration(HTTPShutdownTimeoutKey),
248249
ShutdownWait: v.GetDuration(HTTPShutdownWaitKey),
249250
}

config/flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ func addNodeFlags(fs *pflag.FlagSet) {
221221
fs.String(HTTPSCertFileKey, "", fmt.Sprintf("TLS certificate file for the HTTPs server. Ignored if %s is specified", HTTPSCertContentKey))
222222
fs.String(HTTPSCertContentKey, "", "Specifies base64 encoded TLS certificate for the HTTPs server")
223223
fs.String(HTTPAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network")
224+
fs.StringSlice(HTTPAllowedHostsKey, []string{"localhost"}, "List of acceptable host names in API requests. Provide the wildcard ('*') to accept requests from all hosts. API requests where the Host field is empty or an IP address will always be accepted. An API call whose HTTP Host field isn't acceptable will receive a 403 error code")
224225
fs.Duration(HTTPShutdownWaitKey, 0, "Duration to wait after receiving SIGTERM or SIGINT before initiating shutdown. The /health endpoint will return unhealthy during this duration")
225226
fs.Duration(HTTPShutdownTimeoutKey, 10*time.Second, "Maximum duration to wait for existing connections to complete during node shutdown")
226227
fs.Duration(HTTPReadTimeoutKey, 30*time.Second, "Maximum duration for reading the entire request, including the body. A zero or negative value means there will be no timeout")

config/keys.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ const (
5454
HTTPSCertFileKey = "http-tls-cert-file"
5555
HTTPSCertContentKey = "http-tls-cert-file-content"
5656
HTTPAllowedOrigins = "http-allowed-origins"
57+
HTTPAllowedHostsKey = "http-allowed-hosts"
5758
HTTPShutdownTimeoutKey = "http-shutdown-timeout"
5859
HTTPShutdownWaitKey = "http-shutdown-wait"
5960
HTTPReadTimeoutKey = "http-read-timeout"

node/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type HTTPConfig struct {
5454
HTTPSCert []byte `json:"-"`
5555

5656
APIAllowedOrigins []string `json:"apiAllowedOrigins"`
57+
HTTPAllowedHosts []string `json:"httpAllowedHosts"`
5758

5859
ShutdownTimeout time.Duration `json:"shutdownTimeout"`
5960
ShutdownWait time.Duration `json:"shutdownWait"`

node/node.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ func (n *Node) initAPIServer() error {
596596
"api",
597597
n.MetricsRegisterer,
598598
n.Config.HTTPConfig.HTTPConfig,
599+
n.Config.HTTPAllowedHosts,
599600
)
600601
return err
601602
}
@@ -618,6 +619,7 @@ func (n *Node) initAPIServer() error {
618619
"api",
619620
n.MetricsRegisterer,
620621
n.Config.HTTPConfig.HTTPConfig,
622+
n.Config.HTTPAllowedHosts,
621623
a,
622624
)
623625
if err != nil {

0 commit comments

Comments
 (0)