Skip to content

Commit 9660d65

Browse files
committed
Add subsetting logic for epp
1 parent 4888da5 commit 9660d65

File tree

10 files changed

+337
-23
lines changed

10 files changed

+337
-23
lines changed

cmd/epp/runner/runner.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import (
4444
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector"
4545
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
4646
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
47+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter"
4748
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
4849
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
4950
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
@@ -270,11 +271,13 @@ func (r *Runner) initializeScheduler(datastore datastore.Datastore) (*scheduling
270271
if schedulerV2 {
271272
queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog)
272273
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)
274+
endpointSubsetFilter := filter.NewSubsetFilter()
273275

274276
schedulerProfile := framework.NewSchedulerProfile().
275277
WithScorers(framework.NewWeightedScorer(&scorer.QueueScorer{}, queueScorerWeight),
276278
framework.NewWeightedScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight)).
277-
WithPicker(picker.NewMaxScorePicker())
279+
WithPicker(picker.NewMaxScorePicker()).
280+
WithFilters(endpointSubsetFilter)
278281

279282
if prefixCacheScheduling {
280283
prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog)

pkg/epp/handlers/server.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/go-logr/logr"
2929
"google.golang.org/grpc/codes"
3030
"google.golang.org/grpc/status"
31+
"google.golang.org/protobuf/types/known/structpb"
3132
"sigs.k8s.io/controller-runtime/pkg/log"
3233
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
@@ -111,8 +112,9 @@ type RequestContext struct {
111112
}
112113

