diff --git a/internal/cloud/state.go b/internal/cloud/state.go index 37da63936532..e7670a665eeb 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -9,9 +9,12 @@ import ( "errors" "fmt" "log" + "net/http" "os" + "strconv" "strings" "sync" + "time" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/gocty" @@ -54,6 +57,15 @@ type State struct { stateUploadErr bool forcePush bool lockInfo *statemgr.LockInfo + + // The server can optionally return an X-Terraform-Snapshot-Interval header + // in its response to the "Create State Version" operation, which specifies + // a number of seconds the server would prefer us to wait before trying + // to write a new snapshot. If this is non-zero then we'll wait at least + // this long before allowing another intermediate snapshot. This does + // not effect final snapshots after an operation, which will always + // be written to the remote API. + stateSnapshotInterval time.Duration } var ErrStateVersionUnauthorizedUpgradeState = errors.New(strings.TrimSpace(` @@ -224,11 +236,20 @@ func (s *State) PersistState(schemas *terraform.Schemas) error { } // ShouldPersistIntermediateState implements local.IntermediateStateConditionalPersister -func (*State) ShouldPersistIntermediateState(info *local.IntermediateStatePersistInfo) bool { - // We currently don't create intermediate snapshots for Terraform Cloud or - // Terraform Enterprise at all, to avoid extra storage costs for Terraform - // Enterprise customers. - return false +func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePersistInfo) bool { + if info.ForcePersist { + return true + } + + // Our persist interval is the largest of either the caller's requested + // interval or the server's requested interval. + wantInterval := info.RequestedPersistInterval + if s.stateSnapshotInterval > wantInterval { + wantInterval = s.stateSnapshotInterval + } + + currentInterval := time.Since(info.LastPersist) + return currentInterval >= wantInterval } func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error { @@ -250,6 +271,30 @@ func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, sta if runID != "" { options.Run = &tfe.Run{ID: runID} } + + // The server is allowed to dynamically request a different time interval + // than we'd normally use, for example if it's currently under heavy load + // and needs clients to backoff for a while. + ctx = tfe.ContextWithResponseHeaderHook(ctx, func(status int, header http.Header) { + intervalStr := header.Get("x-terraform-snapshot-interval") + + if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil { + if intervalSecs > 3600 { + // More than an hour is an unreasonable delay, so we'll just + // saturate at one hour. + intervalSecs = 3600 + } else if intervalSecs < 0 { + intervalSecs = 0 + } + s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second + } else { + // If the header field is either absent or invalid then we'll + // just choose zero, which effectively means that we'll just use + // the caller's requested interval instead. + s.stateSnapshotInterval = time.Duration(0) + } + }) + // Create the new state. _, err := s.tfeClient.StateVersions.Create(ctx, s.workspace.ID, options) return err diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index f03bd15c53e5..0168272012df 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -3,12 +3,21 @@ package cloud import ( "bytes" "context" + "encoding/json" "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" "testing" + "time" tfe "github.com/hashicorp/go-tfe" + "github.com/hashicorp/terraform/internal/addrs" + "github.com/hashicorp/terraform/internal/backend/local" + "github.com/hashicorp/terraform/internal/states" "github.com/hashicorp/terraform/internal/states/statefile" "github.com/hashicorp/terraform/internal/states/statemgr" + "github.com/zclconf/go-cty/cty" ) func TestState_impl(t *testing.T) { @@ -252,9 +261,9 @@ func TestDelete_SafeDelete(t *testing.T) { } func TestState_PersistState(t *testing.T) { - cloudState := testCloudState(t) - t.Run("Initial PersistState", func(t *testing.T) { + cloudState := testCloudState(t) + if cloudState.readState != nil { t.Fatal("expected nil initial readState") } @@ -269,4 +278,224 @@ func TestState_PersistState(t *testing.T) { t.Fatalf("expected initial state readSerial to be %d, got %d", expectedSerial, cloudState.readSerial) } }) + + t.Run("Snapshot Interval Backpressure Header", func(t *testing.T) { + // The "Create a State Version" API is allowed to return a special + // HTTP response header X-Terraform-Snapshot-Interval, in which case + // we should remember the number of seconds it specifies and delay + // creating any more intermediate state snapshots for that many seconds. + + cloudState := testCloudState(t) + + if cloudState.stateSnapshotInterval != 0 { + t.Error("state manager already has a nonzero snapshot interval") + } + + // For this test we'll use a real client talking to a fake server, + // since HTTP-level concerns like headers are out of scope for the + // mock client we typically use in other tests in this package, which + // aim to abstract away HTTP altogether. + var serverURL string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.Method, r.URL.String()) + + if r.URL.Path == "/state-json" { + t.Log("pretending to be Archivist") + fakeState := states.NewState() + fakeStateFile := statefile.New(fakeState, "boop", 1) + var buf bytes.Buffer + statefile.Write(fakeStateFile, &buf) + respBody := buf.Bytes() + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.FormatInt(int64(len(respBody)), 10)) + w.WriteHeader(http.StatusOK) + w.Write(respBody) + return + } + if r.URL.Path == "/api/ping" { + t.Log("pretending to be Ping") + w.WriteHeader(http.StatusNoContent) + return + } + + fakeBody := map[string]any{ + "data": map[string]any{ + "type": "state-versions", + "attributes": map[string]any{ + "hosted-state-download-url": serverURL + "/state-json", + }, + }, + } + fakeBodyRaw, err := json.Marshal(fakeBody) + if err != nil { + t.Fatal(err) + } + + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.FormatInt(int64(len(fakeBodyRaw)), 10)) + + switch r.Method { + case "POST": + t.Log("pretending to be Create a State Version") + w.Header().Set("x-terraform-snapshot-interval", "300") + w.WriteHeader(http.StatusAccepted) + case "GET": + t.Log("pretending to be Fetch the Current State Version for a Workspace") + w.WriteHeader(http.StatusOK) + default: + t.Fatal("don't know what API operation this was supposed to be") + } + + w.WriteHeader(http.StatusOK) + w.Write(fakeBodyRaw) + })) + serverURL = server.URL + cfg := &tfe.Config{ + Address: server.URL, + BasePath: "api", + Token: "placeholder", + } + client, err := tfe.NewClient(cfg) + if err != nil { + t.Fatal(err) + } + cloudState.tfeClient = client + + err = cloudState.RefreshState() + if err != nil { + t.Fatal(err) + } + cloudState.WriteState(states.BuildState(func(s *states.SyncState) { + s.SetOutputValue( + addrs.OutputValue{Name: "boop"}.Absolute(addrs.RootModuleInstance), + cty.StringVal("beep"), false, + ) + })) + + err = cloudState.PersistState(nil) + if err != nil { + t.Fatal(err) + } + + // The PersistState call above should have sent a request to the test + // server and got back the x-terraform-snapshot-interval header, whose + // value should therefore now be recorded in the relevant field. + if got, want := cloudState.stateSnapshotInterval, 300*time.Second; got != want { + t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, want) + } + }) +} + +func TestState_ShouldPersistIntermediateState(t *testing.T) { + cloudState := testCloudState(t) + + // We'll specify a normal interval and a server-supplied interval that + // have enough time between them that we can be confident that the + // fake timestamps we'll pass into ShouldPersistIntermediateState are + // either too soon for normal, long enough for normal but not for server, + // or too long for server. + shortServerInterval := 5 * time.Second + normalInterval := 60 * time.Second + longServerInterval := 120 * time.Second + beforeNormalInterval := 20 * time.Second + betweenInterval := 90 * time.Second + afterLongServerInterval := 200 * time.Second + + // Before making any requests the state manager should just respect the + // normal interval, because it hasn't yet heard a request from the server. + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-beforeNormalInterval), + }) + if should { + t.Errorf("indicated that should persist before normal interval") + } + } + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-betweenInterval), + }) + if !should { + t.Errorf("indicated that should not persist after normal interval") + } + } + + // After making a request to the "Create a State Version" operation, the + // server might return a header that causes us to set this field: + cloudState.stateSnapshotInterval = shortServerInterval + + // The short server interval is shorter than the normal interval, so the + // normal interval takes priority here. + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-beforeNormalInterval), + }) + if should { + t.Errorf("indicated that should persist before normal interval") + } + } + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-betweenInterval), + }) + if !should { + t.Errorf("indicated that should not persist after normal interval") + } + } + + // The server might instead request a longer interval. + cloudState.stateSnapshotInterval = longServerInterval + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-beforeNormalInterval), + }) + if should { + t.Errorf("indicated that should persist before server interval") + } + } + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-betweenInterval), + }) + if should { + t.Errorf("indicated that should persist before server interval") + } + } + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-afterLongServerInterval), + }) + if !should { + t.Errorf("indicated that should not persist after server interval") + } + } + + // The "force" mode should always win, regardless of how much time has passed. + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-beforeNormalInterval), + ForcePersist: true, + }) + if !should { + t.Errorf("ignored ForcePersist") + } + } + { + should := cloudState.ShouldPersistIntermediateState(&local.IntermediateStatePersistInfo{ + RequestedPersistInterval: normalInterval, + LastPersist: time.Now().Add(-betweenInterval), + ForcePersist: true, + }) + if !should { + t.Errorf("ignored ForcePersist") + } + } }