From b5e6315bb6f6a4912ded68c8e16afa5ffe167d9f Mon Sep 17 00:00:00 2001 From: Amanda Vialva <144278621+amandavialva01@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:56:48 -0400 Subject: [PATCH] fix: set max slots and checkpoint gc policy should comply with config policies (#10140) (cherry picked from commit 06b8f48f6ac9d27fe4dbbaec1de21a82475dc012) --- master/internal/api_experiment.go | 44 +++- master/internal/api_experiment_intg_test.go | 233 ++++++++++++++++++ .../postgres_task_config_policy.go | 71 +++--- .../postgres_task_config_policy_intg_test.go | 84 ++++--- master/internal/configpolicy/utils.go | 36 ++- master/internal/configpolicy/utils_test.go | 29 +-- master/internal/experiment.go | 10 +- 7 files changed, 383 insertions(+), 124 deletions(-) diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index 5fa1914f172..aa7aec53931 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -1132,21 +1132,38 @@ func (a *apiServer) PatchExperiment( } enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig]( - ctx, &w.ID, "invariant_config", "checkpoint_storage", + ctx, &w.ID, []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) if err != nil { return nil, fmt.Errorf("unable to fetch task config policies: %w", err) } + if enforcedChkptConf != nil { - activeConfig.SetCheckpointStorage(*enforcedChkptConf) + enforcedSaveExpBest := enforcedChkptConf.RawSaveExperimentBest + enforcedSaveTrialBest := enforcedChkptConf.RawSaveTrialBest + enforcedSaveTrialLatest := enforcedChkptConf.RawSaveTrialLatest + + if enforcedSaveExpBest != nil && + int(newCheckpointStorage.SaveExperimentBest) != *enforcedSaveExpBest { + return nil, + fmt.Errorf("save_experiment_best is enforced as an invariant config policy of %d", + *enforcedSaveExpBest) + } + if enforcedSaveTrialBest != nil && + int(newCheckpointStorage.SaveTrialBest) != *enforcedSaveTrialBest { + return nil, + fmt.Errorf("save_trial_best is enforced as an invariant config policy of %d", + *enforcedSaveTrialBest) + } + if enforcedSaveTrialLatest != nil && + int(newCheckpointStorage.SaveTrialLatest) != *enforcedSaveTrialLatest { + return nil, + fmt.Errorf("save_trial_latest is enforced as an invariant config policy of %d", + *enforcedSaveTrialLatest) + } } } - // `patch` represents the allowed mutations that can be performed on an experiment, in JSON - if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil { - return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID) - } - if newResources != nil { e, ok := experiment.ExperimentRegistry.Load(int(exp.Id)) if !ok { @@ -1155,6 +1172,15 @@ func (a *apiServer) PatchExperiment( if newResources.MaxSlots != nil { msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))} + w, err := getWorkspaceByConfig(activeConfig) + if err != nil { + return nil, status.Errorf(codes.Internal, err.Error()) + } + + err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, err.Error()) + } e.SetGroupMaxSlots(msg) } if newResources.Weight != nil { @@ -1171,6 +1197,10 @@ func (a *apiServer) PatchExperiment( } } + if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil { + return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID) + } + if newCheckpointStorage != nil { checkpoints, err := experiment.ExperimentCheckpointsToGCRaw( ctx, diff --git a/master/internal/api_experiment_intg_test.go b/master/internal/api_experiment_intg_test.go index 38c79146854..2eac388bfa4 100644 --- a/master/internal/api_experiment_intg_test.go +++ b/master/internal/api_experiment_intg_test.go @@ -36,6 +36,7 @@ import ( apiPkg "github.com/determined-ai/determined/master/internal/api" authz2 "github.com/determined-ai/determined/master/internal/authz" + "github.com/determined-ai/determined/master/internal/configpolicy" "github.com/determined-ai/determined/master/internal/db" expauth "github.com/determined-ai/determined/master/internal/experiment" "github.com/determined-ai/determined/master/internal/mocks" @@ -47,6 +48,7 @@ import ( "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" "github.com/determined-ai/determined/master/test/olddata" + "github.com/determined-ai/determined/master/test/testutils" "github.com/determined-ai/determined/proto/pkg/apiv1" "github.com/determined-ai/determined/proto/pkg/commonv1" "github.com/determined-ai/determined/proto/pkg/experimentv1" @@ -2348,3 +2350,234 @@ func TestGetWorkspaceByConfig(t *testing.T) { require.Equal(t, *wkspName, w.Name) }) } + +func TestPatchExperiment(t *testing.T) { + mockRM := MockRM() + testutils.MustLoadLicenseAndKeyFromFilesystem("../../") + + api, _, ctx := setupAPITest(t, nil, mockRM) + conf := ` +entrypoint: test +searcher: + metric: loss + name: single + max_length: 10 +resources: + resource_pool: kubernetes +checkpoint_storage: + type: shared_fs + host_path: /etc + storage_path: determined-integration-checkpoints +` + createReq := &apiv1.CreateExperimentRequest{ + ModelDefinition: []*utilv1.File{{Content: []byte{1}}}, + Config: conf, + ParentId: 0, + Activate: false, + ProjectId: 1, + } + + mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + expResp, err := api.CreateExperiment(ctx, createReq) + require.NoError(t, err) + + // Create global invariant config policy with checkpoint storage. + _, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{ + WorkloadType: model.ExperimentType, + ConfigPolicies: ` +invariant_config: + checkpoint_storage: + type: shared_fs + host_path: /tmp + storage_path: determined-integration-checkpoints + + save_experiment_best: 10 + save_trial_best: 11 + save_trial_latest: 12 +`, + }) + require.NoError(t, err) + + t.Run("save exp best config differs", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{ + SaveExperimentBest: 1, + SaveTrialBest: 11, + SaveTrialLatest: 12, + }, + }, + }) + require.ErrorContains(t, err, "invariant config policy") + }) + + t.Run("save trial best config differs", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{ + SaveExperimentBest: 10, + SaveTrialBest: 1, + SaveTrialLatest: 12, + }, + }, + }) + require.ErrorContains(t, err, "invariant config policy") + }) + + t.Run("save trial latest config differs", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{ + SaveExperimentBest: 10, + SaveTrialBest: 11, + SaveTrialLatest: 1, + }, + }, + }) + require.ErrorContains(t, err, "invariant config policy") + }) + + t.Run("chkpt config matches invariant config", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{ + SaveExperimentBest: 10, + SaveTrialBest: 11, + SaveTrialLatest: 12, + }, + }, + }) + require.NoError(t, err) + }) + + // Set global invariant config policy with resources.max_slots. + _, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{ + WorkloadType: model.ExperimentType, + ConfigPolicies: ` +invariant_config: + resources: + max_slots: 23 +`, + }) + require.NoError(t, err) + t.Run("max slots differs", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + MaxSlots: ptrs.Ptr[int32](20), + }, + }, + }) + require.ErrorContains(t, err, configpolicy.SlotsAlreadySetErr) + }) + + t.Run("max slots matches", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + MaxSlots: ptrs.Ptr[int32](23), + }, + }, + }) + require.NoError(t, err) + }) + + // Set global constraints policy with resources.max_slots. + _, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{ + WorkloadType: model.ExperimentType, + ConfigPolicies: ` +constraints: + resources: + max_slots: 23 +`, + }) + require.NoError(t, err) + + t.Run("max slots violates constraint", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + MaxSlots: ptrs.Ptr[int32](30), + }, + }, + }) + require.ErrorContains(t, err, configpolicy.SlotsReqTooHighErr) + }) + + t.Run("max slots complies with constraint", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + MaxSlots: ptrs.Ptr[int32](10), + }, + }, + }) + require.NoError(t, err) + }) + + // Set global invariant config policy with resources.weight. + _, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{ + WorkloadType: model.ExperimentType, + ConfigPolicies: ` +invariant_config: + resources: + weight: 23 +`, + }) + require.NoError(t, err) + + t.Run("weight config differs", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + Weight: ptrs.Ptr[float64](30), + }, + }, + }) + require.ErrorContains(t, err, "invariant config policy") + }) + + t.Run("weight config matches", func(t *testing.T) { + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + Weight: ptrs.Ptr[float64](23), + }, + }, + }) + require.NoError(t, err) + }) + + t.Run("no config policies", func(t *testing.T) { + _, err = api.DeleteGlobalConfigPolicies(ctx, &apiv1.DeleteGlobalConfigPoliciesRequest{ + WorkloadType: model.ExperimentType, + }) + require.NoError(t, err) + + _, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{ + Experiment: &experimentv1.PatchExperiment{ + Id: expResp.Experiment.Id, + Resources: &experimentv1.PatchExperiment_PatchResources{ + MaxSlots: ptrs.Ptr[int32](5), + Weight: ptrs.Ptr[float64](20), + }, + CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{ + SaveExperimentBest: 1, + SaveTrialBest: 2, + SaveTrialLatest: 3, + }, + }, + }) + require.NoError(t, err) + }) +} diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index 11cf96ff24b..1cba6131f00 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -111,53 +111,60 @@ func DeleteConfigPolicies(ctx context.Context, return nil } -// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order -// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found. -// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string -// exactly as you wish for it to be accessed in the database. For example, if you want to access -// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT -// "resources -> max_slots". +// GetConfigPolicyField fetches the accessField from an invariant_config or constraints policy +// (determined by policyType) in order of precedence. Global policies takes precedence over workspace +// policies. Returns nil if the accessField is not set at either scope. +// **NOTE** The accessField elements are to be specified in the "order of access", meaning that the +// most nested config field should be the last element of accessField while the outermost +// config field should be the first element of accessField. +// For example, if you want to access resources.max_slots, accessField should be +// []string{"resources", "max_slots"}. If you just want to access the entire resources config, then +// accessField should be []string{"resources"}. // **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of // object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so -// that when its pointer is returned, you get an object of type *int. -func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T, +// that when its pointer is returned, we get an object of type *int. +func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, accessField []string, policyType, + workloadType string) (*T, error, ) { if policyType != "invariant_config" && policyType != "constraints" { return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType) } + field := "'" + strings.Join(accessField, "' -> '") + "'" var confBytes []byte var conf T - err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - var globalBytes []byte - err := tx.NewSelect().Table("task_config_policies"). - ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). - Where("workspace_id IS NULL"). - Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes) - if err == nil && len(globalBytes) > 0 { - confBytes = globalBytes - } - if err != nil && err != sql.ErrNoRows { - return err + var globalBytes []byte + err := db.Bun().NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). + Where("workspace_id IS NULL"). + Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes) + if err == nil && len(globalBytes) > 0 { + err = json.Unmarshal(globalBytes, &conf) + if err != nil { + return nil, fmt.Errorf("error unmarshaling config field: %w", err) } + return &conf, nil + } + if err != nil && err != sql.ErrNoRows { + return nil, err + } - var wkspBytes []byte - err = tx.NewSelect().Table("task_config_policies"). - ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). - Where("workspace_id = ?", wkspID). - Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes) - if err == nil && len(globalBytes) == 0 { - confBytes = wkspBytes - } - return err - }) + var wkspBytes []byte + err = db.Bun().NewSelect().Table("task_config_policies"). + ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)). + Where("workspace_id = ?", wkspID). + Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes) + if err != nil && err != sql.ErrNoRows { + return nil, fmt.Errorf("error getting config field %s: %w", field, err) + } + if len(globalBytes) == 0 { + confBytes = wkspBytes + } if err == sql.ErrNoRows || len(confBytes) == 0 { + // The field is not enforced as a config policy. Should not be an error. return nil, nil } - if err != nil { - return nil, fmt.Errorf("error getting config field %s: %w", field, err) - } err = json.Unmarshal(confBytes, &conf) if err != nil { diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 9a35dfee6e8..347d89b8d9f 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -664,8 +664,8 @@ func requireEqualTaskPolicy(t *testing.T, exp *model.TaskConfigPolicies, act *mo func TestGetEnforcedConfig(t *testing.T) { ctx := context.Background() require.NoError(t, etc.SetRootPath(db.RootFromDB)) - pgDB, _ := db.MustResolveNewPostgresDatabase(t) - // defer cleanup() + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) user := db.RequireMockUser(t, pgDB) @@ -682,19 +682,9 @@ func TestGetEnforcedConfig(t *testing.T) { "container_path": "global_container_path" } } -` - wkspConf := ` -{ - "checkpoint_storage": { - "type": "shared_fs", - "host_path": "wksp_host_path", - "container_path": "wksp_container_path", - "checkpoint_path": "wksp_checkpoint_path" - } -} ` - t.Run("checkpoint storage config", func(t *testing.T) { + t.Run("checkpoint storage config just global", func(t *testing.T) { err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkloadType: model.ExperimentType, LastUpdatedBy: user.ID, @@ -702,6 +692,31 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) + checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, + []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) + require.NoError(t, err) + require.NotNil(t, checkpointStorage) + + // global config enforced? + require.Equal(t, expconf.CheckpointStorageConfigV0{ + RawSharedFSConfig: &expconf.SharedFSConfigV0{ + RawHostPath: ptrs.Ptr("global_host_path"), + RawContainerPath: ptrs.Ptr("global_container_path"), + }, + }, *checkpointStorage) + }) + + wkspConf := ` + { + "checkpoint_storage": { + "type": "shared_fs", + "host_path": "wksp_host_path", + "container_path": "wksp_container_path", + "checkpoint_path": "wksp_checkpoint_path" + } + } + ` + t.Run("checkpoint storage config global and wksp", func(t *testing.T) { err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkspaceID: &w.ID, WorkloadType: model.ExperimentType, @@ -711,7 +726,7 @@ func TestGetEnforcedConfig(t *testing.T) { require.NoError(t, err) checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, - "invariant_config", "'checkpoint_storage'", model.ExperimentType) + []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.NotNil(t, checkpointStorage) @@ -751,8 +766,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", - "'resources' -> 'max_slots'", model.ExperimentType) + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"resources", "max_slots"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) @@ -767,16 +782,14 @@ func TestGetEnforcedConfig(t *testing.T) { } } ` - wkspConstraints := ` - { - "resources": { - "max_slots": 20 - } +{ + "resources": { + "max_slots": 20 } +} ` - - t.Run("max slots constraints", func(t *testing.T) { + t.Run("max slots constraints just global", func(t *testing.T) { err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkloadType: model.ExperimentType, LastUpdatedBy: user.ID, @@ -792,8 +805,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'resources' -> 'max_slots'", model.ExperimentType) + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"resources", "max_slots"}, "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) @@ -812,7 +825,6 @@ func TestGetEnforcedConfig(t *testing.T) { "priority_limit": 50 } ` - t.Run("priority constraints", func(t *testing.T) { err = SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ WorkloadType: model.ExperimentType, @@ -829,8 +841,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'priority_limit'", model.ExperimentType) + priority, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"priority_limit"}, "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, priority) @@ -851,8 +863,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'priority_limit'", model.ExperimentType) + priority, err := GetConfigPolicyField[int](ctx, &w.ID, []string{"priority_limit"}, + "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, priority) @@ -861,22 +873,22 @@ func TestGetEnforcedConfig(t *testing.T) { }) t.Run("field not set in config", func(t *testing.T) { - maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", - "'max_restarts'", model.ExperimentType) + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"max_restarts"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.Nil(t, maxRestarts) }) t.Run("nonexistent constraints field", func(t *testing.T) { - maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'max_restarts'", model.ExperimentType) + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"max_restarts"}, "constraints", model.ExperimentType) require.NoError(t, err) require.Nil(t, maxRestarts) }) t.Run("invalid policy type", func(t *testing.T) { - _, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy", - "'debug'", model.ExperimentType) + _, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"debug"}, "bad policy", model.ExperimentType) require.ErrorContains(t, err, invalidPolicyTypeErr) }) } diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 7f08edd51ae..86953523db7 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -31,6 +31,8 @@ const ( // SlotsReqTooHighErr is the error reported when the requested slots violates the max slots // constraint. SlotsReqTooHighErr = "requested slots is violates max slots constraint" + // SlotsAlreadySetErr is the error reported when slots are already set in an invariant config. + SlotsAlreadySetErr = "max slots is already set in an invariant config policy" ) // ConfigPolicyWarning logs a warning for the configuration policy component. @@ -303,39 +305,31 @@ func configPolicyOverlap(config1, config2 interface{}) { } } -// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the -// enforced max slots for the workspace if that's set as an invariant config, and returns the -// requested max slots otherwise. Returns an error when max slots is not set as an invariant config -// and the requested max slots violates the constriant. -func CanSetMaxSlots(slotsReq *int, wkspID int) (*int, error) { +// CanSetMaxSlots returns an error if slotsReq differs from an invariant config or violates a +// constraint. Otherwise, it returns nil. +func CanSetMaxSlots(slotsReq *int, wkspID int) error { if slotsReq == nil { - return slotsReq, nil + return nil } enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID, - "invariant_config", - "'resources' -> 'max_slots'", model.ExperimentType) + []string{"resources", "max_slots"}, "invariant_config", model.ExperimentType) if err != nil { - return nil, err + return err } - if enforcedMaxSlots != nil { - return enforcedMaxSlots, nil + if enforcedMaxSlots != nil && *slotsReq != *enforcedMaxSlots { + return fmt.Errorf(SlotsAlreadySetErr+":max_slots of %d is enforced", *enforcedMaxSlots) } maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID, - "constraints", - "'resources' -> 'max_slots'", model.ExperimentType) + []string{"resources", "max_slots"}, "constraints", model.ExperimentType) if err != nil { - return nil, err + return err } - var canSetReqSlots bool - if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit { - canSetReqSlots = true - } - if !canSetReqSlots { - return nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) + if maxSlotsLimit != nil && *slotsReq > *maxSlotsLimit { + return fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) } - return slotsReq, nil + return nil } diff --git a/master/internal/configpolicy/utils_test.go b/master/internal/configpolicy/utils_test.go index dd5f80820d2..762d7948f03 100644 --- a/master/internal/configpolicy/utils_test.go +++ b/master/internal/configpolicy/utils_test.go @@ -624,9 +624,8 @@ func TestCanSetMaxSlots(t *testing.T) { ctx := context.Background() w := createWorkspaceWithUser(ctx, t, user.ID) t.Run("nil slots request", func(t *testing.T) { - slots, err := CanSetMaxSlots(nil, w.ID) + err := CanSetMaxSlots(nil, w.ID) require.NoError(t, err) - require.Nil(t, slots) }) err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ @@ -639,29 +638,18 @@ func TestCanSetMaxSlots(t *testing.T) { "max_slots": 13 } } -`), - Constraints: ptrs.Ptr(` -{ - "resources": { - "max_slots": 13 - } -} `), }) require.NoError(t, err) t.Run("slots different than config higher", func(t *testing.T) { - slots, err := CanSetMaxSlots(ptrs.Ptr(15), w.ID) - require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 13, *slots) + err := CanSetMaxSlots(ptrs.Ptr(15), w.ID) + require.ErrorContains(t, err, SlotsAlreadySetErr) }) t.Run("slots different than config lower", func(t *testing.T) { - slots, err := CanSetMaxSlots(ptrs.Ptr(10), w.ID) - require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 13, *slots) + err := CanSetMaxSlots(ptrs.Ptr(10), w.ID) + require.ErrorContains(t, err, SlotsAlreadySetErr) }) t.Run("just constraints slots higher", func(t *testing.T) { @@ -679,9 +667,8 @@ func TestCanSetMaxSlots(t *testing.T) { }) require.NoError(t, err) - slots, err := CanSetMaxSlots(ptrs.Ptr(25), w.ID) + err = CanSetMaxSlots(ptrs.Ptr(25), w.ID) require.ErrorContains(t, err, SlotsReqTooHighErr) - require.Nil(t, slots) }) t.Run("just constraints slots lower", func(t *testing.T) { @@ -699,9 +686,7 @@ func TestCanSetMaxSlots(t *testing.T) { }) require.NoError(t, err) - slots, err := CanSetMaxSlots(ptrs.Ptr(20), w.ID) + err = CanSetMaxSlots(ptrs.Ptr(20), w.ID) require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 20, *slots) }) } diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 2cee92b62fc..0b533aa2503 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -374,13 +374,12 @@ func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) { return } - slots, err := configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) if err != nil { log.Warnf("unable to set max slots: %s", err.Error()) return } - msg.MaxSlots = slots resources := e.activeConfig.Resources() resources.SetMaxSlots(msg.MaxSlots) e.activeConfig.SetResources(resources) @@ -945,13 +944,12 @@ func (e *internalExperiment) setWeight(weight float64) error { return fmt.Errorf("error getting workspace: %w", err) } enforcedWeight, err := configpolicy.GetConfigPolicyField[float64](context.TODO(), &w.ID, - "invariant_config", - "'resources' -> 'weight'", model.ExperimentType) + []string{"resources", "weight"}, "invariant_config", model.ExperimentType) if err != nil { return fmt.Errorf("error checking against config policies: %w", err) } - if enforcedWeight != nil { - weight = *enforcedWeight + if enforcedWeight != nil && weight != *enforcedWeight { + return fmt.Errorf("weight is enforced as an invariant config policy of %v", *enforcedWeight) } resources := e.activeConfig.Resources()