113114
type Request struct {
114-
Headers map[string]string
115-
Body map[string]interface{}
115+
Headers map[string]string
116+
Body map[string]interface{}
117+
FilterMetadata map[string]*structpb.Struct
116118
}
117119
type Response struct {
118120
Headers map[string]string
@@ -141,8 +143,9 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
141143
reqCtx := &RequestContext{
142144
RequestState: RequestReceived,
143145
Request: &Request{
144-
Headers: make(map[string]string),
145-
Body: make(map[string]interface{}),
146+
Headers: make(map[string]string),
147+
Body: make(map[string]interface{}),
148+
FilterMetadata: make(map[string]*structpb.Struct),
146149
},
147150
Response: &Response{
148151
Headers: make(map[string]string),
@@ -185,6 +188,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
185188
return status.Errorf(codes.Unknown, "cannot receive stream request: %v", err)
186189
}
187190

191+
reqCtx.Request.FilterMetadata = req.GetMetadataContext().GetFilterMetadata()
192+
188193
switch v := req.Request.(type) {
189194
case *extProcPb.ProcessingRequest_RequestHeaders:
190195
if requestId := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestId) > 0 {

pkg/epp/requestcontrol/director.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
112112

113113
// Prepare LLMRequest (needed for both saturation detection and Scheduler)
114114
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
115-
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
116-
TargetModel: reqCtx.ResolvedTargetModel,
117-
Prompt: prompt,
118-
Headers: reqCtx.Request.Headers,
115+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
116+
TargetModel: reqCtx.ResolvedTargetModel,
117+
Prompt: prompt,
118+
Headers: reqCtx.Request.Headers,
119+
FilterMetadata: reqCtx.Request.FilterMetadata,
119120
}
120121
logger = logger.WithValues(
121122
"model", reqCtx.Model,

pkg/epp/scheduling/framework/plugins/filter/filter_test.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/google/go-cmp/cmp"
2424
"github.com/google/uuid"
25+
"google.golang.org/protobuf/types/known/structpb"
2526
k8stypes "k8s.io/apimachinery/pkg/types"
2627
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2728
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -247,3 +248,117 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
247248
actualAvailablePercent, availableLowerBound, availableUpperBound)
248249
}
249250
}
251+
252+
func TestSubsettingFilter(t *testing.T) {
253+
var makeFilterMetadata = func(data interface{}) map[string]*structpb.Struct {
254+
structVal, _ := structpb.NewStruct(map[string]interface{}{
255+
"x-gateway-destination-endpoint-subset": data,
256+
})
257+
258+
return map[string]*structpb.Struct{
259+
"envoy.lb.subset_hint": structVal,
260+
}
261+
}
262+
263+
tests := []struct {
264+
name string
265+
req *types.LLMRequest
266+
filter framework.Filter
267+
input []types.Pod
268+
output []types.Pod
269+
}{
270+
{
271+
name: "SubsetFilter, filter not present — return all pods",
272+
req: &types.LLMRequest{
273+
Headers: map[string]string{},
274+
FilterMetadata: map[string]*structpb.Struct{},
275+
},
276+
filter: &SubsetFilter{},
277+
input: []types.Pod{
278+
&types.PodMetrics{
279+
Pod: &backend.Pod{Address: "10.0.0.1"},
280+
},
281+
&types.PodMetrics{
282+
Pod: &backend.Pod{Address: "10.0.0.2"},
283+
},
284+
},
285+
output: []types.Pod{
286+
&types.PodMetrics{
287+
Pod: &backend.Pod{Address: "10.0.0.1"},
288+
},
289+
&types.PodMetrics{
290+
Pod: &backend.Pod{Address: "10.0.0.2"},
291+
},
292+
},
293+
},
294+
{
295+
name: "SubsetFilter, subset with one matching pod",
296+
req: &types.LLMRequest{
297+
FilterMetadata: makeFilterMetadata([]interface{}{"10.0.0.1"}),
298+
},
299+
filter: &SubsetFilter{},
300+
input: []types.Pod{
301+
&types.PodMetrics{
302+
Pod: &backend.Pod{Address: "10.0.0.1"},
303+
},
304+
&types.PodMetrics{
305+
Pod: &backend.Pod{Address: "10.0.0.2"},
306+
},
307+
},
308+
output: []types.Pod{
309+
&types.PodMetrics{
310+
Pod: &backend.Pod{Address: "10.0.0.1"},
311+
},
312+
},
313+
},
314+
{
315+
name: "SubsetFilter, subset with multiple matching pods",
316+
req: &types.LLMRequest{
317+
FilterMetadata: makeFilterMetadata([]interface{}{"10.0.0.1", "10.0.0.2", "10.0.0.3"}),
318+
},
319+
filter: &SubsetFilter{},
320+
input: []types.Pod{
321+
&types.PodMetrics{
322+
Pod: &backend.Pod{Address: "10.0.0.1"},
323+
},
324+
&types.PodMetrics{
325+
Pod: &backend.Pod{Address: "10.0.0.2"},
326+
},
327+
},
328+
output: []types.Pod{
329+
&types.PodMetrics{
330+
Pod: &backend.Pod{Address: "10.0.0.1"},
331+
},
332+
&types.PodMetrics{
333+
Pod: &backend.Pod{Address: "10.0.0.2"},
334+
},
335+
},
336+
},
337+
{
338+
name: "SubsetFilter, subset with no matching pods",
339+
req: &types.LLMRequest{
340+
FilterMetadata: makeFilterMetadata([]interface{}{"10.0.0.3"}),
341+
},
342+
filter: &SubsetFilter{},
343+
input: []types.Pod{
344+
&types.PodMetrics{
345+
Pod: &backend.Pod{Address: "10.0.0.1"},
346+
},
347+
&types.PodMetrics{
348+
Pod: &backend.Pod{Address: "10.0.0.2"},
349+
},
350+
},
351+
output: []types.Pod{},
352+
},
353+
}
354+
355+
for _, test := range tests {
356+
t.Run(test.name, func(t *testing.T) {
357+
got := test.filter.Filter(context.Background(), test.req, types.NewCycleState(), test.input)
358+
359+
if diff := cmp.Diff(test.output, got); diff != "" {
360+
t.Errorf("Unexpected output (-want +got): %v", diff)
361+
}
362+
})
363+
}
364+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package filter
18+
19+
import (
20+
"context"
21+
"strings"
22+
23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
25+
)
26+
27+
const (
28+
subsetHintNamespace = "envoy.lb.subset_hint"
29+
subsetHintKey = "x-gateway-destination-endpoint-subset"
30+
)
31+
32+
// compile-time type assertion
33+
var _ framework.Filter = &SubsetFilter{}
34+
35+
// NewSubsetFilter initializes a new SubsetFilter.
36+
func NewSubsetFilter() *SubsetFilter {
37+
return &SubsetFilter{}
38+
}
39+
40+
// SubsetFilter filters Pods based on the subset hint provided by the proxy via filterMetadata.
41+
type SubsetFilter struct{}
42+
43+
// Name returns the name of the filter.
44+
func (f *SubsetFilter) Name() string {
45+
return "subset-hint"
46+
}
47+
48+
// Filter filters out pods that are not in the subset provided in filterMetadata.
49+
func (f *SubsetFilter) Filter(_ context.Context, request *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
50+
// Check if envoy.lb.subset_hint is present in the metadata map
51+
subsetNamespace, found := request.FilterMetadata[subsetHintNamespace]
52+
if !found {
53+
return pods
54+
}
55+
56+
subsetMap := subsetNamespace.AsMap()
57+
endpointSubsetList, ok := subsetMap[subsetHintKey]
58+
if !ok {
59+
return pods
60+
}
61+
62+
// Assume endpoint list is of type list and must have at least one endpoint
63+
subsetList, ok := endpointSubsetList.([]interface{})
64+
if !ok || len(subsetList) == 0 {
65+
return pods
66+
}
67+
68+
// Create map of pod addys for easy lookup
69+
podAddresses := make(map[string]types.Pod)
70+
for _, pod := range pods {
71+
podAddresses[pod.GetPod().Address] = pod
72+
}
73+
74+
// Filter based on address
75+
filteredPods := []types.Pod{}
76+
for _, endpoint := range subsetList {
77+
epStr := strings.Split(endpoint.(string), ":")[0]
78+
if pod, allowed := podAddresses[epStr]; allowed {
79+
filteredPods = append(filteredPods, pod)
80+
}
81+
}
82+
83+
return filteredPods
84+
}

pkg/epp/scheduling/scheduler.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func NewScheduler(datastore Datastore) *Scheduler {
3838
// When the scheduler is initialized with NewScheduler function, thw below config will be used as default.
3939
// it's possible to call NewSchedulerWithConfig to pass a different scheduler config.
4040
// For build time plugins changes, it's recommended to call in main.go to NewSchedulerWithConfig.
41+
endpointSubsetFilter := filter.NewSubsetFilter()
4142
loraAffinityFilter := filter.NewLoraAffinityFilter()
4243
leastQueueFilter := filter.NewLeastQueueFilter()
4344
leastKvCacheFilter := filter.NewLeastKVCacheFilter()
@@ -65,7 +66,7 @@ func NewScheduler(datastore Datastore) *Scheduler {
6566
}
6667

6768
defaultProfile := framework.NewSchedulerProfile().
68-
WithFilters(lowLatencyFilter).
69+
WithFilters(endpointSubsetFilter, lowLatencyFilter).
6970
WithPicker(&picker.RandomPicker{})
7071

7172
profilePicker := profilepicker.NewSingleProfileHandler()

pkg/epp/scheduling/types/types.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package types
1919
import (
2020
"fmt"
2121

22+
"google.golang.org/protobuf/types/known/structpb"
2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2324
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2425
)
@@ -33,10 +34,12 @@ type LLMRequest struct {
3334
Prompt string
3435
// Headers is a map of the request headers.
3536
Headers map[string]string
37+
// FilterMetadata is a map of metadata in the request
38+
FilterMetadata map[string]*structpb.Struct
3639
}
3740

3841
func (r *LLMRequest) String() string {
39-
return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers)
42+
return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v, FilterMetadata: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers, r.FilterMetadata)
4043
}
4144

