Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/epp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ func run() error {
scorers,
picker.NewMaxScorePicker(),
[]plugins.PostSchedule{},
[]plugins.PostResponse{},
schedConfigOpts...)
scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig)
}
Expand Down
59 changes: 56 additions & 3 deletions pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
Expand Down Expand Up @@ -98,6 +99,58 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(
}
}

func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) {
for _, header := range resp.ResponseHeaders.Headers.Headers {
if header.RawValue != nil {
reqCtx.Response.Headers[header.Key] = string(header.RawValue)
} else {
reqCtx.Response.Headers[header.Key] = header.Value
}
}

reqCtx, err := s.director.HandleResponse(ctx, reqCtx)

return reqCtx, err
}

func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext) *extProcPb.ProcessingResponse {
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: s.generateResponseHeaders(reqCtx),
},
},
},
},
}
}

func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption {
// can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic.
headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
// This is for debugging purpose only.
Key: "x-went-into-resp-headers",
RawValue: []byte("true"),
},
},
}

// include all headers
for key, value := range reqCtx.Response.Headers {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
RawValue: []byte(value),
},
})
}
return headers
}

// Example message if "stream_options": {"include_usage": "true"} is included in the request:
// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[],
// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
Expand All @@ -112,8 +165,8 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(
func parseRespForUsage(
ctx context.Context,
responseText string,
) Response {
response := Response{}
) ResponseBody {
response := ResponseBody{}
logger := log.FromContext(ctx)

lines := strings.Split(responseText, "\n")
Expand All @@ -136,7 +189,7 @@ func parseRespForUsage(
return response
}

type Response struct {
type ResponseBody struct {
Usage Usage `json:"usage"`
}

Expand Down
34 changes: 15 additions & 19 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"strings"
"time"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/go-logr/logr"
Expand All @@ -49,6 +48,7 @@ func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEnd

type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
}

Expand Down Expand Up @@ -91,6 +91,8 @@ type RequestContext struct {
RequestState StreamRequestState
modelServerStreaming bool

Response *Response

reqHeaderResp *extProcPb.ProcessingResponse
reqBodyResp *extProcPb.ProcessingResponse
reqTrailerResp *extProcPb.ProcessingResponse
Expand All @@ -104,6 +106,9 @@ type Request struct {
Headers map[string]string
Body map[string]interface{}
}
type Response struct {
Headers map[string]string
}
type StreamRequestState int

const (
Expand Down Expand Up @@ -131,6 +136,9 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
Headers: make(map[string]string),
Body: make(map[string]interface{}),
},
Response: &Response{
Headers: make(map[string]string),
},
}

var body []byte
Expand Down Expand Up @@ -229,25 +237,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
}
}
reqCtx.RequestState = ResponseRecieved
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
// This is for debugging purpose only.
Key: "x-went-into-resp-headers",
RawValue: []byte("true"),
},
},
},
},
},
},
},

var responseErr error
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
if responseErr != nil {
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response headers", "request", req)
}
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)

case *extProcPb.ProcessingRequest_ResponseBody:
if reqCtx.modelServerStreaming {
Expand Down
17 changes: 17 additions & 0 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ import (
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

type Scheduler interface {
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string)
}

type Director struct {
Expand Down Expand Up @@ -84,6 +86,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo

llmReq := &schedulingtypes.LLMRequest{
TargetModel: reqCtx.ResolvedTargetModel,
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
Prompt: prompt,
Headers: reqCtx.Request.Headers,
Expand Down Expand Up @@ -137,6 +140,20 @@ func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestCon
return reqCtx, nil
}

func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)

llmResp := &schedulingtypes.LLMResponse{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
}
logger.V(logutil.DEBUG).Info("LLM response assembled", "response", llmResp)

d.scheduler.OnResponse(ctx, llmResp, reqCtx.TargetPod)

return reqCtx, nil
}

