Skip to content

Commit 075d458

Browse files
kaushikmitrBenjaminBraunDev
authored andcommitted
emit predicted and actual ttft tpot in body
1 parent 506083f commit 075d458

File tree

7 files changed

+162
-78
lines changed

7 files changed

+162
-78
lines changed

config/manifests/inferencepool-resources.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ data:
1717
LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib"
1818
LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib"
1919
LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib"
20+
LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET: "5000"
2021

2122
---
2223
apiVersion: inference.networking.k8s.io/v1

pkg/epp/handlers/request.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ func (s *StreamingServer) generateRequestHeaderResponse(reqCtx *RequestContext)
108108
SetHeaders: s.generateHeaders(reqCtx),
109109
},
110110
},
111+
111112
},
113+
112114
},
113115
DynamicMetadata: s.generateMetadata(reqCtx.TargetEndpoint),
114116
}

pkg/epp/handlers/response.go

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ import (
2222
"strings"
2323

2424
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
25+
filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
2526
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27+
"github.com/go-logr/logr"
28+
2629
"sigs.k8s.io/controller-runtime/pkg/log"
2730

2831
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -60,7 +63,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6063
// will add the processing for streaming case.
6164
reqCtx.ResponseComplete = true
6265

63-
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
66+
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true, reqCtx, logger)
6467
return reqCtx, nil
6568
}
6669

@@ -75,12 +78,11 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
7578
s.director.HandleResponseBodyChunk(ctx, reqCtx)
7679
}
7780

78-
7981
// The function is to handle streaming response if the modelServer is streaming.
8082
func (s *StreamingServer) HandleResponseTrailers(
8183
ctx context.Context,
8284
reqCtx *RequestContext,
83-
) (*RequestContext, error) {
85+
) (*RequestContext, error) {
8486

8587
return s.director.HandleResponseTrailers(ctx, reqCtx)
8688
}
@@ -110,6 +112,9 @@ func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext)
110112
},
111113
},
112114
},
115+
ModeOverride: &filterPb.ProcessingMode{
116+
ResponseTrailerMode: filterPb.ProcessingMode_SEND,
117+
},
113118
}
114119
}
115120

