Skip to content

Commit 9fb812f

Browse files
shmuelknirrozenbaum
authored andcommitted
feat: Add support to invoke PostResponse plugins (kubernetes-sigs#800)
* Added the LLMResponse struct and RequestId to LLMRequest Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Updates due to NewSchedulerContext API change Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Populate the RequestId field of LLMRequest Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Updates to tests Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Added PostResponse plugins to scheduler config Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Added scheduler.OnResponse to handle responses Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Added dispatcher.HandleResponse to handle responses Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Refactored server response header handling to invoke PostResponse plugins Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Added simple test for PostResponse plugins Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Setup the logger in the SchedulerContext appropriately for reponses Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Updates due to rebase issues * merge functions in env utils (kubernetes-sigs#819) Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * generalize scheduling cycle state concept (kubernetes-sigs#818) * generalize scheduling cycle state concept Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * typo Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * make linter happy Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * make prefix state struct internal to package instead of public Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> --------- Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * remove Model field from LLMRequest (kubernetes-sigs#782) * remove Model field from LLMRequest Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * rebase handling Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> --------- Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> * Added the LLMResponse struct and RequestId to LLMRequest Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Insure that wanted response header messages have all of the response headers in them Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> --------- Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com> Co-authored-by: Nir Rozenbaum <nirro@il.ibm.com>
1 parent 27ad10a commit 9fb812f

File tree

14 files changed

+258
-35
lines changed

14 files changed

+258
-35
lines changed

cmd/epp/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ func run() error {
211211
scorers,
212212
picker.NewMaxScorePicker(),
213213
[]plugins.PostSchedule{},
214+
[]plugins.PostResponse{},
214215
schedConfigOpts...)
215216
scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig)
216217
}

pkg/epp/handlers/response.go

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"strings"
2323

24+
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2425
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2526
"sigs.k8s.io/controller-runtime/pkg/log"
2627
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -98,6 +99,58 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(
9899
}
99100
}
100101

102+
func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) {
103+
for _, header := range resp.ResponseHeaders.Headers.Headers {
104+
if header.RawValue != nil {
105+
reqCtx.Response.Headers[header.Key] = string(header.RawValue)
106+
} else {
107+
reqCtx.Response.Headers[header.Key] = header.Value
108+
}
109+
}
110+
111+
reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
112+
113+
return reqCtx, err
114+
}
115+
116+
func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext) *extProcPb.ProcessingResponse {
117+
return &extProcPb.ProcessingResponse{
118+
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
119+
ResponseHeaders: &extProcPb.HeadersResponse{
120+
Response: &extProcPb.CommonResponse{
121+
HeaderMutation: &extProcPb.HeaderMutation{
122+
SetHeaders: s.generateResponseHeaders(reqCtx),
123+
},
124+
},
125+
},
126+
},
127+
}
128+
}
129+
130+
func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption {
131+
// can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic.
132+
headers := []*configPb.HeaderValueOption{
133+
{
134+
Header: &configPb.HeaderValue{
135+
// This is for debugging purpose only.
136+
Key: "x-went-into-resp-headers",
137+
RawValue: []byte("true"),
138+
},
139+
},
140+
}
141+
142+
// include all headers
143+
for key, value := range reqCtx.Response.Headers {
144+
headers = append(headers, &configPb.HeaderValueOption{
145+
Header: &configPb.HeaderValue{
146+
Key: key,
147+
RawValue: []byte(value),
148+
},
149+
})
150+
}
151+
return headers
152+
}
153+
101154
// Example message if "stream_options": {"include_usage": "true"} is included in the request:
102155
// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[],
103156
// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
@@ -112,8 +165,8 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(
112165
func parseRespForUsage(
113166
ctx context.Context,
114167
responseText string,
115-
) Response {
116-
response := Response{}
168+
) ResponseBody {
169+
response := ResponseBody{}
117170
logger := log.FromContext(ctx)
118171

119172
lines := strings.Split(responseText, "\n")
@@ -136,7 +189,7 @@ func parseRespForUsage(
136189
return response
137190
}
138191

139-
type Response struct {
192+
type ResponseBody struct {
140193
Usage Usage `json:"usage"`
141194
}
142195

pkg/epp/handlers/server.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"strings"
2424
"time"
2525

26-
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2726
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2827
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
2928
"github.com/go-logr/logr"
@@ -49,6 +48,7 @@ func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEnd
4948

5049
type Director interface {
5150
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
51+
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
5252
GetRandomPod() *backend.Pod
5353
}
5454

@@ -91,6 +91,8 @@ type RequestContext struct {
9191
RequestState StreamRequestState
9292
modelServerStreaming bool
9393

94+
Response *Response
95+
9496
reqHeaderResp *extProcPb.ProcessingResponse
9597
reqBodyResp *extProcPb.ProcessingResponse
9698
reqTrailerResp *extProcPb.ProcessingResponse
@@ -104,6 +106,9 @@ type Request struct {
104106
Headers map[string]string
105107
Body map[string]interface{}
106108
}
109+
type Response struct {
110+
Headers map[string]string
111+
}
107112
type StreamRequestState int
108113

109114
const (
@@ -131,6 +136,9 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
131136
Headers: make(map[string]string),
132137
Body: make(map[string]interface{}),
133138
},
139+
Response: &Response{
140+
Headers: make(map[string]string),
141+
},
134142
}
135143

136144
var body []byte
@@ -229,25 +237,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
229237
}
230238
}
231239
reqCtx.RequestState = ResponseRecieved
232-
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
233-
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
234-
ResponseHeaders: &extProcPb.HeadersResponse{
235-
Response: &extProcPb.CommonResponse{
236-
HeaderMutation: &extProcPb.HeaderMutation{
237-
SetHeaders: []*configPb.HeaderValueOption{
238-
{
239-
Header: &configPb.HeaderValue{
240-
// This is for debugging purpose only.
241-
Key: "x-went-into-resp-headers",
242-
RawValue: []byte("true"),
243-
},
244-
},
245-
},
246-
},
247-
},
248-
},
249-
},
240+
241+
var responseErr error
242+
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
243+
if responseErr != nil {
244+
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response headers", "request", req)
250245
}
246+
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)
251247

