Skip to content

Adding allowed http hosts flag #1566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions api/server/allowed_hosts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package server

import (
"net"
"net/http"
"strings"

"github.com/ava-labs/avalanchego/utils/set"
)

const wildcard = "*"

var _ http.Handler = (*allowedHostsHandler)(nil)

func filterInvalidHosts(
handler http.Handler,
allowed []string,
) http.Handler {
s := set.Set[string]{}

for _, host := range allowed {
if host == wildcard {
// wildcards match all hostnames, so just return the base handler
return handler
}
s.Add(strings.ToLower(host))
}

return &allowedHostsHandler{
handler: handler,
hosts: s,
}
}

// allowedHostsHandler is an implementation of http.Handler that validates the
// http host header of incoming requests. This can prevent DNS rebinding attacks
// which do not utilize CORS-headers. Http request host headers are validated
// against a whitelist to determine whether the request should be dropped or
// not.
type allowedHostsHandler struct {
handler http.Handler
hosts set.Set[string]
}

func (a *allowedHostsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// if the host header is missing we can serve this request because dns
// rebinding attacks rely on this header
if r.Host == "" {
a.handler.ServeHTTP(w, r)
return
}

host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// either invalid (too many colons) or no port specified
host = r.Host
}

if ipAddr := net.ParseIP(host); ipAddr != nil {
// accept requests from ips
a.handler.ServeHTTP(w, r)
return
}

// a specific hostname - we need to check the whitelist to see if we should
// accept this r
if a.hosts.Contains(strings.ToLower(host)) {
a.handler.ServeHTTP(w, r)
return
}

http.Error(w, "invalid host specified", http.StatusForbidden)
}
77 changes: 77 additions & 0 deletions api/server/allowed_hosts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package server

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
)

func TestAllowedHostsHandler_ServeHTTP(t *testing.T) {
tests := []struct {
name string
allowed []string
host string
serve bool
}{
{
name: "no host header",
allowed: []string{"www.foobar.com"},
host: "",
serve: true,
},
{
name: "ip",
allowed: []string{"www.foobar.com"},
host: "192.168.1.1",
serve: true,
},
{
name: "hostname not allowed",
allowed: []string{"www.foobar.com"},
host: "www.evil.com",
},
{
name: "hostname allowed",
allowed: []string{"www.foobar.com"},
host: "www.foobar.com",
serve: true,
},
{
name: "wildcard",
allowed: []string{"*"},
host: "www.foobar.com",
serve: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require := require.New(t)

baseHandler := &testHandler{}

httpAllowedHostsHandler := filterInvalidHosts(
baseHandler,
test.allowed,
)

w := &httptest.ResponseRecorder{}
r := httptest.NewRequest("", "/", nil)
r.Host = test.host

httpAllowedHostsHandler.ServeHTTP(w, r)

if test.serve {
require.True(baseHandler.called)
return
}

require.Equal(http.StatusForbidden, w.Code)
})
}
}
4 changes: 3 additions & 1 deletion api/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func New(
namespace string,
registerer prometheus.Registerer,
httpConfig HTTPConfig,
allowedHosts []string,
wrappers ...Wrapper,
) (Server, error) {
m, err := newMetrics(namespace, registerer)
Expand All @@ -127,10 +128,11 @@ func New(
}

router := newRouter()
allowedHostsHandler := filterInvalidHosts(router, allowedHosts)
corsHandler := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowCredentials: true,
}).Handler(router)
}).Handler(allowedHostsHandler)
gzipHandler := gziphandler.GzipHandler(corsHandler)
var handler http.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) {
HTTPSKey: httpsKey,
HTTPSCert: httpsCert,
APIAllowedOrigins: v.GetStringSlice(HTTPAllowedOrigins),
HTTPAllowedHosts: v.GetStringSlice(HTTPAllowedHostsKey),
ShutdownTimeout: v.GetDuration(HTTPShutdownTimeoutKey),
ShutdownWait: v.GetDuration(HTTPShutdownWaitKey),
}
Expand Down
1 change: 1 addition & 0 deletions config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ func addNodeFlags(fs *pflag.FlagSet) {
fs.String(HTTPSCertFileKey, "", fmt.Sprintf("TLS certificate file for the HTTPs server. Ignored if %s is specified", HTTPSCertContentKey))
fs.String(HTTPSCertContentKey, "", "Specifies base64 encoded TLS certificate for the HTTPs server")
fs.String(HTTPAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network")
fs.StringSlice(HTTPAllowedHostsKey, []string{"localhost"}, "List of acceptable host names in API requests. Provide the wildcard ('*') to accept requests from all hosts. API requests where the Host field is empty or an IP address will always be accepted. An API call whose HTTP Host field isn't acceptable will receive a 403 error code")
fs.Duration(HTTPShutdownWaitKey, 0, "Duration to wait after receiving SIGTERM or SIGINT before initiating shutdown. The /health endpoint will return unhealthy during this duration")
fs.Duration(HTTPShutdownTimeoutKey, 10*time.Second, "Maximum duration to wait for existing connections to complete during node shutdown")
fs.Duration(HTTPReadTimeoutKey, 30*time.Second, "Maximum duration for reading the entire request, including the body. A zero or negative value means there will be no timeout")
Expand Down
1 change: 1 addition & 0 deletions config/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const (
HTTPSCertFileKey = "http-tls-cert-file"
HTTPSCertContentKey = "http-tls-cert-file-content"
HTTPAllowedOrigins = "http-allowed-origins"
HTTPAllowedHostsKey = "http-allowed-hosts"
HTTPShutdownTimeoutKey = "http-shutdown-timeout"
HTTPShutdownWaitKey = "http-shutdown-wait"
HTTPReadTimeoutKey = "http-read-timeout"
Expand Down
1 change: 1 addition & 0 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type HTTPConfig struct {
HTTPSCert []byte `json:"-"`

APIAllowedOrigins []string `json:"apiAllowedOrigins"`
Copy link
Contributor Author

@joshua-kim joshua-kim Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but I think this json tag is named weirdly because the config flag for this is actually called http-allowed-origins so I would expect it to be httpAllowedOrigins

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - we should add that in a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HTTPAllowedHosts []string `json:"httpAllowedHosts"`

ShutdownTimeout time.Duration `json:"shutdownTimeout"`
ShutdownWait time.Duration `json:"shutdownWait"`
Expand Down
2 changes: 2 additions & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ func (n *Node) initAPIServer() error {
"api",
n.MetricsRegisterer,
n.Config.HTTPConfig.HTTPConfig,
n.Config.HTTPAllowedHosts,
)
return err
}
Expand All @@ -618,6 +619,7 @@ func (n *Node) initAPIServer() error {
"api",
n.MetricsRegisterer,
n.Config.HTTPConfig.HTTPConfig,
n.Config.HTTPAllowedHosts,
a,
)
if err != nil {
Expand Down