Skip to content

Commit

Permalink
TAS: Introduce scheduling gate utils (#3234)
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo authored Oct 15, 2024
1 parent 167380e commit c9710e3
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 35 deletions.
51 changes: 20 additions & 31 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ import (
"sigs.k8s.io/kueue/pkg/util/kubeversion"
"sigs.k8s.io/kueue/pkg/util/maps"
"sigs.k8s.io/kueue/pkg/util/parallelize"
utilpod "sigs.k8s.io/kueue/pkg/util/pod"
utilslices "sigs.k8s.io/kueue/pkg/util/slices"
)

const (
SchedulingGateName = "kueue.x-k8s.io/admission"
FrameworkName = "pod"
gateNotFound = -1
ConditionTypeTerminationTarget = "TerminationTarget"
errMsgIncorrectGroupRoleCount = "pod group can't include more than 8 roles"
IsGroupWorkloadAnnotationKey = "kueue.x-k8s.io/is-group-workload"
Expand Down Expand Up @@ -177,23 +177,12 @@ func (p *Pod) Object() client.Object {
return &p.pod
}

// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod.
// If the scheduling gate is not found, returns -1.
func gateIndex(p *corev1.Pod) int {
for i := range p.Spec.SchedulingGates {
if p.Spec.SchedulingGates[i].Name == SchedulingGateName {
return i
}
}
return gateNotFound
}

func isPodTerminated(p *corev1.Pod) bool {
return p.Status.Phase == corev1.PodFailed || p.Status.Phase == corev1.PodSucceeded
}

func podSuspended(p *corev1.Pod) bool {
return isPodTerminated(p) || gateIndex(p) != gateNotFound
return isPodTerminated(p) || isGated(p)
}

func isUnretriablePod(pod corev1.Pod) bool {
Expand Down Expand Up @@ -238,18 +227,6 @@ func (p *Pod) Suspend() {
// Not implemented because this is not called when JobWithCustomStop is implemented.
}

// ungatePod removes the kueue scheduling gate from the pod.
// Returns true if the pod has been ungated and false otherwise.
func ungatePod(pod *corev1.Pod) bool {
idx := gateIndex(pod)
if idx != gateNotFound {
pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates[:idx], pod.Spec.SchedulingGates[idx+1:]...)
return true
}

return false
}

// Run will inject the node affinity and podSet counts extracting from workload to job and unsuspend it.
func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.PodSetInfo, recorder record.EventRecorder, msg string) error {
log := ctrl.LoggerFrom(ctx)
Expand All @@ -259,12 +236,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod
return fmt.Errorf("%w: expecting 1 pod set got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo))
}

if gateIndex(&p.pod) == gateNotFound {
if !isGated(&p.pod) {
return nil
}

if err := clientutil.Patch(ctx, c, &p.pod, true, func() (bool, error) {
ungatePod(&p.pod)
ungate(&p.pod)
return true, podset.Merge(&p.pod.ObjectMeta, &p.pod.Spec, podSetsInfo[0])
}); err != nil {
return err
Expand All @@ -280,12 +257,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod
return parallelize.Until(ctx, len(p.list.Items), func(i int) error {
pod := &p.list.Items[i]

if gateIndex(pod) == gateNotFound {
if !isGated(pod) {
return nil
}

if err := clientutil.Patch(ctx, c, pod, true, func() (bool, error) {
ungatePod(pod)
ungate(pod)

roleHash, err := getRoleHash(*pod)
if err != nil {
Expand Down Expand Up @@ -854,8 +831,8 @@ func sortActivePods(activePods []corev1.Pod) {
if iFin != jFin {
return iFin
}
iGated := gateIndex(pi) != gateNotFound
jGated := gateIndex(pj) != gateNotFound
iGated := isGated(pi)
jGated := isGated(pj)
// Prefer to keep pods that aren't gated.
if iGated != jGated {
return !iGated
Expand Down Expand Up @@ -1354,3 +1331,15 @@ func IsPodOwnerManagedByKueue(p *Pod) bool {
func GetWorkloadNameForPod(podName string, podUID types.UID) string {
return jobframework.GetWorkloadNameForOwnerWithGVK(podName, podUID, gvk)
}

func isGated(pod *corev1.Pod) bool {
return utilpod.HasGate(pod, SchedulingGateName)
}

func ungate(pod *corev1.Pod) bool {
return utilpod.Ungate(pod, SchedulingGateName)
}

func gate(pod *corev1.Pod) bool {
return utilpod.Gate(pod, SchedulingGateName)
}
5 changes: 1 addition & 4 deletions pkg/controller/jobs/pod/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,7 @@ func (w *PodWebhook) Default(ctx context.Context, obj runtime.Object) error {
}
pod.pod.Labels[ManagedLabelKey] = ManagedLabelValue

if gateIndex(&pod.pod) == gateNotFound {
log.V(5).Info("Adding gate")
pod.pod.Spec.SchedulingGates = append(pod.pod.Spec.SchedulingGates, corev1.PodSchedulingGate{Name: SchedulingGateName})
}
gate(&pod.pod)

if podGroupName(pod.pod) != "" {
if err := pod.addRoleHash(); err != nil {
Expand Down
58 changes: 58 additions & 0 deletions pkg/util/pod/pod.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
CCopyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package pod

import (
"slices"

corev1 "k8s.io/api/core/v1"
)

// HasGate checks if the pod has a scheduling gate with a specified name.
func HasGate(pod *corev1.Pod, gateName string) bool {
return gateIndex(pod, gateName) >= 0
}

// Ungate removes scheduling gate from the Pod if present.
// Returns true if the pod has been updated and false otherwise.
func Ungate(pod *corev1.Pod, gateName string) bool {
if idx := gateIndex(pod, gateName); idx >= 0 {
pod.Spec.SchedulingGates = slices.Delete(pod.Spec.SchedulingGates, idx, idx+1)
return true
}
return false
}

// Gate adds scheduling gate from the Pod if present.
// Returns true if the pod has been updated and false otherwise.
func Gate(pod *corev1.Pod, gateName string) bool {
if !HasGate(pod, gateName) {
pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates, corev1.PodSchedulingGate{
Name: gateName,
})
return true
}
return false
}

// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod.
// If the scheduling gate is not found, returns -1.
func gateIndex(p *corev1.Pod, gateName string) int {
return slices.IndexFunc(p.Spec.SchedulingGates, func(g corev1.PodSchedulingGate) bool {
return g.Name == gateName
})
}
199 changes: 199 additions & 0 deletions pkg/util/pod/pod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
CCopyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package pod

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
)

func TestHasGate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
want bool
}{
"scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
want: true,
},
"another gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
want: false,
},
"no scheduling gates": {
pod: corev1.Pod{},
want: false,
},
}

for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := HasGate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
})
}
}

func TestUngate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
wantPod corev1.Pod
want bool
}{
"ungate when scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
wantPod: corev1.Pod{},
want: true,
},
"ungate when scheduling gate missing": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
want: false,
},
}
for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := Ungate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("Unexpected scheduling gates\ndiff=%s", diff)
}
})
}
}

func TestGate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
wantPod corev1.Pod
want bool
}{
"gate when scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
want: false,
},
"gate when scheduling gate missing": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
{
Name: "example.com/gate",
},
},
},
},
want: true,
},
}

for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := Gate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates); diff != "" {
t.Errorf("Unexpected scheduling gates\ndiff=%s", diff)
}
})
}
}

0 comments on commit c9710e3

Please sign in to comment.