diff --git a/pkg/controller/jobframework/constants.go b/pkg/controller/jobframework/constants.go index 8fbbec5a1d..775f2ee2c1 100644 --- a/pkg/controller/jobframework/constants.go +++ b/pkg/controller/jobframework/constants.go @@ -37,5 +37,13 @@ const ( // node selectors are recorded upon a workload admission. This information // will be used to restore them when the job is suspended. // The content is a json marshaled slice of selectors. + // + // DEPRECATED: Use OriginalPodSetsInfoAnnotation. OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors" + + // OriginalPodSetsInfoAnnotation is the annotation in which the original + // node selectors and podSet counts are recorded upon a workload admission. + // This information will be used to restore them when the job is suspended. + // The content is a json marshaled slice of PodSetInfo. + OriginalPodSetsInfoAnnotation = "kueue.x-k8s.io/original-pod-sets-info" ) diff --git a/pkg/controller/jobframework/interface.go b/pkg/controller/jobframework/interface.go index 2f6dfba821..d6590d7c46 100644 --- a/pkg/controller/jobframework/interface.go +++ b/pkg/controller/jobframework/interface.go @@ -31,10 +31,10 @@ type GenericJob interface { // ResetStatus will reset the job status to the original state. // If true, status is modified, if not, status is as it was. ResetStatus() bool - // RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job. - RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector) - // RestoreNodeAffinity will restore the original node affinity of job. - RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector) + // RunWithPodSetsInfo will inject the node affinity extracting from workload to job and unsuspend the job. + RunWithPodSetsInfo(nodeSelectors []PodSetInfo) + // RestorePodSetsInfo will restore the original node affinity of job. + RestorePodSetsInfo(nodeSelectors []PodSetInfo) // Finished means whether the job is completed/failed or not, // condition represents the workload finished condition. Finished() (condition metav1.Condition, finished bool) diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 777a16ec86..f8997927b1 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -16,6 +16,7 @@ package jobframework import ( "context" "encoding/json" + "errors" "fmt" corev1 "k8s.io/api/core/v1" @@ -32,11 +33,13 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/constants" utilpriority "sigs.k8s.io/kueue/pkg/util/priority" + utilslice "sigs.k8s.io/kueue/pkg/util/slice" "sigs.k8s.io/kueue/pkg/workload" ) var ( - errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation) + errPodSetsInfoNotFound = fmt.Errorf("annotation %s or %s not found", OriginalNodeSelectorsAnnotation, OriginalPodSetsInfoAnnotation) + errUnknownPodSetName = errors.New("unknown podSet name") ) // JobReconciler reconciles a GenericJob object @@ -332,17 +335,17 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec // startJob will unsuspend the job, and also inject the node affinity. func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error { - //get the original selectors and store them in the job object - originalSelectors := r.getNodeSelectorsFromPodSets(wl) - if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil { - return fmt.Errorf("startJob, record original node selectors: %w", err) + //get the original podSetsInfo and store them in the job object + originalPodSetsInfo := r.getPodSetsInfoFromSpec(wl) + if err := setNodeSelectorsInAnnotation(object, originalPodSetsInfo); err != nil { + return fmt.Errorf("startJob, record original podSetsInfo: %w", err) } - nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl) + info, err := r.getPodSetsInfoFromAdmission(ctx, wl) if err != nil { return err } - job.RunWithNodeAffinity(nodeSelectors) + job.RunWithPodSetsInfo(info) if err := r.client.Update(ctx, object); err != nil { return err @@ -372,12 +375,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie } } - log.V(3).Info("restore node selectors from annotation") - selectors, err := getNodeSelectorsFromObjectAnnotation(object) + log.V(3).Info("restore podSetsInfo from annotation") + info, err := getPodSetsInfoFromObjectAnnotation(object, wl.Spec.PodSets) if err != nil { - log.V(3).Error(err, "Unable to get original node selectors") + log.V(3).Error(err, "Unable to get original podSetsInfo") } else { - job.RestoreNodeAffinity(selectors) + job.RestorePodSetsInfo(info) return r.client.Update(ctx, object) } @@ -412,24 +415,26 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o return wl, nil } -type PodSetNodeSelector struct { +type PodSetInfo struct { Name string `json:"name"` NodeSelector map[string]string `json:"nodeSelector"` + Count int32 `json:"count"` } -// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads. -func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) { +// getPodSetsInfoFromAdmission will extract podSetsInfo and podSets count from admitted workloads. +func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) { if len(w.Status.Admission.PodSetAssignments) == 0 { return nil, nil } - nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments)) + nodeSelectors := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments)) for i, podSetFlavor := range w.Status.Admission.PodSetAssignments { processedFlvs := sets.NewString() - nodeSelector := PodSetNodeSelector{ + nodeSelector := PodSetInfo{ Name: podSetFlavor.Name, NodeSelector: make(map[string]string), + Count: podSetFlavor.Count, } for _, flvRef := range podSetFlavor.Flavors { flvName := string(flvRef) @@ -452,18 +457,19 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku return nodeSelectors, nil } -// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets. -func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector { +// getPodSetsInfoFromSpec will extract podSetsInfo and podSet's counts from a workload's spec. +func (r *JobReconciler) getPodSetsInfoFromSpec(w *kueue.Workload) []PodSetInfo { podSets := w.Spec.PodSets if len(podSets) == 0 { return nil } - ret := make([]PodSetNodeSelector, len(podSets)) + ret := make([]PodSetInfo, len(podSets)) for psi := range podSets { ps := &podSets[psi] - ret[psi] = PodSetNodeSelector{ + ret[psi] = PodSetInfo{ Name: ps.Name, NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector), + Count: ps.Count, } } return ret @@ -519,34 +525,51 @@ func cloneNodeSelector(src map[string]string) map[string]string { return ret } -// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the +// getPodSetsInfoFromObjectAnnotation tries to retrieve a podSetsInfo slice from the // object's annotations fails if it's not found or is unable to unmarshal -func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) { - str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] +func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet) ([]PodSetInfo, error) { + hasCounts := true + str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation] if !found { - return nil, errNodeSelectorsNotFound + hasCounts = false + str, found = obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] + if !found { + return nil, errPodSetsInfoNotFound + } } // unmarshal - ret := []PodSetNodeSelector{} + ret := []PodSetInfo{} if err := json.Unmarshal([]byte(str), &ret); err != nil { return nil, err } + + if !hasCounts { + psMap := utilslice.ToRefMap(spec, func(ps *kueue.PodSet) string { return ps.Name }) + for i := range ret { + info := &ret[i] + ps, found := psMap[info.Name] + if !found { + return nil, fmt.Errorf("%w: %s", errUnknownPodSetName, info.Name) + } + info.Count = ps.Count + } + } return ret, nil } -// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into +// setNodeSelectorsInAnnotation - sets an annotation containing the provided podSetsInfo into // a job object, even if very unlikely it could return an error related to json.marshaling -func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error { - nodeSelectorsBytes, err := json.Marshal(nodeSelectors) +func setNodeSelectorsInAnnotation(obj client.Object, info []PodSetInfo) error { + nodeSelectorsBytes, err := json.Marshal(info) if err != nil { return err } annotations := obj.GetAnnotations() if annotations == nil { - annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)} + annotations = map[string]string{OriginalPodSetsInfoAnnotation: string(nodeSelectorsBytes)} } else { - annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes) + annotations[OriginalPodSetsInfoAnnotation] = string(nodeSelectorsBytes) } obj.SetAnnotations(annotations) return nil diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index 4a745777c9..74b763f31c 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -29,6 +29,7 @@ var ( queueNameLabelPath = labelsPath.Key(QueueLabel) originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation) + originalPodSetsInfosWorkloadKeyPath = annotationsPath.Key(OriginalPodSetsInfoAnnotation) ) func ValidateCreateForQueueName(job GenericJob) field.ErrorList { @@ -83,10 +84,26 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) } } else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found { - out := []PodSetNodeSelector{} + out := []PodSetInfo{} if err := json.Unmarshal([]byte(av), &out); err != nil { allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error())) } } return allErrs } + +func ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob GenericJob) field.ErrorList { + var allErrs field.ErrorList + if oldJob.IsSuspended() == newJob.IsSuspended() { + if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation], + newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation], originalPodSetsInfosWorkloadKeyPath); len(errList) > 0 { + allErrs = append(allErrs, field.Forbidden(originalPodSetsInfosWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) + } + } else if av, found := newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation]; found { + out := []PodSetInfo{} + if err := json.Unmarshal([]byte(av), &out); err != nil { + allErrs = append(allErrs, field.Invalid(originalPodSetsInfosWorkloadKeyPath, av, err.Error())) + } + } + return allErrs +} diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index c134317c4e..ceaa46de87 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -180,7 +180,7 @@ func (j *Job) PodSets() []kueue.PodSet { } } -func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { j.Spec.Suspend = pointer.Bool(false) if len(nodeSelectors) == 0 { return @@ -195,7 +195,7 @@ func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelecto } } -func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) { return } diff --git a/pkg/controller/jobs/job/job_webhook.go b/pkg/controller/jobs/job/job_webhook.go index f3d4d30c68..5d63896e38 100644 --- a/pkg/controller/jobs/job/job_webhook.go +++ b/pkg/controller/jobs/job/job_webhook.go @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime. func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList { allErrs := validateCreate(newJob) allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...) return allErrs diff --git a/pkg/controller/jobs/mpijob/mpijob_controller.go b/pkg/controller/jobs/mpijob/mpijob_controller.go index 856d6801a0..4012304de4 100644 --- a/pkg/controller/jobs/mpijob/mpijob_controller.go +++ b/pkg/controller/jobs/mpijob/mpijob_controller.go @@ -121,7 +121,7 @@ func (j *MPIJob) PodSets() []kueue.PodSet { return podSets } -func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *MPIJob) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { j.Spec.RunPolicy.Suspend = pointer.Bool(false) if len(nodeSelectors) == 0 { return @@ -144,7 +144,7 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSele } } -func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *MPIJob) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { orderedReplicaTypes := orderedReplicaTypes(&j.Spec) for index, nodeSelector := range nodeSelectors { replicaType := orderedReplicaTypes[index] diff --git a/pkg/controller/jobs/mpijob/mpijob_webhook.go b/pkg/controller/jobs/mpijob/mpijob_webhook.go index b66227157b..b14851c71a 100644 --- a/pkg/controller/jobs/mpijob/mpijob_webhook.go +++ b/pkg/controller/jobs/mpijob/mpijob_webhook.go @@ -88,6 +88,7 @@ func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runti log := ctrl.LoggerFrom(ctx).WithName("job-webhook") log.Info("Validating update", "job", klog.KObj(newJob)) allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldGenJob, newGenJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...) return allErrs.ToAggregate() }