From 195dfca433028887973f5bd82d173d91fe9dab4a Mon Sep 17 00:00:00 2001 From: miagilepner Date: Tue, 29 Oct 2024 10:14:44 +0100 Subject: [PATCH] VAULT-31264: Limit raft joins (#28790) * Switch from an unbounded Map to an LRU, 429 when exceeding it's size, and repeat challenges to the same server rather than encrypting new ones * Prune old challenges * Remove from pending only if the answer is correct * Add a unit test that validates 429s, delays, and eviction of old entries * Switch to using a flat token bucket from x/time/rate * remove from LRU on each challenge write * Remove sleep, simplify unit test * improve const names * additional tests * max answer size * add locking to prevent multiple new challenges * remove log line --------- Co-authored-by: Scott G. Miller --- vault/core.go | 5 +- vault/external_tests/raft/raft_test.go | 166 +++++++++++++++++++++++++ vault/logical_system.go | 23 ++-- vault/logical_system_raft.go | 68 +++++++--- vault/raft.go | 16 ++- 5 files changed, 246 insertions(+), 32 deletions(-) diff --git a/vault/core.go b/vault/core.go index 3598b5e9f830..981d2a2fe706 100644 --- a/vault/core.go +++ b/vault/core.go @@ -40,6 +40,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/tlsutil" "github.com/hashicorp/go-uuid" + lru "github.com/hashicorp/golang-lru/v2" kv "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" @@ -628,7 +629,9 @@ type Core struct { // Stop channel for raft TLS rotations raftTLSRotationStopCh chan struct{} // Stores the pending peers we are waiting to give answers - pendingRaftPeers *sync.Map + pendingRaftPeers *lru.Cache[string, *raftBootstrapChallenge] + // holds the lock for modifying pendingRaftPeers + pendingRaftPeersLock sync.RWMutex // rawConfig stores the config as-is from the provided server configuration. rawConfig *atomic.Value diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index c0000a633b09..9f36486c3e4f 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "crypto/md5" + "encoding/base64" "errors" "fmt" "io" @@ -16,7 +17,9 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/hashicorp/go-cleanhttp" + wrapping "github.com/hashicorp/go-kms-wrapping/v2" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" @@ -248,6 +251,169 @@ func TestRaft_Retry_Join(t *testing.T) { }) } +// TestRaftChallenge_sameAnswerSameID_concurrent verifies that 10 goroutines +// all requesting a raft challenge with the same ID all return the same answer. +// This is a regression test for a TOCTTOU race found during testing. +func TestRaftChallenge_sameAnswerSameID_concurrent(t *testing.T) { + t.Parallel() + + cluster, _ := raftCluster(t, &RaftClusterOpts{ + DisableFollowerJoins: true, + NumCores: 1, + }) + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + challenges := make(chan string, 15) + wg := sync.WaitGroup{} + for i := 0; i < 15; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": "node1", + }) + require.NoError(t, err) + challenges <- res.Data["challenge"].(string) + }() + } + + wg.Wait() + challengeSet := make(map[string]struct{}) + close(challenges) + for challenge := range challenges { + challengeSet[challenge] = struct{}{} + } + + require.Len(t, challengeSet, 1) +} + +// TestRaftChallenge_sameAnswerSameID verifies that repeated bootstrap requests +// with the same node ID return the same challenge, but that a different node ID +// returns a different challenge +func TestRaftChallenge_sameAnswerSameID(t *testing.T) { + t.Parallel() + + cluster, _ := raftCluster(t, &RaftClusterOpts{ + DisableFollowerJoins: true, + NumCores: 1, + }) + defer cluster.Cleanup() + client := cluster.Cores[0].Client + res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": "node1", + }) + require.NoError(t, err) + + // querying the same ID returns the same challenge + challenge := res.Data["challenge"] + resSameID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": "node1", + }) + require.NoError(t, err) + require.Equal(t, challenge, resSameID.Data["challenge"]) + + // querying a different ID returns a new challenge + resDiffID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": "node2", + }) + require.NoError(t, err) + require.NotEqual(t, challenge, resDiffID.Data["challenge"]) +} + +// TestRaftChallenge_evicted verifies that a valid answer errors if there have +// been more than 20 challenge requests after it, because our cache of pending +// bootstraps is limited to 20 +func TestRaftChallenge_evicted(t *testing.T) { + t.Parallel() + cluster, _ := raftCluster(t, &RaftClusterOpts{ + DisableFollowerJoins: true, + NumCores: 1, + }) + defer cluster.Cleanup() + firstResponse := map[string]interface{}{} + client := cluster.Cores[0].Client + for i := 0; i < vault.RaftInitialChallengeLimit+1; i++ { + if i == vault.RaftInitialChallengeLimit { + // wait before sending the last request, so we don't get rate + // limited + time.Sleep(2 * time.Second) + } + res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": fmt.Sprintf("node-%d", i), + }) + require.NoError(t, err) + + // save the response from the first challenge + if i == 0 { + firstResponse = res.Data + } + } + + // get the answer to the challenge + challengeRaw, err := base64.StdEncoding.DecodeString(firstResponse["challenge"].(string)) + require.NoError(t, err) + eBlob := &wrapping.BlobInfo{} + err = proto.Unmarshal(challengeRaw, eBlob) + require.NoError(t, err) + access := cluster.Cores[0].SealAccess().GetAccess() + multiWrapValue := &vaultseal.MultiWrapValue{ + Generation: access.Generation(), + Slots: []*wrapping.BlobInfo{eBlob}, + } + plaintext, _, err := access.Decrypt(context.Background(), multiWrapValue) + require.NoError(t, err) + + // send the answer + _, err = client.Logical().Write("sys/storage/raft/bootstrap/answer", map[string]interface{}{ + "answer": base64.StdEncoding.EncodeToString(plaintext), + "server_id": "node-0", + "cluster_addr": "127.0.0.1:8200", + "sdk_version": "1.1.1", + "upgrade_version": "1.2.3", + "non_voter": false, + }) + + require.ErrorContains(t, err, "no expected answer for the server id provided") +} + +// TestRaft_ChallengeSpam creates 40 raft bootstrap challenges. The first 20 +// should succeed. After 20 challenges have been created, slow down the requests +// so that there are 2.5 occurring per second. Some of these will fail, due to +// rate limiting, but others will succeed. +func TestRaft_ChallengeSpam(t *testing.T) { + t.Parallel() + cluster, _ := raftCluster(t, &RaftClusterOpts{ + DisableFollowerJoins: true, + }) + defer cluster.Cleanup() + + // Execute 2 * MaxInFlightRequests, over a period that should allow some to proceed as the token bucket + // refills. + var someLaterFailed bool + var someLaterSucceeded bool + for n := 0; n < 2*vault.RaftInitialChallengeLimit; n++ { + _, err := cluster.Cores[0].Client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{ + "server_id": fmt.Sprintf("core-%d", n), + }) + // First MaxInFlightRequests should succeed for sure + if n < vault.RaftInitialChallengeLimit { + require.NoError(t, err) + } else { + // slow down to twice the configured rps + time.Sleep((1000 * time.Millisecond) / (2 * time.Duration(vault.RaftChallengesPerSecond))) + if err != nil { + require.Equal(t, 429, err.(*api.ResponseError).StatusCode) + someLaterFailed = true + } else { + someLaterSucceeded = true + } + } + } + require.True(t, someLaterFailed) + require.True(t, someLaterSucceeded) +} + func TestRaft_Join(t *testing.T) { t.Parallel() cluster, _ := raftCluster(t, &RaftClusterOpts{ diff --git a/vault/logical_system.go b/vault/logical_system.go index 49a1e3186212..4c10ef150f9b 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -56,6 +56,7 @@ import ( "github.com/hashicorp/vault/version" "github.com/mitchellh/mapstructure" "golang.org/x/crypto/sha3" + "golang.org/x/time/rate" ) const ( @@ -94,11 +95,12 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf } b := &SystemBackend{ - Core: core, - db: db, - logger: logger, - mfaBackend: NewPolicyMFABackend(core, logger), - syncBackend: syncBackend, + Core: core, + db: db, + logger: logger, + mfaBackend: NewPolicyMFABackend(core, logger), + syncBackend: syncBackend, + raftChallengeLimiter: rate.NewLimiter(rate.Limit(RaftChallengesPerSecond), RaftInitialChallengeLimit), } b.Backend = &framework.Backend{ @@ -270,11 +272,12 @@ func (b *SystemBackend) rawPaths() []*framework.Path { type SystemBackend struct { *framework.Backend entSystemBackend - Core *Core - db *memdb.MemDB - logger log.Logger - mfaBackend *PolicyMFABackend - syncBackend *SecretsSyncBackend + Core *Core + db *memdb.MemDB + logger log.Logger + mfaBackend *PolicyMFABackend + syncBackend *SecretsSyncBackend + raftChallengeLimiter *rate.Limiter } // handleConfigStateSanitized returns the current configuration state. The configuration diff --git a/vault/logical_system_raft.go b/vault/logical_system_raft.go index 08bac60241e2..32576cbb6ae6 100644 --- a/vault/logical_system_raft.go +++ b/vault/logical_system_raft.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/http" "strings" "time" @@ -272,6 +273,10 @@ func (b *SystemBackend) handleRaftRemovePeerUpdate() framework.OperationFunc { } } +const answerSize = 16 + +var answerMaxEncodedSize = base64.StdEncoding.EncodedLen(answerSize) + func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snapshot.Sealer) framework.OperationFunc { return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { serverID := d.Get("server_id").(string) @@ -280,25 +285,42 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap } var answer []byte - answerRaw, ok := b.Core.pendingRaftPeers.Load(serverID) + b.Core.pendingRaftPeersLock.RLock() + challenge, ok := b.Core.pendingRaftPeers.Get(serverID) + b.Core.pendingRaftPeersLock.RUnlock() if !ok { - var err error - answer, err = uuid.GenerateRandomBytes(16) - if err != nil { - return nil, err + if !b.raftChallengeLimiter.Allow() { + return logical.RespondWithStatusCode(logical.ErrorResponse("too many raft challenges in flight"), req, http.StatusTooManyRequests) } - b.Core.pendingRaftPeers.Store(serverID, answer) - } else { - answer = answerRaw.([]byte) - } - sealer := makeSealer() - if sealer == nil { - return nil, errors.New("core has no seal Access to write raft bootstrap challenge") - } - protoBlob, err := sealer.Seal(ctx, answer) - if err != nil { - return nil, err + b.Core.pendingRaftPeersLock.Lock() + defer b.Core.pendingRaftPeersLock.Unlock() + + challenge, ok = b.Core.pendingRaftPeers.Get(serverID) + if !ok { + + var err error + answer, err = uuid.GenerateRandomBytes(answerSize) + if err != nil { + return nil, err + } + + sealer := makeSealer() + if sealer == nil { + return nil, errors.New("core has no seal access to write raft bootstrap challenge") + } + protoBlob, err := sealer.Seal(ctx, answer) + if err != nil { + return nil, err + } + + challenge = &raftBootstrapChallenge{ + serverID: serverID, + answer: answer, + challenge: protoBlob, + } + b.Core.pendingRaftPeers.Add(serverID, challenge) + } } sealConfig, err := b.Core.seal.BarrierConfig(ctx) @@ -308,7 +330,7 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap return &logical.Response{ Data: map[string]interface{}{ - "challenge": base64.StdEncoding.EncodeToString(protoBlob), + "challenge": base64.StdEncoding.EncodeToString(challenge.challenge), "seal_config": sealConfig, }, }, nil @@ -330,6 +352,9 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc if len(answerRaw) == 0 { return logical.ErrorResponse("no answer provided"), logical.ErrInvalidRequest } + if len(answerRaw) > answerMaxEncodedSize { + return logical.ErrorResponse("answer is too long"), logical.ErrInvalidRequest + } clusterAddr := d.Get("cluster_addr").(string) if len(clusterAddr) == 0 { return logical.ErrorResponse("no cluster_addr provided"), logical.ErrInvalidRequest @@ -342,14 +367,17 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc return logical.ErrorResponse("could not base64 decode answer"), logical.ErrInvalidRequest } - expectedAnswerRaw, ok := b.Core.pendingRaftPeers.Load(serverID) + b.Core.pendingRaftPeersLock.Lock() + expectedChallenge, ok := b.Core.pendingRaftPeers.Get(serverID) if !ok { + b.Core.pendingRaftPeersLock.Unlock() return logical.ErrorResponse("no expected answer for the server id provided"), logical.ErrInvalidRequest } - b.Core.pendingRaftPeers.Delete(serverID) + b.Core.pendingRaftPeers.Remove(serverID) + b.Core.pendingRaftPeersLock.Unlock() - if subtle.ConstantTimeCompare(answer, expectedAnswerRaw.([]byte)) == 0 { + if subtle.ConstantTimeCompare(answer, expectedChallenge.answer) == 0 { return logical.ErrorResponse("invalid answer given"), logical.ErrInvalidRequest } diff --git a/vault/raft.go b/vault/raft.go index 7d34bef99c63..416334868f9b 100644 --- a/vault/raft.go +++ b/vault/raft.go @@ -25,6 +25,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/tlsutil" "github.com/hashicorp/go-uuid" goversion "github.com/hashicorp/go-version" + lru "github.com/hashicorp/golang-lru/v2" "github.com/hashicorp/vault/api" httpPriority "github.com/hashicorp/vault/http/priority" "github.com/hashicorp/vault/physical/raft" @@ -36,6 +37,9 @@ import ( ) const ( + RaftInitialChallengeLimit = 20 // allow an initial burst to 20 + RaftChallengesPerSecond = 5 // equating to an average 200ms min time + // undoLogMonitorInterval is how often the leader checks to see // if all the cluster members it knows about are new enough to support // undo logs. @@ -56,6 +60,12 @@ var ( ErrJoinWithoutAutoloading = errors.New("attempt to join a cluster using autoloaded licenses while not using autoloading ourself") ) +type raftBootstrapChallenge struct { + serverID string + answer []byte // the random answer + challenge []byte // the Sealed answer +} + // GetRaftNodeID returns the raft node ID if there is one, or an empty string if there's not func (c *Core) GetRaftNodeID() string { rb := c.getRaftBackend() @@ -314,7 +324,11 @@ func (c *Core) setupRaftActiveNode(ctx context.Context) error { c.logger.Info("starting raft active node") - c.pendingRaftPeers = &sync.Map{} + var err error + c.pendingRaftPeers, err = lru.New[string, *raftBootstrapChallenge](RaftInitialChallengeLimit) + if err != nil { + return err + } // Reload the raft TLS keys to ensure we are using the latest version. if err := c.checkRaftTLSKeyUpgrades(ctx); err != nil {