Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance/testing-epp/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := test.filter.Filter(context.Background(), test.req, types.NewCycleState(), test.input)
got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (f *HeaderBasedTestingFilter) Type() string {
}

// Filter selects pods that match the IP addresses specified in the request header.
func (f *HeaderBasedTestingFilter) Filter(_ context.Context, request *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
headerValue, ok := request.Headers[headerTestEppEndPointSelectionKey]
if !ok || headerValue == "" {
return []types.Pod{}
Expand Down
13 changes: 6 additions & 7 deletions pkg/epp/common/config/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func (f *test1) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *test1) Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod {
func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
return pods
}

Expand All @@ -482,12 +482,11 @@ func (f *test2) Type() string {
return test2Type
}

func (m *test2) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 {
return map[types.Pod]float64{}
}

func (m *test2) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
}
func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}

// compile-time type validation
var _ framework.Picker = &testPicker{}
Expand All @@ -498,7 +497,7 @@ func (p *testPicker) Type() string {
return testPickerType
}

func (p *testPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
func (p *testPicker) Pick(_ context.Context, _ *types.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult {
return nil
}

Expand All @@ -511,11 +510,11 @@ func (p *testProfileHandler) Type() string {
return testProfileHandlerType
}

func (p *testProfileHandler) Pick(ctx context.Context, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, executionResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
func (p *testProfileHandler) Pick(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*framework.SchedulerProfile, _ map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
return nil
}

func (p *testProfileHandler) ProcessResults(ctx context.Context, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
func (p *testProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
return nil, nil
}

Expand Down
10 changes: 6 additions & 4 deletions pkg/epp/scheduling/framework/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,28 @@ type ProfileHandler interface {
plugins.Plugin
// Pick selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties
// and the previously executed SchedluderProfile cycles along with their results.
Pick(ctx context.Context, request *types.LLMRequest, profiles map[string]*SchedulerProfile, profileResults map[string]*types.ProfileRunResult) map[string]*SchedulerProfile
Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile,
profileResults map[string]*types.ProfileRunResult) map[string]*SchedulerProfile

// ProcessResults handles the outcome of the profile runs after all profiles ran.
// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the
// key of the primary profile that should be used to get the request selected destination.
// When a profile run fails, its result in the profileResults map is nil.
ProcessResults(ctx context.Context, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error)
ProcessResults(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest,
profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error)
}

// Filter defines the interface for filtering a list of pods based on context.
type Filter interface {
plugins.Plugin
Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod
Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod
}

// Scorer defines the interface for scoring a list of pods based on context.
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
type Scorer interface {
plugins.Plugin
Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64
Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64
}

// Picker picks the final pod(s) to send the request to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ func (f *DecisionTreeFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *DecisionTreeFilter) Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod {
func (f *DecisionTreeFilter) Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
filteredPod := f.Current.Filter(ctx, request, cycleState, pods)
filteredPod := f.Current.Filter(ctx, cycleState, request, pods)

next := f.NextOnSuccessOrFailure
if len(filteredPod) > 0 {
Expand All @@ -71,7 +71,7 @@ func (f *DecisionTreeFilter) Filter(ctx context.Context, request *types.LLMReque
}
loggerTrace.Info("Filter succeeded", "filter", f.Type(), "next", next.Type(), "filteredPodCount", len(filteredPod))
// On success, pass the filtered result to the next filter.
return next.Filter(ctx, request, cycleState, filteredPod)
return next.Filter(ctx, cycleState, request, filteredPod)
} else {
if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil {
// No succeeding filters to run, return.
Expand All @@ -82,6 +82,6 @@ func (f *DecisionTreeFilter) Filter(ctx context.Context, request *types.LLMReque
}
loggerTrace.Info("Filter failed", "filter", f.Type(), "next", next.Type())
// On failure, pass the initial set of pods to the next filter.
return next.Filter(ctx, request, cycleState, pods)
return next.Filter(ctx, cycleState, request, pods)
}
}
6 changes: 3 additions & 3 deletions pkg/epp/scheduling/framework/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (f *filterAll) Type() string {
return "filter-all"
}

func (f *filterAll) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *filterAll) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
return []types.Pod{}
}

Expand Down Expand Up @@ -138,7 +138,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := test.filter.Filter(context.Background(), test.req, types.NewCycleState(), test.input)
got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
Expand Down Expand Up @@ -206,7 +206,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
LoraAffinityFilter := NewLoraAffinityFilter(config.Conf.LoraAffinityThreshold)

for range numIterations {
result := LoraAffinityFilter.Filter(context.Background(), req, types.NewCycleState(), pods)
result := LoraAffinityFilter.Filter(context.Background(), types.NewCycleState(), req, pods)

// Check which type of pod was returned
if len(result) != 1 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (f *LeastKVCacheFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

min := math.MaxFloat64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (f *LeastQueueFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LeastQueueFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *LeastQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

min := math.MaxInt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (f *LoraAffinityFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LoraAffinityFilter) Filter(_ context.Context, request *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *LoraAffinityFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
// Pre-allocate slices with estimated capacity
filtered_affinity := make([]types.Pod, 0, len(pods))
filtered_available := make([]types.Pod, 0, len(pods))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (f *LowQueueFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LowQueueFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (f *LowQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

for _, pod := range pods {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (m *Plugin) Type() string {
}

// Score returns the scoring result for the given list of pods based on context.
func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
Expand Down
12 changes: 6 additions & 6 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestPrefixPlugin(t *testing.T) {
Prompt: "aaaaaa",
}
cycleState1 := types.NewCycleState()
scores := plugin.Score(context.Background(), req1, cycleState1, pods)
scores := plugin.Score(context.Background(), cycleState1, req1, pods)
state, err := plugin.getPrefixState(cycleState1)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
Expand All @@ -70,7 +70,7 @@ func TestPrefixPlugin(t *testing.T) {
Prompt: "bbbbbb",
}
cycleState2 := types.NewCycleState()
scores = plugin.Score(context.Background(), req2, cycleState2, pods)
scores = plugin.Score(context.Background(), cycleState2, req2, pods)
state, err = plugin.getPrefixState(cycleState2)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
Expand All @@ -90,7 +90,7 @@ func TestPrefixPlugin(t *testing.T) {
Prompt: "aaaabbbb",
}
cycleState3 := types.NewCycleState()
scores = plugin.Score(context.Background(), req3, cycleState3, pods)
scores = plugin.Score(context.Background(), cycleState3, req3, pods)
state, err = plugin.getPrefixState(cycleState3)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
Expand All @@ -109,7 +109,7 @@ func TestPrefixPlugin(t *testing.T) {
Prompt: "aaaabbbb",
}
cycleState4 := types.NewCycleState()
scores = plugin.Score(context.Background(), req4, cycleState4, pods)
scores = plugin.Score(context.Background(), cycleState4, req4, pods)
state, err = plugin.getPrefixState(cycleState4)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
Expand All @@ -128,7 +128,7 @@ func TestPrefixPlugin(t *testing.T) {
Prompt: "aaaabbbbcccc",
}
cycleState5 := types.NewCycleState()
scores = plugin.Score(context.Background(), req5, cycleState5, pods)
scores = plugin.Score(context.Background(), cycleState5, req5, pods)
state, err = plugin.getPrefixState(cycleState5)
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
Expand Down Expand Up @@ -179,7 +179,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {

// First cycle: simulate scheduling and insert prefix info into the cache
cycleState := types.NewCycleState()
plugin.Score(context.Background(), req, cycleState, pods)
plugin.Score(context.Background(), cycleState, req, pods)
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod})

// Second cycle: validate internal state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (h *SingleProfileHandler) Type() string {

// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the
// previously executed cycles along with their results.
func (h *SingleProfileHandler) Pick(_ context.Context, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile,
func (h *SingleProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile,
profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call
return map[string]*framework.SchedulerProfile{}
Expand All @@ -67,7 +67,7 @@ func (h *SingleProfileHandler) Pick(_ context.Context, request *types.LLMRequest
// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the
// key of the primary profile that should be used to get the request selected destination.
// When a profile run fails, its result in the profileResults map is nil.
func (h *SingleProfileHandler) ProcessResults(_ context.Context, _ *types.LLMRequest,
func (h *SingleProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, _ *types.LLMRequest,
profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
if len(profileResults) != 1 {
return nil, errors.New("single profile handler is intended to be used with a single profile, failed to process multiple profiles")
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/framework/plugins/scorer/kvcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (s *KVCacheScorer) Type() string {
}

// Score returns the scoring result for the given list of pods based on context.
func (s *KVCacheScorer) Score(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
func (s *KVCacheScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
scores := make(map[types.Pod]float64, len(pods))
for _, pod := range pods {
scores[pod] = 1 - pod.GetMetrics().KVCacheUsagePercent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestKvCacheScorer(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scorer := &KVCacheScorer{}
scores := scorer.Score(context.Background(), &types.LLMRequest{}, types.NewCycleState(), test.pods)
scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods)

for i, pod := range test.pods {
expectedScore := test.expectedScoresPod[i]
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/framework/plugins/scorer/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *QueueScorer) Type() string {
}

// Score returns the scoring result for the given list of pods based on context.
func (s *QueueScorer) Score(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
minQueueSize := math.MaxInt
maxQueueSize := math.MinInt

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/framework/plugins/scorer/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scores := scorer.Score(context.Background(), &types.LLMRequest{}, types.NewCycleState(), test.pods)
scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods)

for i, pod := range test.pods {
expectedScore := test.expectedScoresPod[i]
Expand Down
4 changes: 2 additions & 2 deletions pkg/epp/scheduling/framework/scheduler_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.
for _, filter := range p.filters {
loggerDebug.Info("Running filter plugin", "plugin", filter.Type())
before := time.Now()
filteredPods = filter.Filter(ctx, request, cycleState, filteredPods)
filteredPods = filter.Filter(ctx, cycleState, request, filteredPods)
metrics.RecordSchedulerPluginProcessingLatency(FilterPluginType, filter.Type(), time.Since(before))
loggerDebug.Info("Filter plugin result", "plugin", filter.Type(), "pods", filteredPods)
if len(filteredPods) == 0 {
Expand All @@ -153,7 +153,7 @@ func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.
for _, scorer := range p.scorers {
loggerDebug.Info("Running scorer", "scorer", scorer.Type())
before := time.Now()
scores := scorer.Score(ctx, request, cycleState, pods)
scores := scorer.Score(ctx, cycleState, request, pods)
metrics.RecordSchedulerPluginProcessingLatency(ScorerPluginType, scorer.Type(), time.Since(before))
for pod, score := range scores { // weight is relative to the sum of weights
weightedScorePerPod[pod] += score * float64(scorer.Weight())
Expand Down
4 changes: 2 additions & 2 deletions pkg/epp/scheduling/framework/scheduler_profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ type testPlugin struct {

func (tp *testPlugin) Type() string { return tp.TypeRes }

func (tp *testPlugin) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
func (tp *testPlugin) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
tp.FilterCallCount++
return findPods(pods, tp.FilterRes...)

}

func (tp *testPlugin) Score(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
tp.ScoreCallCount++
scoredPods := make(map[types.Pod]float64, len(pods))
for _, pod := range pods {
Expand Down
4 changes: 2 additions & 2 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest) (*t

for { // get the next set of profiles to run iteratively based on the request and the previous execution results
before := time.Now()
profiles := s.profileHandler.Pick(ctx, request, s.profiles, profileRunResults)
profiles := s.profileHandler.Pick(ctx, cycleState, request, s.profiles, profileRunResults)
metrics.RecordSchedulerPluginProcessingLatency(framework.ProfilePickerType, s.profileHandler.Type(), time.Since(before))
if len(profiles) == 0 { // profile picker didn't pick any profile to run
break
Expand All @@ -136,7 +136,7 @@ func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest) (*t
}

before := time.Now()
result, err := s.profileHandler.ProcessResults(ctx, request, profileRunResults)
result, err := s.profileHandler.ProcessResults(ctx, cycleState, request, profileRunResults)
metrics.RecordSchedulerPluginProcessingLatency(framework.ProcessProfilesResultsType, s.profileHandler.Type(), time.Since(before))

return result, err
Expand Down