Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"slices"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

Expand Down Expand Up @@ -58,13 +60,15 @@ func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
return &MaxScorePicker{
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
type MaxScorePicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int // maximum number of endpoints to pick
maxNumOfEndpoints int // maximum number of endpoints to pick
randomGenerator *rand.Rand // randomGenerator for randomly pick endpoint on tie-break
}

// WithName sets the picker's name
Expand All @@ -83,6 +87,11 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState,
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
len(scoredPods), scoredPods))

// Shuffle in-place - needed for random tie break when scores are equal
p.randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
})

slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
if i.Score > j.Score {
return -1
Expand Down
25 changes: 21 additions & 4 deletions pkg/epp/scheduling/framework/plugins/picker/picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
k8stypes "k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
Expand All @@ -34,10 +35,11 @@ func TestPickMaxScorePicker(t *testing.T) {
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}

tests := []struct {
name string
picker framework.Picker
input []*types.ScoredPod
output []types.Pod
name string
picker framework.Picker
input []*types.ScoredPod
output []types.Pod
tieBreakCandidates int // tie break is random, specify how many candidate with max score
}{
{
name: "Single max score",
Expand All @@ -63,6 +65,7 @@ func TestPickMaxScorePicker(t *testing.T) {
&types.ScoredPod{Pod: pod1, Score: 50},
&types.ScoredPod{Pod: pod2, Score: 50},
},
tieBreakCandidates: 2,
},
{
name: "Multiple results sorted by highest score, more pods than needed",
Expand Down Expand Up @@ -104,6 +107,7 @@ func TestPickMaxScorePicker(t *testing.T) {
&types.ScoredPod{Pod: pod3, Score: 30},
&types.ScoredPod{Pod: pod2, Score: 25},
},
tieBreakCandidates: 2,
},
}

Expand All @@ -112,6 +116,19 @@ func TestPickMaxScorePicker(t *testing.T) {
result := test.picker.Pick(context.Background(), types.NewCycleState(), test.input)
got := result.TargetPods

if test.tieBreakCandidates > 0 {
testMaxScoredPods := test.output[:test.tieBreakCandidates]
gotMaxScoredPods := got[:test.tieBreakCandidates]
diff := cmp.Diff(testMaxScoredPods, gotMaxScoredPods, cmpopts.SortSlices(func(a, b types.Pod) bool {
return a.String() < b.String() // predictable order within the pods with equal scores
}))
if diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
test.output = test.output[test.tieBreakCandidates:]
got = got[test.tieBreakCandidates:]
}

if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/epp/scheduling/framework/plugins/picker/random_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

Expand Down Expand Up @@ -57,13 +58,15 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker {
return &RandomPicker{
typedName: plugins.TypedName{Type: RandomPickerType, Name: RandomPickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

// RandomPicker picks random pod(s) from the list of candidates.
type RandomPicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int
randomGenerator *rand.Rand // randomGenerator for randomly pick endpoint on tie-break
}

// WithName sets the name of the picker.
Expand All @@ -83,7 +86,7 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods
len(scoredPods), scoredPods))

// Shuffle in-place
rand.Shuffle(len(scoredPods), func(i, j int) {
p.randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
})

Expand Down