Skip to content

Commit

Permalink
Sanitize SSH server hostnames
Browse files Browse the repository at this point in the history
Prevents any invalid and malicious hostnames, but replacing them with
known valid data already associated with the host. This was chosen
instead of rejecting to persist the server resource in an attempt to
continue providing access to the host in order to remedy the invalid
hostname.

Any servers that represent a Teleport ssh_service with an invalid
hostname will be replaced by the host UUID. Any static OpenSSH servers
will have invalid hostnames replaced with the address. This will continue
to allow the hosts to be dialable. In order to make these hosts
discoverable, the invalid hostname will be set in the
"teleport.internal/invalid-hostname" label.

Updates gravitational/teleport-private#1676.
  • Loading branch information
rosstimothy committed Nov 14, 2024
1 parent 903c1ad commit 2779014
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 12 deletions.
76 changes: 76 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ import (
"log/slog"
"math/big"
insecurerand "math/rand"
"net"
"os"
"regexp"
"slices"
"sort"
"strconv"
Expand Down Expand Up @@ -1470,6 +1472,25 @@ func (a *Server) runPeriodicOperations() {
if services.NodeHasMissedKeepAlives(srv) {
missedKeepAliveCount++
}

// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
if !validServerHostname(srv.GetHostname()) {
if srv.GetSubKind() != types.SubKindOpenSSHNode {
return false, nil
}

// Any existing static hosts will not have their
// hostname sanitized since they don't heartbeat.
if err := sanitizeHostname(srv); err != nil {
a.logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err, "server", srv.GetName())
return false, nil
}

if _, err := a.Services.UpsertNode(a.closeCtx, srv); err != nil {
a.logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err, "server", srv.GetName())
}
}

return false, nil
},
req,
Expand Down Expand Up @@ -5618,9 +5639,64 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error {
return nil
}