4245
type Pod interface {

test/integration/bbr/hermetic_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestBodyBasedRouting(t *testing.T) {
4545
}{
4646
{
4747
name: "success adding model parameter to header",
48-
req: integrationutils.GenerateRequest(logger, "test", "llama"),
48+
req: integrationutils.GenerateRequest(logger, "test", "llama", nil),
4949
wantHeaders: []*configPb.HeaderValueOption{
5050
{
5151
Header: &configPb.HeaderValue{
@@ -58,7 +58,7 @@ func TestBodyBasedRouting(t *testing.T) {
5858
},
5959
{
6060
name: "no model parameter",
61-
req: integrationutils.GenerateRequest(logger, "test1", ""),
61+
req: integrationutils.GenerateRequest(logger, "test1", "", nil),
6262
wantHeaders: []*configPb.HeaderValueOption{},
6363
wantErr: false,
6464
},
@@ -107,7 +107,7 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) {
107107
}{
108108
{
109109
name: "success adding model parameter to header",
110-
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", "foo"),
110+
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", "foo", nil),
111111
wantResponses: []*extProcPb.ProcessingResponse{
112112
{
113113
Response: &extProcPb.ProcessingResponse_RequestHeaders{
@@ -212,7 +212,7 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) {
212212
},
213213
{
214214
name: "no model parameter",
215-
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", ""),
215+
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", "", nil),
216216
wantResponses: []*extProcPb.ProcessingResponse{
217217
{
218218
Response: &extProcPb.ProcessingResponse_RequestHeaders{

0 commit comments

Comments
 (0)