252248
case *extProcPb.ProcessingRequest_ResponseBody:
253249
if reqCtx.modelServerStreaming {

pkg/epp/requestcontrol/director.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ import (
3131
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3232
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
3333
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3435
)
3536

3637
type Scheduler interface {
3738
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
39+
OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string)
3840
}
3941

4042
type Director struct {
@@ -84,6 +86,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
8486

8587
llmReq := &schedulingtypes.LLMRequest{
8688
TargetModel: reqCtx.ResolvedTargetModel,
89+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
8790
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
8891
Prompt: prompt,
8992
Headers: reqCtx.Request.Headers,
@@ -137,6 +140,20 @@ func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestCon
137140
return reqCtx, nil
138141
}
139142

143+
func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
144+
logger := log.FromContext(ctx)
145+
146+
llmResp := &schedulingtypes.LLMResponse{
147+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
148+
Headers: reqCtx.Response.Headers,
149+
}
150+
logger.V(logutil.DEBUG).Info("LLM response assembled", "response", llmResp)
151+
152+
d.scheduler.OnResponse(ctx, llmResp, reqCtx.TargetPod)
153+
154+
return reqCtx, nil
155+
}
156+
140157
func (d *Director) GetRandomPod() *backend.Pod {
141158
pods := d.datastore.PodGetAll()
142159
if len(pods) == 0 {

pkg/epp/scheduling/config.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ import (
2323

2424
// NewSchedulerConfig creates a new SchedulerConfig object with the given plugins.
2525
func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int,
26-
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, opts ...ConfigOption) *SchedulerConfig {
26+
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, postResponsePlugins []plugins.PostResponse, opts ...ConfigOption) *SchedulerConfig {
2727
config := &SchedulerConfig{
2828
preSchedulePlugins: preSchedulePlugins,
2929
filters: filters,
3030
scorers: scorers,
3131
picker: picker,
3232
postSchedulePlugins: postSchedulePlugins,
33+
postResponsePlugins: postResponsePlugins,
3334
}
3435
for _, opt := range opts {
3536
opt(config)
@@ -44,6 +45,7 @@ type SchedulerConfig struct {
4445
scorers map[plugins.Scorer]int // map from scorer to weight
4546
picker plugins.Picker
4647
postSchedulePlugins []plugins.PostSchedule
48+
postResponsePlugins []plugins.PostResponse
4749
}
4850

4951
type ConfigOption func(*SchedulerConfig)

pkg/epp/scheduling/plugins/filter/filter_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"testing"
2222

2323
"github.com/google/go-cmp/cmp"
24+
"github.com/google/uuid"
2425
k8stypes "k8s.io/apimachinery/pkg/types"
2526
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -170,7 +171,7 @@ func TestFilter(t *testing.T) {
170171

171172
for _, test := range tests {
172173
t.Run(test.name, func(t *testing.T) {
173-
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
174+
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
174175
got := test.filter.Filter(ctx, test.input)
175176

176177
if diff := cmp.Diff(test.output, got); diff != "" {
@@ -205,6 +206,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
205206
// Create a test request and pods
206207
req := &types.LLMRequest{
207208
TargetModel: testAffinityModel,
209+
RequestId: uuid.NewString(),
208210
}
209211

210212
// Test setup: One affinity pod and one available pod
@@ -226,7 +228,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
226228
},
227229
},
228230
}
229-
ctx := types.NewSchedulingContext(context.Background(), req, pods)
231+
ctx := types.NewSchedulingContext(context.Background(), req, nil, pods)
230232

231233
// Run the filter function multiple times and count the results
232234
affinityCount := 0

pkg/epp/scheduling/plugins/prefix/plugin_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestPrefixPlugin(t *testing.T) {
2727
TargetModel: "test-model1",
2828
Prompt: "aaaaaa",
2929
}
30-
ctx := types.NewSchedulingContext(context.Background(), req1, pods)
30+
ctx := types.NewSchedulingContext(context.Background(), req1, nil, pods)
3131
plugin.PreSchedule(ctx)
3232
state, err := plugin.getPrefixState(ctx.CycleState)
3333
assert.NoError(t, err)
@@ -51,7 +51,7 @@ func TestPrefixPlugin(t *testing.T) {
5151
TargetModel: "test-model2",
5252
Prompt: "bbbbbb",
5353
}
54-
ctx = types.NewSchedulingContext(context.Background(), req2, pods)
54+
ctx = types.NewSchedulingContext(context.Background(), req2, nil, pods)
5555
plugin.PreSchedule(ctx)
5656
state, err = plugin.getPrefixState(ctx.CycleState)
5757
assert.NoError(t, err)
@@ -74,7 +74,7 @@ func TestPrefixPlugin(t *testing.T) {
7474
TargetModel: "test-model1",
7575
Prompt: "aaaabbbb",
7676
}
77-
ctx = types.NewSchedulingContext(context.Background(), req3, pods)
77+
ctx = types.NewSchedulingContext(context.Background(), req3, nil, pods)
7878
plugin.PreSchedule(ctx)
7979
state, err = plugin.getPrefixState(ctx.CycleState)
8080
assert.NoError(t, err)
@@ -96,7 +96,7 @@ func TestPrefixPlugin(t *testing.T) {
9696
TargetModel: "test-model-new",
9797
Prompt: "aaaabbbb",
9898
}
99-
ctx = types.NewSchedulingContext(context.Background(), req4, pods)
99+
ctx = types.NewSchedulingContext(context.Background(), req4, nil, pods)
100100
plugin.PreSchedule(ctx)
101101
state, err = plugin.getPrefixState(ctx.CycleState)
102102
assert.NoError(t, err)
@@ -118,7 +118,7 @@ func TestPrefixPlugin(t *testing.T) {
118118
TargetModel: "test-model1",
119119
Prompt: "aaaabbbbcccc",
120120
}
121-
ctx = types.NewSchedulingContext(context.Background(), req5, pods)
121+
ctx = types.NewSchedulingContext(context.Background(), req5, nil, pods)
122122
plugin.PreSchedule(ctx)
123123
state, err = plugin.getPrefixState(ctx.CycleState)
124124
assert.NoError(t, err)

pkg/epp/scheduling/plugins/scorer/kvcache_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) {
8282

8383
for _, tt := range tests {
8484
t.Run(tt.name, func(t *testing.T) {
85-
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
85+
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
8686
scorer := &KVCacheScorer{}
8787
scores := scorer.Score(ctx, tt.pods)
8888

pkg/epp/scheduling/plugins/scorer/queue_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) {
7373

7474
for _, tt := range tests {
7575
t.Run(tt.name, func(t *testing.T) {
76-
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
76+
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
7777
scores := scorer.Score(ctx, tt.pods)
7878

7979
for i, pod := range tt.pods {

pkg/epp/scheduling/scheduler.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched
8484
scorers: config.scorers,
8585
picker: config.picker,
8686
postSchedulePlugins: config.postSchedulePlugins,
87+
postResponsePlugins: config.postResponsePlugins,
8788
}
8889
}
8990

@@ -94,6 +95,7 @@ type Scheduler struct {
9495
scorers map[plugins.Scorer]int // map from scorer to its weight
9596
picker plugins.Picker
9697
postSchedulePlugins []plugins.PostSchedule
98+
postResponsePlugins []plugins.PostResponse
9799
}
98100

99101
type Datastore interface {
@@ -113,7 +115,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
113115
// Snapshot pod metrics from the datastore to:
114116
// 1. Reduce concurrent access to the datastore.
115117
// 2. Ensure consistent data during the scheduling operation of a request.
116-
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
118+
sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
117119
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))
118120

119121
s.runPreSchedulePlugins(sCtx)
@@ -211,3 +213,32 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty
211213
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
212214
}
213215
}
216+
217+
// OnResponse is invoked during the processing of a response from an inference pod. It will invoke
218+
// any defined plugins that process the response.
219+
func (s *Scheduler) OnResponse(ctx context.Context, resp *types.LLMResponse, targetPodName string) {
220+
// Snapshot pod metrics from the datastore to:
221+
// 1. Reduce concurrent access to the datastore.
222+
// 2. Ensure consistent data during the scheduling operation of a request.
223+
pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll())
224+
var targetPod types.Pod
225+
for _, pod := range pods {
226+
if pod.GetPod().NamespacedName.String() == targetPodName {
227+
targetPod = pod
228+
break
229+
}
230+
}
231+
232+
sCtx := types.NewSchedulingContext(ctx, nil, resp, pods)
233+
234+
s.runPostResponsePlugins(sCtx, targetPod)
235+
}
236+
237+
func (s *Scheduler) runPostResponsePlugins(ctx *types.SchedulingContext, targetPod types.Pod) {
238+
for _, plugin := range s.postResponsePlugins {
239+
ctx.Logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name())
240+
before := time.Now()
241+
plugin.PostResponse(ctx, targetPod)
242+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before))
243+
}
244+
}

0 commit comments

Comments
 (0)