const (
serverHostnameMaxLen = 256
serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$`
replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname"
)

var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern)

// validServerHostname returns false if the hostname is longer than 256 characters or
// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also
// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol.
func validServerHostname(hostname string) bool {
return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname)
}

func sanitizeHostname(server types.Server) error {
invalidHostname := server.GetHostname()

replacedHostname := server.GetName()
if server.GetSubKind() == types.SubKindOpenSSHNode {
host, _, err := net.SplitHostPort(server.GetAddr())
if err != nil || !validServerHostname(host) {
id, err := uuid.NewRandom()
if err != nil {
return trace.Wrap(err)
}

host = id.String()
}

replacedHostname = host
}

switch s := server.(type) {
case *types.ServerV2:
s.Spec.Hostname = replacedHostname

if s.Metadata.Labels == nil {
s.Metadata.Labels = map[string]string{}
}

s.Metadata.Labels[replacedHostnameLabel] = invalidHostname
default:
return trace.BadParameter("invalid server provided")
}

return nil
}

// UpsertNode implements [services.Presence] by delegating to [Server.Services]
// and potentially emitting a [usagereporter] event.
func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) {
if !validServerHostname(server.GetHostname()) {
if err := sanitizeHostname(server); err != nil {
return nil, trace.Wrap(err)
}
}

lease, err := a.Services.UpsertNode(ctx, server)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
138 changes: 128 additions & 10 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package auth

import (
"cmp"
"context"
"crypto/rand"
"crypto/x509"
Expand All @@ -34,7 +35,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/license"
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestSessions(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, out.GetSSHPriv())
assert.Empty(t, out.GetTLSPriv())
assert.Empty(t, cmp.Diff(ws, out,
assert.Empty(t, gocmp.Diff(ws, out,
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv")))

Expand Down Expand Up @@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) {
AssetTag: test.opts.DeviceExtensions.AssetTag,
CredentialId: test.opts.DeviceExtensions.CredentialID,
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := gocmp.Diff(want, got); diff != "" {
t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff)
}
}
Expand Down Expand Up @@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) {
// Assert TLS extensions.
_, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert())
wantExts := tlsca.DeviceExtensions(*deviceExts)
if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff)
}

// Assert SSH extensions.
if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff)
}
})
Expand Down Expand Up @@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) {
// Validate audit event.
lastEvent := p.mockEmitter.LastEvent()
require.IsType(t, &apievents.CertificateCreate{}, lastEvent)
require.Empty(t, cmp.Diff(
require.Empty(t, gocmp.Diff(
&apievents.CertificateCreate{
Metadata: apievents.Metadata{
Type: events.CertificateCreateEvent,
Expand Down Expand Up @@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD
}

// Ignore LastUsed and SignatureCounter?
var opts []cmp.Option
var opts []gocmp.Option
if ignoreUpdateAndCounter {
opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool {
opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool {
p := path.String()
return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter"
}, cmp.Ignore()))
}, gocmp.Ignore()))
}

if diff := cmp.Diff(want, got, opts...); diff != "" {
if diff := gocmp.Diff(want, got, opts...); diff != "" {
t.Errorf("compareDevices mismatch (-want +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest

return &notification
}

// TestServerHostnameSanitization tests that persisting servers with
// "invalid" hostnames results in the hostname being sanitized and the
// illegal name being placed in a label.
func TestServerHostnameSanitization(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()})
require.NoError(t, err)

cases := []struct {
name string
hostname string
addr string
invalidHostname bool
invalidAddr bool
}{
{
name: "valid dns hostname",
hostname: "llama.example.com",
},
{
name: "valid friendly hostname",
hostname: "llama",
},
{
name: "uuid hostname",
hostname: uuid.NewString(),
},
{
name: "uuid dns hostname",
hostname: uuid.NewString() + ".example.com",
},
{
name: "empty hostname",
hostname: "",
invalidHostname: true,
},
{
name: "exceptionally long hostname",
hostname: strings.Repeat("a", serverHostnameMaxLen*2),
invalidHostname: true,
},
{
name: "invalid dns hostname",
hostname: "llama..example.com",
invalidHostname: true,
},
{
name: "spaces in hostname",
hostname: "the quick brown fox jumps over the lazy dog",
invalidHostname: true,
},
{
name: "invalid addr",
hostname: "..",
addr: "..:2345",
invalidHostname: true,
invalidAddr: true,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} {
t.Run(subKind, func(t *testing.T) {
server := &types.ServerV2{
Kind: types.KindNode,
SubKind: subKind,
Metadata: types.Metadata{
Name: uuid.NewString(),
},
Spec: types.ServerSpecV2{
Hostname: test.hostname,
Addr: cmp.Or(test.addr, "abcd:1234"),
},
}
if subKind == types.KindNode {
server.SubKind = ""
}

_, err = srv.AuthServer.UpsertNode(ctx, server)
require.NoError(t, err)

replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname")
if !test.invalidHostname {
assert.Equal(t, test.hostname, server.GetHostname())
assert.Empty(t, replacedValue)
return
}

assert.Equal(t, test.hostname, replacedValue)
switch subKind {
case types.SubKindOpenSSHNode:
host, _, err := net.SplitHostPort(server.GetAddr())
assert.NoError(t, err)
if !test.invalidAddr {
// If the address is valid, then the hostname should be set
// to the host of the addr field.
assert.Equal(t, host, server.GetHostname())
} else {
// If the address is not valid, then the hostname should be
// set to a UUID.
assert.NotEqual(t, host, server.GetHostname())
assert.NotEqual(t, server.GetName(), server.GetHostname())

_, err := uuid.Parse(server.GetHostname())
require.NoError(t, err)
}
default:
assert.Equal(t, server.GetName(), server.GetHostname())
}
})
}
})
}
}
4 changes: 2 additions & 2 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) {
require.NoError(t, err)

// node1 and node2 will be added to default namespace
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil)
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil)
require.NoError(t, err)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil)
require.NoError(t, err)

t.Run("CreateNode", func(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions lib/services/suite/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
require.Empty(t, out)

srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace)
srv.Spec.Hostname = "llama"
_, err = s.PresenceS.UpsertNode(ctx, srv)
require.NoError(t, err)

Expand All @@ -513,6 +514,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
require.Empty(t, out)

proxy := NewServer(types.KindProxy, "proxy1", "127.0.0.1:2023", apidefaults.Namespace)
proxy.Spec.Hostname = "proxy.llama"
require.NoError(t, s.PresenceS.UpsertProxy(ctx, proxy))

out, err = s.PresenceS.GetProxies()
Expand All @@ -533,6 +535,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) {
require.Empty(t, out)

auth := NewServer(types.KindAuthServer, "auth1", "127.0.0.1:2025", apidefaults.Namespace)
auth.Spec.Hostname = "auth.llama"
require.NoError(t, s.PresenceS.UpsertAuthServer(ctx, auth))

out, err = s.PresenceS.GetAuthServers()
Expand Down

0 comments on commit 2779014

Please sign in to comment.