Skip to content

Commit

Permalink
chore: experiment config slots to comply with constraint max slots (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 authored Oct 15, 2024
1 parent 1d5c984 commit 3fc9fed
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
6 changes: 3 additions & 3 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ invariant_config:
// Additional NTSC combinatory tests (YAML).
{
"YAML NTSC valid config valid constraints", model.NTSCType,
validNTSCConfigPolicyYAML + validConstraintsPolicyYAML, fmt.Errorf("invalid ntsc config policy"),
validNTSCConfigPolicyYAML + validConstraintsPolicyYAML, nil,
},
{
"YAML NTSC valid constraints invalid constraints", model.NTSCType,
Expand Down Expand Up @@ -1192,8 +1192,8 @@ func TestValidatePoliciesAndWorkloadTypeJSON(t *testing.T) {

// Additional NTSC combinatory tests (JSON).
{
"JSON NTSC valid config invalid constraints", model.NTSCType,
"{" + validNTSCConfigPolicyJSON + "," + validConstraintsPolicyJSON + "}", fmt.Errorf("invalid ntsc config policy"),
"JSON NTSC valid config valid constraints", model.NTSCType,
"{" + validNTSCConfigPolicyJSON + "," + validConstraintsPolicyJSON + "}", nil,
},
{
"JSON NTSC valid constraints invalid constraints", model.NTSCType,
Expand Down
29 changes: 20 additions & 9 deletions master/internal/configpolicy/task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func CheckNTSCConstraints(
}

if constraints.ResourceConstraints != nil && constraints.ResourceConstraints.MaxSlots != nil {
if err = checkSlotsConstraint(*constraints.ResourceConstraints.MaxSlots, workloadConfig.Resources.Slots,
if err = checkSlotsConstraint(*constraints.ResourceConstraints.MaxSlots, &workloadConfig.Resources.Slots,
workloadConfig.Resources.MaxSlots); err != nil {
return err
}
Expand Down Expand Up @@ -88,10 +88,19 @@ func CheckExperimentConstraints(

if constraints.ResourceConstraints != nil && constraints.ResourceConstraints.MaxSlots != nil {
// users cannot specify number of slots for an experiment
slotsRequest := *constraints.ResourceConstraints.MaxSlots
if err = checkSlotsConstraint(*constraints.ResourceConstraints.MaxSlots, slotsRequest,
workloadConfig.Resources().MaxSlots()); err != nil {
return err
if workloadConfig.RawResources != nil {
slotsRequest := workloadConfig.RawResources.RawSlotsPerTrial
if err = checkSlotsConstraint(*constraints.ResourceConstraints.MaxSlots,
slotsRequest,
workloadConfig.Resources().MaxSlots()); err != nil {
return err
}
slotsRequest = workloadConfig.RawResources.RawMaxSlots
if err = checkSlotsConstraint(*constraints.ResourceConstraints.MaxSlots,
slotsRequest,
workloadConfig.Resources().MaxSlots()); err != nil {
return err
}
}
}

Expand Down Expand Up @@ -121,10 +130,12 @@ func checkPriorityConstraint(smallerHigher bool, priorityLimit *int, priorityReq
return nil
}

func checkSlotsConstraint(slotsLimit int, slotsRequest int, maxSlotsRequest *int) error {
if slotsLimit < slotsRequest {
return fmt.Errorf("requested resources.slots [%d] exceeds limit set by admin [%d]: %w",
slotsRequest, slotsLimit, errResourceConstraintFailure)
func checkSlotsConstraint(slotsLimit int, slotsRequest *int, maxSlotsRequest *int) error {
if slotsRequest != nil {
if slotsLimit < *slotsRequest {
return fmt.Errorf("requested resources.slots [%d] exceeds limit set by admin [%d]: %w",
slotsRequest, slotsLimit, errResourceConstraintFailure)
}
}

if maxSlotsRequest != nil {
Expand Down
6 changes: 3 additions & 3 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ func checkConstraintConflicts(constraints *model.Constraints, maxSlots, slots, p
if maxSlots != nil && *constraints.ResourceConstraints.MaxSlots != *maxSlots {
return fmt.Errorf("invariant config & constraints are trying to set the max slots")
}
if slots != nil && *constraints.ResourceConstraints.MaxSlots > *slots {
return fmt.Errorf("invariant config & constraints are attempting to set an invalid max slot123: %v vs %v",
*constraints.ResourceConstraints.MaxSlots, *slots)
if slots != nil && *constraints.ResourceConstraints.MaxSlots < *slots {
return fmt.Errorf("invariant config has %v slots per trial. violates constraints max slots of %v",
*slots, *constraints.ResourceConstraints.MaxSlots)
}

return nil
Expand Down
43 changes: 43 additions & 0 deletions master/internal/configpolicy/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gotest.tools/assert"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)

Expand Down Expand Up @@ -309,3 +310,45 @@ func TestUnmarshalJSONNTSC(t *testing.T) {
})
}
}

func TestCheckConstraintsConflicts(t *testing.T) {
constraints := &model.Constraints{
ResourceConstraints: &model.ResourceConstraints{
MaxSlots: ptrs.Ptr(10),
},
PriorityLimit: ptrs.Ptr(50),
}
t.Run("max_slots differs to high", func(t *testing.T) {
err := checkConstraintConflicts(constraints, ptrs.Ptr(11), ptrs.Ptr(5), nil)
require.Error(t, err)
})
t.Run("max_slots differs to low", func(t *testing.T) {
err := checkConstraintConflicts(constraints, ptrs.Ptr(9), ptrs.Ptr(5), nil)
require.Error(t, err)
})

t.Run("slots_per_trial too high", func(t *testing.T) {
err := checkConstraintConflicts(constraints, ptrs.Ptr(5), ptrs.Ptr(11), nil)
require.Error(t, err)
})

t.Run("slots_per_trial within range", func(t *testing.T) {
err := checkConstraintConflicts(constraints, ptrs.Ptr(10), ptrs.Ptr(8), nil)
require.NoError(t, err)
})

t.Run("priority differs too high", func(t *testing.T) {
err := checkConstraintConflicts(constraints, nil, nil, ptrs.Ptr(100))
require.Error(t, err)
})

t.Run("priority differs too low", func(t *testing.T) {
err := checkConstraintConflicts(constraints, nil, nil, ptrs.Ptr(10))
require.Error(t, err)
})

t.Run("all comply", func(t *testing.T) {
err := checkConstraintConflicts(constraints, ptrs.Ptr(10), ptrs.Ptr(10), ptrs.Ptr(50))
require.NoError(t, err)
})
}

0 comments on commit 3fc9fed

Please sign in to comment.