Skip to content

Commit a2c19cf

Browse files
committed
remove Model field from LLMRequest
Signed-off-by: Nir Rozenbaum <nirro@il.ibm.com>
1 parent 2b66451 commit a2c19cf

File tree

5 files changed

+21
-34
lines changed

5 files changed

+21
-34
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
7979
if reqCtx.ResolvedTargetModel == "" {
8080
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
8181
}
82+
reqCtx.Request.Body["model"] = reqCtx.ResolvedTargetModel // Update target model in the body.
8283
}
8384

8485
llmReq := &schedulingtypes.LLMRequest{
85-
Model: reqCtx.Model,
86-
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
87-
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
88-
Prompt: prompt,
89-
Headers: reqCtx.Request.Headers,
86+
TargetModel: reqCtx.ResolvedTargetModel,
87+
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
88+
Prompt: prompt,
89+
Headers: reqCtx.Request.Headers,
9090
}
9191
logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq)
9292
results, err := d.Dispatch(ctx, llmReq)
@@ -129,13 +129,8 @@ func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestCon
129129
}
130130

131131
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
132-
logger.V(logutil.DEFAULT).Info("Request handled",
133-
"model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod)
132+
logger.V(logutil.DEFAULT).Info("Request handled", "model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod)
134133

135-
// Update target models in the body.
136-
if reqCtx.Model != reqCtx.ResolvedTargetModel {
137-
reqCtx.Request.Body["model"] = reqCtx.ResolvedTargetModel
138-
}
139134
reqCtx.TargetPod = targetPod.NamespacedName.String()
140135
reqCtx.TargetEndpoint = endpoint
141136

pkg/epp/scheduling/plugins/filter/filter_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
203203

204204
// Create a test request and pods
205205
req := &types.LLMRequest{
206-
Model: testAffinityModel,
207-
ResolvedTargetModel: testAffinityModel,
206+
TargetModel: testAffinityModel,
208207
}
209208

210209
// Test setup: One affinity pod and one available pod

pkg/epp/scheduling/plugins/filter/lora_affinity_filter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ func (f *LoraAffinityFilter) Filter(ctx *types.SchedulingContext, pods []types.P
5959

6060
// Categorize pods based on affinity and availability
6161
for _, pod := range pods {
62-
_, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel]
63-
_, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel]
62+
_, active := pod.GetMetrics().ActiveModels[ctx.Req.TargetModel]
63+
_, waiting := pod.GetMetrics().WaitingModels[ctx.Req.TargetModel]
6464

6565
if active || waiting {
6666
filtered_affinity = append(filtered_affinity, pod)

pkg/epp/scheduling/scheduler_test.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,17 @@ func TestSchedule(t *testing.T) {
4040
{
4141
name: "no pods in datastore",
4242
req: &types.LLMRequest{
43-
Model: "any-model",
44-
ResolvedTargetModel: "any-model",
45-
Critical: true,
43+
TargetModel: "any-model",
44+
Critical: true,
4645
},
4746
input: []*backendmetrics.FakePodMetrics{},
4847
err: true,
4948
},
5049
{
5150
name: "critical request",
5251
req: &types.LLMRequest{
53-
Model: "critical",
54-
ResolvedTargetModel: "critical",
55-
Critical: true,
52+
TargetModel: "critical",
53+
Critical: true,
5654
},
5755
// pod2 will be picked because it has relatively low queue size, with the requested
5856
// model being active, and has low KV cache.
@@ -114,9 +112,8 @@ func TestSchedule(t *testing.T) {
114112
{
115113
name: "sheddable request, accepted",
116114
req: &types.LLMRequest{
117-
Model: "sheddable",
118-
ResolvedTargetModel: "sheddable",
119-
Critical: false,
115+
TargetModel: "sheddable",
116+
Critical: false,
120117
},
121118
// pod1 will be picked because it has capacity for the sheddable request.
122119
input: []*backendmetrics.FakePodMetrics{
@@ -177,9 +174,8 @@ func TestSchedule(t *testing.T) {
177174
{
178175
name: "sheddable request, dropped",
179176
req: &types.LLMRequest{
180-
Model: "sheddable",
181-
ResolvedTargetModel: "sheddable",
182-
Critical: false,
177+
TargetModel: "sheddable",
178+
Critical: false,
183179
},
184180
// All pods have higher KV cache thant the threshold, so the sheddable request will be
185181
// dropped.
@@ -356,7 +352,7 @@ func TestSchedulePlugins(t *testing.T) {
356352
// Initialize the scheduler
357353
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)
358354

359-
req := &types.LLMRequest{Model: "test-model"}
355+
req := &types.LLMRequest{TargetModel: "test-model"}
360356
got, err := scheduler.Schedule(context.Background(), req)
361357

362358
// Validate error state

pkg/epp/scheduling/types/types.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ import (
2828

2929
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
3030
type LLMRequest struct {
31-
// Model is the name of the model that the user specified in the request body.
32-
Model string
33-
// ResolvedTargetModel is the final target model after traffic split.
34-
ResolvedTargetModel string
31+
// TargetModel is the final target model after traffic split.
32+
TargetModel string
3533
// Critical is a boolean that specifies if a request is critical or not.
3634
Critical bool
3735
// Prompt is the prompt that was sent in the request body.
@@ -41,8 +39,7 @@ type LLMRequest struct {
4139
}
4240

4341
func (r *LLMRequest) String() string {
44-
return fmt.Sprintf("Model: %s, ResolvedTargetModel: %s, Critical: %t, PromptLength: %d, Headers: %v",
45-
r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers)
42+
return fmt.Sprintf("TargetModel: %s, Critical: %t, PromptLength: %d, Headers: %v", r.TargetModel, r.Critical, len(r.Prompt), r.Headers)
4643
}
4744

4845
type Pod interface {

0 commit comments

Comments
 (0)