Skip to content

Commit

Permalink
[batch/job] Partial admission
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed May 18, 2023
1 parent 6ad794e commit d5c958c
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 17 deletions.
7 changes: 4 additions & 3 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}

log.V(3).Info("restore podSetsInfo from annotation")
info, err := getPodSetsInfoFromObjectAnnotation(object, wl.Spec.PodSets)
info, err := getPodSetsInfoFromObjectAnnotation(object, job)
if err != nil {
log.V(3).Error(err, "Unable to get original podSetsInfo")
} else {
Expand Down Expand Up @@ -527,7 +527,7 @@ func cloneNodeSelector(src map[string]string) map[string]string {

// getPodSetsInfoFromObjectAnnotation tries to retrieve a podSetsInfo slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet) ([]PodSetInfo, error) {
func getPodSetsInfoFromObjectAnnotation(obj client.Object, job GenericJob) ([]PodSetInfo, error) {
hasCounts := true
str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation]
if !found {
Expand All @@ -544,7 +544,8 @@ func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet)
}

if !hasCounts {
psMap := utilslice.ToRefMap(spec, func(ps *kueue.PodSet) string { return ps.Name })
podSets := job.PodSets()
psMap := utilslice.ToRefMap(podSets, func(ps *kueue.PodSet) string { return ps.Name })
for i := range ret {
info := &ret[i]
ps, found := psMap[info.Name]
Expand Down
56 changes: 45 additions & 11 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package job

import (
"context"
"strconv"

batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -48,6 +49,10 @@ var (
FrameworkName = "batch/job"
)

const (
JobMinParallelismAnnotation = "kueue.x-k8s.io/job-min-parallelism"
)

func init() {
utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
SetupIndexes: SetupIndexes,
Expand Down Expand Up @@ -173,36 +178,47 @@ func (j *Job) ReclaimablePods() []kueue.ReclaimablePod {
func (j *Job) PodSets() []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Count: j.podsCount(),
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Count: j.podsCount(),
MinimumCount: j.minPodsCount(),
},
}
}

func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
if len(podSetsInfo) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
j.Spec.Template.Spec.NodeSelector = podSetsInfo[0].NodeSelector
} else {
for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}

func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
if len(podSetsInfo) == 0 {
return
}

// if partial admission is enabled
if j.minPodsCount() != nil {
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}

if equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSetsInfo[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down Expand Up @@ -237,7 +253,16 @@ func (j *Job) EquivalentToWorkload(wl kueue.Workload) bool {
return false
}

if *j.Spec.Parallelism != wl.Spec.PodSets[0].Count {
ps0 := &wl.Spec.PodSets[0]
if mpc := j.minPodsCount(); mpc != nil {
if pointer.Int32Deref(ps0.MinimumCount, -1) != *mpc {
return false
}

if j.IsSuspended() && j.podsCount() != ps0.Count {
return false
}
} else if j.podsCount() != ps0.Count {
return false
}

Expand Down Expand Up @@ -269,6 +294,15 @@ func (j *Job) podsCount() int32 {
return podsCount
}

func (j *Job) minPodsCount() *int32 {
if strVal, found := j.GetAnnotations()[JobMinParallelismAnnotation]; found {
if iVal, err := strconv.Atoi(strVal); err == nil {
return pointer.Int32(int32(iVal))
}
}
return nil
}

// SetupWithManager sets up the controller with the Manager. It indexes workloads
// based on the owning jobs.
func (r *JobReconciler) SetupWithManager(mgr ctrl.Manager) error {
Expand Down
38 changes: 36 additions & 2 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@ package job

import (
"context"
"fmt"
"strconv"

batchv1 "k8s.io/api/batch/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/klog/v2"
"k8s.io/utils/pointer"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook"

"sigs.k8s.io/kueue/pkg/controller/jobframework"
)

var (
minPodsCountAnnotationsPath = field.NewPath("metadata", "annotations").Key(JobMinParallelismAnnotation)
)

type JobWebhook struct {
manageJobsWithoutQueueName bool
}
Expand Down Expand Up @@ -87,10 +94,26 @@ func (w *JobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) err
return validateCreate(&Job{job}).ToAggregate()
}

func validateCreate(job jobframework.GenericJob) field.ErrorList {
func validateCreate(job *Job) field.ErrorList {
var allErrs field.ErrorList
allErrs = append(allErrs, jobframework.ValidateAnnotationAsCRDName(job, jobframework.ParentWorkloadAnnotation)...)
allErrs = append(allErrs, jobframework.ValidateCreateForQueueName(job)...)
allErrs = append(allErrs, validatePartialAdmissionCreate(job)...)
return allErrs
}

func validatePartialAdmissionCreate(job *Job) field.ErrorList {
var allErrs field.ErrorList
if strVal, found := job.Annotations[JobMinParallelismAnnotation]; found {
v, err := strconv.Atoi(strVal)
if err != nil {
allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, job.Annotations[JobMinParallelismAnnotation], err.Error()))
} else {
if int32(v) >= job.podsCount() || v <= 0 {
allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, v, fmt.Sprintf("should be between 0 and %d", job.podsCount()-1)))
}
}
}
return allErrs
}

Expand All @@ -103,12 +126,23 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
return validateUpdate(&Job{oldJob}, &Job{newJob}).ToAggregate()
}

func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
func validateUpdate(oldJob, newJob *Job) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
allErrs = append(allErrs, validatePartialAdmissionUpdate(oldJob, newJob)...)
return allErrs
}