@@ -118,29 +123,95 @@ func (s *StreamingServer) generateResponseTrailerResponse(reqCtx *RequestContext
118123
return &extProcPb.ProcessingResponse{
119124
Response: &extProcPb.ProcessingResponse_ResponseTrailers{
120125
ResponseTrailers: &extProcPb.TrailersResponse{
121-
HeaderMutation: &extProcPb.HeaderMutation{
122-
// Correct field or remove if unnecessary
123-
SetHeaders: s.generateResponseTrailers(reqCtx),
124-
},
126+
HeaderMutation: &extProcPb.HeaderMutation{
127+
// Correct field or remove if unnecessary
128+
SetHeaders: s.generateResponseTrailers(reqCtx),
125129
},
126130
},
127-
}
131+
},
128132
}
133+
}
134+
135+
func generateResponseBodyResponses(
136+
responseBodyBytes []byte,
137+
setEoS bool,
138+
reqCtx *RequestContext,
139+
logger logr.Logger,
140+
) []*extProcPb.ProcessingResponse {
141+
if reqCtx != nil && reqCtx.ModelServerStreaming {
142+
143+
raw := string(responseBodyBytes)
144+
events := strings.Split(raw, "\n\n")
129145

130-
func generateResponseBodyResponses(responseBodyBytes []byte, setEoS bool) []*extProcPb.ProcessingResponse {
131-
commonResponses := buildCommonResponses(responseBodyBytes, bodyByteLimit, setEoS)
132-
responses := []*extProcPb.ProcessingResponse{}
133-
for _, commonResp := range commonResponses {
134-
resp := &extProcPb.ProcessingResponse{
135-
Response: &extProcPb.ProcessingResponse_ResponseBody{
136-
ResponseBody: &extProcPb.BodyResponse{
137-
Response: commonResp,
146+
var rebuilt strings.Builder
147+
for _, ev := range events {
148+
if !strings.HasPrefix(ev, "data: ") {
149+
continue
150+
}
151+
payload := strings.TrimPrefix(ev, "data: ")
152+
if payload == "[DONE]" {
153+
rebuilt.WriteString("data: [DONE]\n\n")
154+
continue
155+
}
156+
157+
// Try to unmarshal only the JSON
158+
var obj map[string]interface{}
159+
if err := json.Unmarshal([]byte(payload), &obj); err != nil {
160+
logger.Error(err, "failed to unmarshal SSE payload", "payload", payload)
161+
} else {
162+
if usage, ok := obj["usage"].(map[string]interface{}); ok && usage != nil {
163+
usage["ttft_ms"] = reqCtx.TTFT
164+
usage["predicted_ttft_ms"] = reqCtx.PredictedTTFT
165+
usage["tpot_observations_ms"] = reqCtx.TPOTObservations
166+
usage["predicted_tpot_observations_ms"] = reqCtx.PredictedTPOTObservations
167+
usage["avg_tpot_ms"] = reqCtx.AvgTPOT
168+
usage["avg_predicted_tpot_ms"] = reqCtx.AvgPredictedTPOT
169+
}
170+
if mod, err := json.Marshal(obj); err != nil {
171+
logger.Error(err, "failed to re-marshal modified JSON", "obj", obj)
172+
} else {
173+
payload = string(mod)
174+
}
175+
}
176+
177+
// Re-attach SSE prefix
178+
rebuilt.WriteString("data: ")
179+
rebuilt.WriteString(payload)
180+
rebuilt.WriteString("\n\n")
181+
}
182+
183+
// Feed into your existing chunker
184+
modified := []byte(rebuilt.String())
185+
commonResponses := buildCommonResponses(modified, bodyByteLimit, setEoS)
186+
187+
// Wrap as ProcessingResponses
188+
out := make([]*extProcPb.ProcessingResponse, 0, len(commonResponses))
189+
for _, cr := range commonResponses {
190+
out = append(out, &extProcPb.ProcessingResponse{
191+
Response: &extProcPb.ProcessingResponse_ResponseBody{
192+
ResponseBody: &extProcPb.BodyResponse{
193+
Response: cr,
194+
},
138195
},
139-
},
196+
})
140197
}
141-
responses = append(responses, resp)
198+
return out
199+
} else {
200+
commonResponses := buildCommonResponses(responseBodyBytes, bodyByteLimit, setEoS)
201+
responses := []*extProcPb.ProcessingResponse{}
202+
for _, commonResp := range commonResponses {
203+
resp := &extProcPb.ProcessingResponse{
204+
Response: &extProcPb.ProcessingResponse_ResponseBody{
205+
ResponseBody: &extProcPb.BodyResponse{
206+
Response: commonResp,
207+
},
208+
},
209+
}
210+
responses = append(responses, resp)
211+
}
212+
return responses
142213
}
143-
return responses
214+
144215
}
145216

146217
func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption {
@@ -180,7 +251,7 @@ func (s *StreamingServer) generateResponseTrailers(reqCtx *RequestContext) []*co
180251
}
181252

182253
// include all headers
183-
for key, value := range reqCtx.Response.Trailers{
254+
for key, value := range reqCtx.Response.Trailers {
184255
trailers = append(trailers, &configPb.HeaderValueOption{
185256
Header: &configPb.HeaderValue{
186257
Key: key,

pkg/epp/handlers/server.go

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,13 @@ type RequestContext struct {
106106
RequestState StreamRequestState
107107
ModelServerStreaming bool
108108

109-
TTFT float64
110-
PredictedTTFT float64
111-
PredictedTPOTObservations []float64
109+
TTFT float64
110+
PredictedTTFT float64
112111

113-
TPOTObservations []float64
112+
PredictedTPOTObservations []float64
113+
TPOTObservations []float64
114+
AvgTPOT float64
115+
AvgPredictedTPOT float64
114116

115117
TokenSampler *requtil.TokenSampler
116118

@@ -298,18 +300,21 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
298300
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
299301

300302
if s.director.IsPredictorAvailable() {
301-
var sumActual, sumPred float64
302-
for _, actual := range reqCtx.TPOTObservations {
303-
sumActual += actual
303+
// var sumActual, sumPred float64
304+
// for _, actual := range reqCtx.TPOTObservations {
305+
// sumActual += actual
304306

305-
}
306-
for _, prediction := range reqCtx.PredictedTPOTObservations {
307-
sumPred += prediction
307+
// }
308+
// for _, prediction := range reqCtx.PredictedTPOTObservations {
309+
// sumPred += prediction
308310

309-
}
311+
// }
310312

311-
avgActual := sumActual / float64(len(reqCtx.TPOTObservations))
312-
avgPred := sumPred / float64(len(reqCtx.PredictedTPOTObservations))
313+
// avgActual := sumActual / float64(len(reqCtx.TPOTObservations))
314+
// avgPred := sumPred / float64(len(reqCtx.PredictedTPOTObservations))
315+
316+
// reqCtx.AvgTPOT = avgActual
317+
// reqCtx.AvgPredictedTPOT = avgPred
313318

314319
// Compute MAPE for TTFT
315320
mapeTTFT := 0.0
@@ -324,19 +329,19 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
324329
}
325330

326331
mapeTPOT := 0.0
327-
if avgActual > 0 {
328-
mapeTPOT = math.Abs((avgActual-avgPred)/avgActual) * 100
329-
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", avgActual, "avgPredictedTPOT", avgPred)
332+
if reqCtx.AvgTPOT > 0 {
333+
mapeTPOT = math.Abs((reqCtx.AvgTPOT-reqCtx.AvgPredictedTPOT)/reqCtx.AvgTPOT) * 100
334+
logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", reqCtx.AvgTPOT, "avgPredictedTPOT", reqCtx.AvgPredictedTPOT)
330335
logger.V(logutil.DEBUG).Info("MAPE TPOT computed", "mapeTPOT%", mapeTPOT)
331-
metrics.RecordRequestTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, avgActual/1000)
332-
metrics.RecordRequestPredictedTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, avgPred/1000)
336+
metrics.RecordRequestTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgTPOT/1000)
337+
metrics.RecordRequestPredictedTPOT(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.AvgPredictedTPOT/1000)
333338
metrics.RecordRequestTPOTPredictionMape(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, mapeTPOT)
334339
}
335340
}
336341

337342
}
338343

339-
reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream)
344+
reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream, reqCtx, logger)
340345
} else {
341346
body = append(body, v.ResponseBody.Body...)
342347

@@ -349,12 +354,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
349354
var responseErr error
350355
responseErr = json.Unmarshal(body, &responseBody)
351356
if responseErr != nil {
352-
if logger.V(logutil.DEBUG).Enabled() {
353-
logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body))
354-
} else {
355-
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body))
356-
}
357-
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
357+
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshaling request body", "body", string(body))
358+
reqCtx.respBodyResp = generateResponseBodyResponses(body, true, reqCtx, logger)
358359
break
359360
}
360361

