diff --git a/lib/auth/auth.go b/lib/auth/auth.go index dcbcf632b598d..a909608a51af7 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -39,7 +39,9 @@ import ( "log/slog" "math/big" insecurerand "math/rand" + "net" "os" + "regexp" "slices" "sort" "strconv" @@ -1470,6 +1472,25 @@ func (a *Server) runPeriodicOperations() { if services.NodeHasMissedKeepAlives(srv) { missedKeepAliveCount++ } + + // TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then. + if !validServerHostname(srv.GetHostname()) { + if srv.GetSubKind() != types.SubKindOpenSSHNode { + return false, nil + } + + // Any existing static hosts will not have their + // hostname sanitized since they don't heartbeat. + if err := sanitizeHostname(srv); err != nil { + a.logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err, "server", srv.GetName()) + return false, nil + } + + if _, err := a.Services.UpsertNode(a.closeCtx, srv); err != nil { + a.logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err, "server", srv.GetName()) + } + } + return false, nil }, req, @@ -5618,9 +5639,64 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error { return nil } +const ( + serverHostnameMaxLen = 256 + serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$` + replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname" +) + +var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern) + +// validServerHostname returns false if the hostname is longer than 256 characters or +// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also +// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol. +func validServerHostname(hostname string) bool { + return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname) +} + +func sanitizeHostname(server types.Server) error { + invalidHostname := server.GetHostname() + + replacedHostname := server.GetName() + if server.GetSubKind() == types.SubKindOpenSSHNode { + host, _, err := net.SplitHostPort(server.GetAddr()) + if err != nil || !validServerHostname(host) { + id, err := uuid.NewRandom() + if err != nil { + return trace.Wrap(err) + } + + host = id.String() + } + + replacedHostname = host + } + + switch s := server.(type) { + case *types.ServerV2: + s.Spec.Hostname = replacedHostname + + if s.Metadata.Labels == nil { + s.Metadata.Labels = map[string]string{} + } + + s.Metadata.Labels[replacedHostnameLabel] = invalidHostname + default: + return trace.BadParameter("invalid server provided") + } + + return nil +} + // UpsertNode implements [services.Presence] by delegating to [Server.Services] // and potentially emitting a [usagereporter] event. func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) { + if !validServerHostname(server.GetHostname()) { + if err := sanitizeHostname(server); err != nil { + return nil, trace.Wrap(err) + } + } + lease, err := a.Services.UpsertNode(ctx, server) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 53150ac3e176b..e4978e32e358a 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -19,6 +19,7 @@ package auth import ( + "cmp" "context" "crypto/rand" "crypto/x509" @@ -34,7 +35,7 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" + gocmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/gravitational/license" @@ -307,7 +308,7 @@ func TestSessions(t *testing.T) { require.NoError(t, err) assert.Empty(t, out.GetSSHPriv()) assert.Empty(t, out.GetTLSPriv()) - assert.Empty(t, cmp.Diff(ws, out, + assert.Empty(t, gocmp.Diff(ws, out, cmpopts.IgnoreFields(types.Metadata{}, "Revision"), cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv"))) @@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) { AssetTag: test.opts.DeviceExtensions.AssetTag, CredentialId: test.opts.DeviceExtensions.CredentialID, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := gocmp.Diff(want, got); diff != "" { t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff) } } @@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) { // Assert TLS extensions. _, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert()) wantExts := tlsca.DeviceExtensions(*deviceExts) - if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" { + if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" { t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff) } // Assert SSH extensions. - if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" { + if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" { t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff) } }) @@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) { // Validate audit event. lastEvent := p.mockEmitter.LastEvent() require.IsType(t, &apievents.CertificateCreate{}, lastEvent) - require.Empty(t, cmp.Diff( + require.Empty(t, gocmp.Diff( &apievents.CertificateCreate{ Metadata: apievents.Metadata{ Type: events.CertificateCreateEvent, @@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD } // Ignore LastUsed and SignatureCounter? - var opts []cmp.Option + var opts []gocmp.Option if ignoreUpdateAndCounter { - opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool { + opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool { p := path.String() return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter" - }, cmp.Ignore())) + }, gocmp.Ignore())) } - if diff := cmp.Diff(want, got, opts...); diff != "" { + if diff := gocmp.Diff(want, got, opts...); diff != "" { t.Errorf("compareDevices mismatch (-want +got):\n%s", diff) } } @@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest return ¬ification } + +// TestServerHostnameSanitization tests that persisting servers with +// "invalid" hostnames results in the hostname being sanitized and the +// illegal name being placed in a label. +func TestServerHostnameSanitization(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()}) + require.NoError(t, err) + + cases := []struct { + name string + hostname string + addr string + invalidHostname bool + invalidAddr bool + }{ + { + name: "valid dns hostname", + hostname: "llama.example.com", + }, + { + name: "valid friendly hostname", + hostname: "llama", + }, + { + name: "uuid hostname", + hostname: uuid.NewString(), + }, + { + name: "uuid dns hostname", + hostname: uuid.NewString() + ".example.com", + }, + { + name: "empty hostname", + hostname: "", + invalidHostname: true, + }, + { + name: "exceptionally long hostname", + hostname: strings.Repeat("a", serverHostnameMaxLen*2), + invalidHostname: true, + }, + { + name: "invalid dns hostname", + hostname: "llama..example.com", + invalidHostname: true, + }, + { + name: "spaces in hostname", + hostname: "the quick brown fox jumps over the lazy dog", + invalidHostname: true, + }, + { + name: "invalid addr", + hostname: "..", + addr: "..:2345", + invalidHostname: true, + invalidAddr: true, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} { + t.Run(subKind, func(t *testing.T) { + server := &types.ServerV2{ + Kind: types.KindNode, + SubKind: subKind, + Metadata: types.Metadata{ + Name: uuid.NewString(), + }, + Spec: types.ServerSpecV2{ + Hostname: test.hostname, + Addr: cmp.Or(test.addr, "abcd:1234"), + }, + } + if subKind == types.KindNode { + server.SubKind = "" + } + + _, err = srv.AuthServer.UpsertNode(ctx, server) + require.NoError(t, err) + + replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname") + if !test.invalidHostname { + assert.Equal(t, test.hostname, server.GetHostname()) + assert.Empty(t, replacedValue) + return + } + + assert.Equal(t, test.hostname, replacedValue) + switch subKind { + case types.SubKindOpenSSHNode: + host, _, err := net.SplitHostPort(server.GetAddr()) + assert.NoError(t, err) + if !test.invalidAddr { + // If the address is valid, then the hostname should be set + // to the host of the addr field. + assert.Equal(t, host, server.GetHostname()) + } else { + // If the address is not valid, then the hostname should be + // set to a UUID. + assert.NotEqual(t, host, server.GetHostname()) + assert.NotEqual(t, server.GetName(), server.GetHostname()) + + _, err := uuid.Parse(server.GetHostname()) + require.NoError(t, err) + } + default: + assert.Equal(t, server.GetName(), server.GetHostname()) + } + }) + } + }) + } +} diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index aef8897492d80..2890b255f7114 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) { require.NoError(t, err) // node1 and node2 will be added to default namespace - node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil) + node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil) require.NoError(t, err) - node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil) + node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil) require.NoError(t, err) t.Run("CreateNode", func(t *testing.T) { diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 880005063464d..ad2cb4695b3f2 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -488,6 +488,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace) + srv.Spec.Hostname = "llama" _, err = s.PresenceS.UpsertNode(ctx, srv) require.NoError(t, err) @@ -513,6 +514,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) proxy := NewServer(types.KindProxy, "proxy1", "127.0.0.1:2023", apidefaults.Namespace) + proxy.Spec.Hostname = "proxy.llama" require.NoError(t, s.PresenceS.UpsertProxy(ctx, proxy)) out, err = s.PresenceS.GetProxies() @@ -533,6 +535,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) auth := NewServer(types.KindAuthServer, "auth1", "127.0.0.1:2025", apidefaults.Namespace) + auth.Spec.Hostname = "auth.llama" require.NoError(t, s.PresenceS.UpsertAuthServer(ctx, auth)) out, err = s.PresenceS.GetAuthServers()