@@ -18,20 +18,25 @@ package filter
18
18
19
19
import (
20
20
"context"
21
+ "strings"
21
22
22
23
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
23
24
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
24
25
)
25
26
26
27
const (
28
+ // headerTestEppEndPointSelectionKey is the header used for testing purposes to make EPP behavior controllable.
29
+ // The header value should be a comma-separated list of endpoint IP addresses.
30
+ // E.g., "test-epp-endpoint-selection": "10.0.0.7,10.0.0.8"
31
+ // The returned order is the same as the order provided in the header.
27
32
headerTestEppEndPointSelectionKey = "test-epp-endpoint-selection"
28
33
)
29
34
30
35
// compile-time type assertion
31
36
var _ framework.Filter = & HeaderBasedTestingFilter {}
32
37
33
- // NewHeaderBasedTestingFilter initializes a new HeaderBasedTestingFilter and returns its pointer .
34
- // This should be only used in testing purpose .
38
+ // NewHeaderBasedTestingFilter initializes a new HeaderBasedTestingFilter.
39
+ // This should only be used for testing purposes .
35
40
func NewHeaderBasedTestingFilter () * HeaderBasedTestingFilter {
36
41
return & HeaderBasedTestingFilter {}
37
42
}
@@ -41,20 +46,26 @@ type HeaderBasedTestingFilter struct{}
41
46
42
47
// Name returns the name of the filter.
43
48
func (f * HeaderBasedTestingFilter ) Name () string {
44
- return "test- header-based"
49
+ return "header-based-testing "
45
50
}
46
51
47
- // Filter filters out pods that doesn't meet the filter criteria .
52
+ // Filter selects pods that match the IP addresses specified in the request header .
48
53
func (f * HeaderBasedTestingFilter ) Filter (_ context.Context , request * types.LLMRequest , _ * types.CycleState , pods []types.Pod ) []types.Pod {
49
- filteredPods := []types.Pod {}
50
-
51
- endPointInReqeust , found := request .Headers [headerTestEppEndPointSelectionKey ]
52
- if ! found {
53
- return filteredPods
54
+ headerValue , ok := request .Headers [headerTestEppEndPointSelectionKey ]
55
+ if ! ok || headerValue == "" {
56
+ return []types.Pod {}
54
57
}
55
58
59
+ podAddressMap := make (map [string ]types.Pod , len (pods ))
56
60
for _ , pod := range pods {
57
- if pod .GetPod ().Address == endPointInReqeust {
61
+ podAddressMap [pod .GetPod ().Address ] = pod
62
+ }
63
+
64
+ endpoints := strings .Split (headerValue , "," )
65
+ filteredPods := make ([]types.Pod , 0 , len (endpoints ))
66
+ for _ , endpoint := range endpoints {
67
+ trimmedEndpoint := strings .TrimSpace (endpoint )
68
+ if pod , found := podAddressMap [trimmedEndpoint ]; found {
58
69
filteredPods = append (filteredPods , pod )
59
70
}
60
71
}
0 commit comments