func (d *Director) GetRandomPod() *backend.Pod {
pods := d.datastore.PodGetAll()
if len(pods) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion pkg/epp/scheduling/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import (

// NewSchedulerConfig creates a new SchedulerConfig object with the given plugins.
func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int,
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, opts ...ConfigOption) *SchedulerConfig {
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, postResponsePlugins []plugins.PostResponse, opts ...ConfigOption) *SchedulerConfig {
config := &SchedulerConfig{
preSchedulePlugins: preSchedulePlugins,
filters: filters,
scorers: scorers,
picker: picker,
postSchedulePlugins: postSchedulePlugins,
postResponsePlugins: postResponsePlugins,
}
for _, opt := range opts {
opt(config)
Expand All @@ -44,6 +45,7 @@ type SchedulerConfig struct {
scorers map[plugins.Scorer]int // map from scorer to weight
picker plugins.Picker
postSchedulePlugins []plugins.PostSchedule
postResponsePlugins []plugins.PostResponse
}

type ConfigOption func(*SchedulerConfig)
Expand Down
6 changes: 4 additions & 2 deletions pkg/epp/scheduling/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
Expand Down Expand Up @@ -170,7 +171,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
got := test.filter.Filter(ctx, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
Expand Down Expand Up @@ -205,6 +206,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
// Create a test request and pods
req := &types.LLMRequest{
TargetModel: testAffinityModel,
RequestId: uuid.NewString(),
}

// Test setup: One affinity pod and one available pod
Expand All @@ -226,7 +228,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
},
},
}
ctx := types.NewSchedulingContext(context.Background(), req, pods)
ctx := types.NewSchedulingContext(context.Background(), req, nil, pods)

// Run the filter function multiple times and count the results
affinityCount := 0
Expand Down
10 changes: 5 additions & 5 deletions pkg/epp/scheduling/plugins/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaaaa",
}
ctx := types.NewSchedulingContext(context.Background(), req1, pods)
ctx := types.NewSchedulingContext(context.Background(), req1, nil, pods)
plugin.PreSchedule(ctx)
state, err := plugin.getPrefixState(ctx.CycleState)
assert.NoError(t, err)
Expand All @@ -51,7 +51,7 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model2",
Prompt: "bbbbbb",
}
ctx = types.NewSchedulingContext(context.Background(), req2, pods)
ctx = types.NewSchedulingContext(context.Background(), req2, nil, pods)
plugin.PreSchedule(ctx)
state, err = plugin.getPrefixState(ctx.CycleState)
assert.NoError(t, err)
Expand All @@ -74,7 +74,7 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaabbbb",
}
ctx = types.NewSchedulingContext(context.Background(), req3, pods)
ctx = types.NewSchedulingContext(context.Background(), req3, nil, pods)
plugin.PreSchedule(ctx)
state, err = plugin.getPrefixState(ctx.CycleState)
assert.NoError(t, err)
Expand All @@ -96,7 +96,7 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model-new",
Prompt: "aaaabbbb",
}
ctx = types.NewSchedulingContext(context.Background(), req4, pods)
ctx = types.NewSchedulingContext(context.Background(), req4, nil, pods)
plugin.PreSchedule(ctx)
state, err = plugin.getPrefixState(ctx.CycleState)
assert.NoError(t, err)
Expand All @@ -118,7 +118,7 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaabbbbcccc",
}
ctx = types.NewSchedulingContext(context.Background(), req5, pods)
ctx = types.NewSchedulingContext(context.Background(), req5, nil, pods)
plugin.PreSchedule(ctx)
state, err = plugin.getPrefixState(ctx.CycleState)
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/plugins/scorer/kvcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
scorer := &KVCacheScorer{}
scores := scorer.Score(ctx, tt.pods)

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/scheduling/plugins/scorer/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
scores := scorer.Score(ctx, tt.pods)

for i, pod := range tt.pods {
Expand Down
33 changes: 32 additions & 1 deletion pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched
scorers: config.scorers,
picker: config.picker,
postSchedulePlugins: config.postSchedulePlugins,
postResponsePlugins: config.postResponsePlugins,
}
}

Expand All @@ -94,6 +95,7 @@ type Scheduler struct {
scorers map[plugins.Scorer]int // map from scorer to its weight
picker plugins.Picker
postSchedulePlugins []plugins.PostSchedule
postResponsePlugins []plugins.PostResponse
}

type Datastore interface {
Expand All @@ -113,7 +115,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))

s.runPreSchedulePlugins(sCtx)
Expand Down Expand Up @@ -211,3 +213,32 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
}
}

// OnResponse is invoked during the processing of a response from an inference pod. It will invoke
// any defined plugins that process the response.
func (s *Scheduler) OnResponse(ctx context.Context, resp *types.LLMResponse, targetPodName string) {
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll())
var targetPod types.Pod
for _, pod := range pods {
if pod.GetPod().NamespacedName.String() == targetPodName {
targetPod = pod
break
}
}

sCtx := types.NewSchedulingContext(ctx, nil, resp, pods)

s.runPostResponsePlugins(sCtx, targetPod)
}

func (s *Scheduler) runPostResponsePlugins(ctx *types.SchedulingContext, targetPod types.Pod) {
for _, plugin := range s.postResponsePlugins {
ctx.Logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name())
before := time.Now()
plugin.PostResponse(ctx, targetPod)
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before))
}
}
Loading