@@ -375,7 +376,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
375376
}
376377
}
377378
case *extProcPb.ProcessingRequest_ResponseTrailers:
378-
logger.V(logutil.DEBUG).Info("Processing response trailers", "trailers", v.ResponseTrailers.Trailers)
379+
logger.V(logutil.DEFAULT).Info("Processing response trailers", "trailers", v.ResponseTrailers.Trailers)
379380
if reqCtx.ModelServerStreaming {
380381

381382
var trailerErr error

pkg/epp/latencypredictorasync/latencypredictor_async_test.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Pre
281281

282282
// Test multiple predictions and measure time
283283
const numTests = 10
284-
const maxDurationMs = 500
284+
const avgDurationMs = 250
285285

286286
var totalDuration time.Duration
287287
var maxSingleDuration time.Duration
@@ -314,10 +314,6 @@ func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Pre
314314
t.Logf("Prediction %d: %.2fms - TTFT: %.1fms, TPOT: %.1fms",
315315
i+1, durationMs, response.TTFT, response.TPOT)
316316

317-
// Check if this prediction exceeded the target
318-
if durationMs > maxDurationMs {
319-
t.Errorf("Prediction %d took %.2fms, exceeded target of %dms", i+1, durationMs, maxDurationMs)
320-
}
321317
}
322318

323319
// Calculate statistics
@@ -330,13 +326,13 @@ func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Pre
330326
t.Logf(" Average: %.2fms", avgMs)
331327
t.Logf(" Minimum: %.2fms", minMs)
332328
t.Logf(" Maximum: %.2fms", maxMs)
333-
t.Logf(" Target: < %dms", maxDurationMs)
329+
t.Logf(" Target: < %dms", avgDurationMs)
334330

335331
// Overall performance check
336-
if avgMs > maxDurationMs {
337-
t.Errorf("Average prediction time %.2fms exceeded target of %dms", avgMs, maxDurationMs)
332+
if avgMs > avgDurationMs {
333+
t.Errorf("Average prediction time %.2fms exceeded target of %dms", avgMs, avgDurationMs)
338334
} else {
339-
t.Logf("✅ Performance target met: avg %.2fms < %dms", avgMs, maxDurationMs)
335+
t.Logf("✅ Performance target met: avg %.2fms < %dms", avgMs, avgDurationMs)
340336
}
341337

342338
// Check for consistency (max shouldn't be too much higher than average)
@@ -417,7 +413,7 @@ func testHTTPOnlyPerformance(t *testing.T, ctx context.Context) {
417413

418414
// Performance test
419415
const numTests = 15
420-
const targetMs = 500
416+
const targetMs = 250
421417

422418
var durations []time.Duration
423419
var successful int
@@ -441,9 +437,6 @@ func testHTTPOnlyPerformance(t *testing.T, ctx context.Context) {
441437
durationMs := float64(duration.Nanoseconds()) / 1e6
442438

443439
status := "✅"
444-
if durationMs > targetMs {
445-
status = "❌"
446-
}
447440

448441
t.Logf("%s Test %d: %.1fms (TTFT: %.0fms, TPOT: %.0fms)",
449442
status, i+1, durationMs, response.TTFT, response.TPOT)

0 commit comments

Comments
 (0)