@@ -22,7 +22,10 @@ import (
22
22
"strings"
23
23
24
24
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
25
+ filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
25
26
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27
+ "github.com/go-logr/logr"
28
+
26
29
"sigs.k8s.io/controller-runtime/pkg/log"
27
30
28
31
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -60,7 +63,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
60
63
// will add the processing for streaming case.
61
64
reqCtx .ResponseComplete = true
62
65
63
- reqCtx .respBodyResp = generateResponseBodyResponses (responseBytes , true )
66
+ reqCtx .respBodyResp = generateResponseBodyResponses (responseBytes , true , reqCtx , logger )
64
67
return reqCtx , nil
65
68
}
66
69
@@ -75,12 +78,11 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
75
78
s .director .HandleResponseBodyChunk (ctx , reqCtx )
76
79
}
77
80
78
-
79
81
// The function is to handle streaming response if the modelServer is streaming.
80
82
func (s * StreamingServer ) HandleResponseTrailers (
81
83
ctx context.Context ,
82
84
reqCtx * RequestContext ,
83
- ) (* RequestContext , error ) {
85
+ ) (* RequestContext , error ) {
84
86
85
87
return s .director .HandleResponseTrailers (ctx , reqCtx )
86
88
}
@@ -110,6 +112,9 @@ func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext)
110
112
},
111
113
},
112
114
},
115
+ ModeOverride : & filterPb.ProcessingMode {
116
+ ResponseTrailerMode : filterPb .ProcessingMode_SEND ,
117
+ },
113
118
}
114
119
}
115
120
@@ -118,29 +123,95 @@ func (s *StreamingServer) generateResponseTrailerResponse(reqCtx *RequestContext
118
123
return & extProcPb.ProcessingResponse {
119
124
Response : & extProcPb.ProcessingResponse_ResponseTrailers {
120
125
ResponseTrailers : & extProcPb.TrailersResponse {
121
- HeaderMutation : & extProcPb.HeaderMutation {
122
- // Correct field or remove if unnecessary
123
- SetHeaders : s .generateResponseTrailers (reqCtx ),
124
- },
126
+ HeaderMutation : & extProcPb.HeaderMutation {
127
+ // Correct field or remove if unnecessary
128
+ SetHeaders : s .generateResponseTrailers (reqCtx ),
125
129
},
126
130
},
127
- }
131
+ },
128
132
}
133
+ }
134
+
135
+ func generateResponseBodyResponses (
136
+ responseBodyBytes []byte ,
137
+ setEoS bool ,
138
+ reqCtx * RequestContext ,
139
+ logger logr.Logger ,
140
+ ) []* extProcPb.ProcessingResponse {
141
+ if reqCtx != nil && reqCtx .ModelServerStreaming {
142
+
143
+ raw := string (responseBodyBytes )
144
+ events := strings .Split (raw , "\n \n " )
129
145
130
- func generateResponseBodyResponses (responseBodyBytes []byte , setEoS bool ) []* extProcPb.ProcessingResponse {
131
- commonResponses := buildCommonResponses (responseBodyBytes , bodyByteLimit , setEoS )
132
- responses := []* extProcPb.ProcessingResponse {}
133
- for _ , commonResp := range commonResponses {
134
- resp := & extProcPb.ProcessingResponse {
135
- Response : & extProcPb.ProcessingResponse_ResponseBody {
136
- ResponseBody : & extProcPb.BodyResponse {
137
- Response : commonResp ,
146
+ var rebuilt strings.Builder
147
+ for _ , ev := range events {
148
+ if ! strings .HasPrefix (ev , "data: " ) {
149
+ continue
150
+ }
151
+ payload := strings .TrimPrefix (ev , "data: " )
152
+ if payload == "[DONE]" {
153
+ rebuilt .WriteString ("data: [DONE]\n \n " )
154
+ continue
155
+ }
156
+
157
+ // Try to unmarshal only the JSON
158
+ var obj map [string ]interface {}
159
+ if err := json .Unmarshal ([]byte (payload ), & obj ); err != nil {
160
+ logger .Error (err , "failed to unmarshal SSE payload" , "payload" , payload )
161
+ } else {
162
+ if usage , ok := obj ["usage" ].(map [string ]interface {}); ok && usage != nil {
163
+ usage ["ttft_ms" ] = reqCtx .TTFT
164
+ usage ["predicted_ttft_ms" ] = reqCtx .PredictedTTFT
165
+ usage ["tpot_observations_ms" ] = reqCtx .TPOTObservations
166
+ usage ["predicted_tpot_observations_ms" ] = reqCtx .PredictedTPOTObservations
167
+ usage ["avg_tpot_ms" ] = reqCtx .AvgTPOT
168
+ usage ["avg_predicted_tpot_ms" ] = reqCtx .AvgPredictedTPOT
169
+ }
170
+ if mod , err := json .Marshal (obj ); err != nil {
171
+ logger .Error (err , "failed to re-marshal modified JSON" , "obj" , obj )
172
+ } else {
173
+ payload = string (mod )
174
+ }
175
+ }
176
+
177
+ // Re-attach SSE prefix
178
+ rebuilt .WriteString ("data: " )
179
+ rebuilt .WriteString (payload )
180
+ rebuilt .WriteString ("\n \n " )
181
+ }
182
+
183
+ // Feed into your existing chunker
184
+ modified := []byte (rebuilt .String ())
185
+ commonResponses := buildCommonResponses (modified , bodyByteLimit , setEoS )
186
+
187
+ // Wrap as ProcessingResponses
188
+ out := make ([]* extProcPb.ProcessingResponse , 0 , len (commonResponses ))
189
+ for _ , cr := range commonResponses {
190
+ out = append (out , & extProcPb.ProcessingResponse {
191
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
192
+ ResponseBody : & extProcPb.BodyResponse {
193
+ Response : cr ,
194
+ },
138
195
},
139
- },
196
+ })
140
197
}
141
- responses = append (responses , resp )
198
+ return out
199
+ } else {
200
+ commonResponses := buildCommonResponses (responseBodyBytes , bodyByteLimit , setEoS )
201
+ responses := []* extProcPb.ProcessingResponse {}
202
+ for _ , commonResp := range commonResponses {
203
+ resp := & extProcPb.ProcessingResponse {
204
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
205
+ ResponseBody : & extProcPb.BodyResponse {
206
+ Response : commonResp ,
207
+ },
208
+ },
209
+ }
210
+ responses = append (responses , resp )
211
+ }
212
+ return responses
142
213
}
143
- return responses
214
+
144
215
}
145
216
146
217
func (s * StreamingServer ) generateResponseHeaders (reqCtx * RequestContext ) []* configPb.HeaderValueOption {
@@ -180,7 +251,7 @@ func (s *StreamingServer) generateResponseTrailers(reqCtx *RequestContext) []*co
180
251
}
181
252
182
253
// include all headers
183
- for key , value := range reqCtx .Response .Trailers {
254
+ for key , value := range reqCtx .Response .Trailers {
184
255
trailers = append (trailers , & configPb.HeaderValueOption {
185
256
Header : & configPb.HeaderValue {
186
257
Key : key ,
0 commit comments