From 64e4c7586667bc56e37390aafba50c958f8fa1ba Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 26 Sep 2023 21:20:51 +0000 Subject: [PATCH] add pod reconciler --- main.go | 11 ++ pkg/controllers/pod_controller.go | 181 ++++++++++++++++++++++++++++ pkg/util/collections/collections.go | 7 ++ 3 files changed, 199 insertions(+) create mode 100644 pkg/controllers/pod_controller.go diff --git a/main.go b/main.go index a97e76222..b8bead5ee 100644 --- a/main.go +++ b/main.go @@ -130,16 +130,27 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}) { <-certsReady setupLog.Info("certs ready") + // Set up JobSet controller. jobSetController := controllers.NewJobSetReconciler(mgr.GetClient(), mgr.GetScheme(), mgr.GetEventRecorderFor("jobset")) if err := jobSetController.SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "JobSet") os.Exit(1) } + + // Set up pod reconciler. + podController := controllers.NewPodReconciler(mgr.GetClient(), mgr.GetScheme(), mgr.GetEventRecorderFor("pod")) + if err := podController.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "Pod") + os.Exit(1) + } + + // Set up validating/defaulting webhook. if err := (&jobset.JobSet{}).SetupWebhookWithManager(mgr); err != nil { setupLog.Error(err, "unable to create webhook", "webhook", "JobSet") os.Exit(1) } + // Set up mutating webhook. mutatingWebhook := &jobset.PodAnnotator{Client: mgr.GetClient()} if err := mutatingWebhook.SetupWebhookWithManager(mgr); err != nil { setupLog.Error(err, "unable to create webhook", "webhook", "JobSet") diff --git a/pkg/controllers/pod_controller.go b/pkg/controllers/pod_controller.go new file mode 100644 index 000000000..3ac369536 --- /dev/null +++ b/pkg/controllers/pod_controller.go @@ -0,0 +1,181 @@ +/* +Copyright 2023 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controllers + +import ( + "context" + "fmt" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/jobset/pkg/util/collections" + + jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2" +) + +const ( + podOwnerKey = ".metadata.controller" + nodePoolKey = "cloud.google.com/gke-nodepool" +) + +var exclusiveGateConditionType = corev1.PodConditionType(jobset.ExclusiveGate) + +// PodReconciler reconciles a JobSet object +type PodReconciler struct { + client.Client + Scheme *runtime.Scheme + Record record.EventRecorder +} + +func NewPodReconciler(client client.Client, scheme *runtime.Scheme, record record.EventRecorder) *PodReconciler { + return &PodReconciler{Client: client, Scheme: scheme, Record: record} +} + +// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch +// +kubebuilder:rbac:groups=core,resources=pods,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=core,resources=nodes,verbs=get;list;watch;create;update;patch;delete +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + // Get JobSet from apiserver. + var pod corev1.Pod + if err := r.Get(ctx, req.NamespacedName, &pod); err != nil { + // we'll ignore not-found errors, since there is nothing we can do here. + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + // TODO: delete logging once we confirmed it works as it will be too noisy. + log := ctrl.LoggerFrom(ctx).WithValues("pod", klog.KObj(&pod)) + ctx = ctrl.LoggerInto(ctx, log) + log.V(2).Info("Reconciling Pod") + + // Check if this is the "leader" pod (completion index 0) and if it has been scheduled. + if pod.Annotations[batchv1.JobCompletionIndexAnnotation] == "0" && pod.Spec.NodeName != "" { + nodePool, err := r.nodePoolFromNode(ctx, pod.Spec.NodeName, pod.Namespace) + if err != nil { + return ctrl.Result{}, err + } + if err := r.assignGatedPodsToNodePool(ctx, &pod, nodePool); err != nil { + return ctrl.Result{}, err + } + } + return ctrl.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager. +func (r *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&corev1.Pod{}). + Complete(r) +} + +func SetupPodReconcilerIndexes(ctx context.Context, indexer client.FieldIndexer) error { + return indexer.IndexField(ctx, &corev1.Pod{}, podOwnerKey, func(obj client.Object) []string { + o := obj.(*corev1.Pod) + owner := metav1.GetControllerOf(o) + if owner == nil { + return nil + } + // ...make sure it's a Pod owned by a Job... + if owner.Kind != "Job" || owner.APIVersion != batchv1.SchemeGroupVersion.String() { + return nil + } + return []string{owner.Name} + }) +} + +// nodePoolFromNode gets the node pool name that the given node is part of. +func (r *PodReconciler) nodePoolFromNode(ctx context.Context, nodeName, ns string) (string, error) { + log := ctrl.LoggerFrom(ctx) + + var node corev1.Node + if err := r.Get(ctx, types.NamespacedName{Name: nodeName, Namespace: ns}, &node); err != nil { + // we'll ignore not-found errors, since there is nothing we can do here. + log.Error(err, fmt.Sprintf("getting node %s", nodeName)) + return "", client.IgnoreNotFound(err) + } + nodePool, exists := node.Labels[nodePoolKey] + if !exists { + log.V(2).Error(nil, fmt.Sprintf("missing node pool label: %s", nodePoolKey)) + } + return nodePool, nil +} + +func (r *PodReconciler) assignGatedPodsToNodePool(ctx context.Context, leaderPod *corev1.Pod, nodePool string) error { + // Get name of job that owns the leader pod. + jobName, err := jobNameFromPod(leaderPod) + if err != nil { + return err + } + + // Get all pods owned by this job. + var podList corev1.PodList + if err := r.List(ctx, &podList, client.InNamespace(leaderPod.Namespace), client.MatchingFields{podOwnerKey: jobName}); err != nil { + return err + } + + // For every pod besides pod index 0, add a node affinity for this node pool and remove the scheduling gate. + for _, pod := range podList.Items { + if pod.Annotations[batchv1.JobCompletionIndexAnnotation] == "0" { + continue + } + addNodePoolNodeSelector(&pod, nodePool) + removeSchedulingGate(&pod) + } + return nil +} + +func jobNameFromPod(pod *corev1.Pod) (string, error) { + var jobName string + for _, ownerReference := range pod.OwnerReferences { + if ownerReference.Kind == "Job" { + jobName = ownerReference.Name + } + } + if jobName == "" { + return "", fmt.Errorf("unable to find name of job owning pod %s", pod.Name) + } + return jobName, nil +} + +func addNodePoolNodeSelector(pod *corev1.Pod, nodePool string) { + // Add node selector for node pool label. + if pod.Spec.NodeSelector == nil { + pod.Spec.NodeSelector = make(map[string]string) + } + pod.Spec.NodeSelector[nodePoolKey] = nodePool +} + +func removeSchedulingGate(pod *corev1.Pod) { + // Remove scheduling gate. + removeIdx := -1 + for i, gate := range pod.Spec.ReadinessGates { + if gate.ConditionType == exclusiveGateConditionType { + removeIdx = i + break + } + } + // If gate already removed, continue. + if removeIdx == -1 { + return + } + pod.Spec.ReadinessGates = collections.Remove(pod.Spec.ReadinessGates, removeIdx) +} diff --git a/pkg/util/collections/collections.go b/pkg/util/collections/collections.go index 8fa2a12a2..ea621617f 100644 --- a/pkg/util/collections/collections.go +++ b/pkg/util/collections/collections.go @@ -38,3 +38,10 @@ func Contains[T comparable](slice []T, element T) bool { } return false } + +func Remove[T any](slice []T, index int) []T { + if index < 0 || index >= len(slice) { + return slice + } + return append(slice[:index], slice[index+1:]...) +}