func validatePartialAdmissionUpdate(oldJob, newJob *Job) field.ErrorList {
var allErrs field.ErrorList
if _, found := oldJob.Annotations[JobMinParallelismAnnotation]; found {
if !oldJob.IsSuspended() && pointer.Int32Deref(oldJob.Spec.Parallelism, 1) != pointer.Int32Deref(newJob.Spec.Parallelism, 1) {
allErrs = append(allErrs, field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"))
}
}
return allErrs
}

Expand Down
63 changes: 63 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,37 @@ func TestValidateCreate(t *testing.T) {
field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message),
},
},
{
name: "invalid partial admission annotation (format)",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "NaN").
Obj(),
wantErr: field.ErrorList{
field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"),
},
},
{
name: "invalid partial admission annotation (badValue)",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "5").
Obj(),
wantErr: field.ErrorList{
field.Invalid(minPodsCountAnnotationsPath, 5, "should be between 0 and 3"),
},
},
{
name: "partial admission annotation valid",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: nil,
},
}

for _, tc := range testcases {
Expand Down Expand Up @@ -199,6 +230,38 @@ func TestValidateUpdate(t *testing.T) {
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
{
name: "immutable parallelism while unsuspended with partial admission enabled",
oldJob: testingutil.MakeJob("job", "default").
Suspend(false).
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Parallelism(5).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: field.ErrorList{
field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"),
},
},
{
name: "mutable parallelism while suspended with partial admission enabled",
oldJob: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Parallelism(5).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: nil,
},
}

for _, tc := range testcases {
Expand Down
2 changes: 1 addition & 1 deletion pkg/scheduler/flavorassigner/podSetReducer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
)

func TestPodSetReducer(t *testing.T) {
func TestSearch(t *testing.T) {
cases := map[string]struct {
podSets []kueue.PodSet
countLimit int32
Expand Down
5 changes: 5 additions & 0 deletions pkg/util/testingjobs/job/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ func (j *JobWrapper) OriginalNodeSelectorsAnnotation(content string) *JobWrapper
return j
}

func (j *JobWrapper) SetAnnotation(key, content string) *JobWrapper {
j.Annotations[key] = content
return j
}

// Toleration adds a toleration to the job.
func (j *JobWrapper) Toleration(t corev1.Toleration) *JobWrapper {
j.Spec.Template.Spec.Tolerations = append(j.Spec.Template.Spec.Tolerations, t)
Expand Down
64 changes: 64 additions & 0 deletions test/integration/controller/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/job"
workloadjob "sigs.k8s.io/kueue/pkg/controller/jobs/job"
"sigs.k8s.io/kueue/pkg/util/pointer"
"sigs.k8s.io/kueue/pkg/util/testing"
Expand Down Expand Up @@ -1027,4 +1028,67 @@ var _ = ginkgo.Describe("Job controller interacting with scheduler", func() {
})
})

ginkgo.It("Should schedule jobs when partial admission is enabled", func() {
prodLocalQ = testing.MakeLocalQueue("prod-queue", ns.Name).ClusterQueue(prodClusterQ.Name).Obj()
job1 := testingjob.MakeJob("job1", ns.Name).
Queue(prodLocalQ.Name).
Parallelism(5).
Completions(6).
Request(corev1.ResourceCPU, "2").
Obj()
jobKey := types.NamespacedName{Name: job1.Name, Namespace: job1.Namespace}
wlKey := types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(job1.Name), Namespace: job1.Namespace}

ginkgo.By("creating localQueues")
gomega.Expect(k8sClient.Create(ctx, prodLocalQ)).Should(gomega.Succeed())

ginkgo.By("creating the job")
gomega.Expect(k8sClient.Create(ctx, job1)).Should(gomega.Succeed())

createdJob := &batchv1.Job{}
ginkgo.By("the job should stay suspended", func() {
gomega.Consistently(func() *bool {
gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed())
return createdJob.Spec.Suspend
}, util.ConsistentDuration, util.Interval).Should(gomega.Equal(pointer.Bool(true)))
})

ginkgo.By("enable partial admission", func() {
gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed())
if createdJob.Annotations == nil {
createdJob.Annotations = map[string]string{
job.JobMinParallelismAnnotation: "1",
}
} else {
createdJob.Annotations[job.JobMinParallelismAnnotation] = "1"
}

gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed())
})

wl := &kueue.Workload{}
ginkgo.By("the job should be unsuspended with a lower parallelism", func() {
gomega.Eventually(func() *bool {
gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed())
return createdJob.Spec.Suspend
}, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false)))
gomega.Expect(*createdJob.Spec.Parallelism).To(gomega.BeEquivalentTo(2))

gomega.Expect(k8sClient.Get(ctx, wlKey, wl)).To(gomega.Succeed())
gomega.Expect(wl.Spec.PodSets[0].MinimumCount).ToNot(gomega.BeNil())
gomega.Expect(*wl.Spec.PodSets[0].MinimumCount).To(gomega.BeEquivalentTo(1))
})

ginkgo.By("changing the min parallelism the job should be suspended and its parallelism restored", func() {
gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed())
createdJob.Annotations[job.JobMinParallelismAnnotation] = "4"
gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed())

gomega.Eventually(func() *bool {
gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed())
return createdJob.Spec.Suspend
}, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(true)))
gomega.Expect(*createdJob.Spec.Parallelism).To(gomega.BeEquivalentTo(5))
})
})
})

0 comments on commit d5c958c

Please sign in to comment.