Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/body-based-routing/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBod

var requestBody map[string]interface{}
if s.streaming {
streamedBody.body = append(streamedBody.body, body.Body...)
// In the stream case, we can receive multiple request bodies.
if !body.EndOfStream {
streamedBody.body = append(streamedBody.body, body.Body...)
return nil, nil
} else {
if body.EndOfStream {
loggerVerbose.Info("Flushing stream buffer")
err := json.Unmarshal(streamedBody.body, &requestBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
}
} else {
return nil, nil
}
} else {
if err := json.Unmarshal(body.GetBody(), &requestBody); err != nil {
Expand Down
210 changes: 165 additions & 45 deletions test/integration/bbr/hermetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@ package bbr

import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/testing/protocmp"
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/server"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration"
)

var logger = logutil.NewTestLogger().V(logutil.VERBOSE)
Expand All @@ -46,7 +45,7 @@ func TestBodyBasedRouting(t *testing.T) {
}{
{
name: "success adding model parameter to header",
req: generateRequest(logger, "llama"),
req: integrationutils.GenerateRequest(logger, "test", "llama"),
wantHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Expand All @@ -59,15 +58,15 @@ func TestBodyBasedRouting(t *testing.T) {
},
{
name: "no model parameter",
req: generateRequest(logger, ""),
req: integrationutils.GenerateRequest(logger, "test1", ""),
wantHeaders: []*configPb.HeaderValueOption{},
wantErr: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, cleanup := setUpHermeticServer()
client, cleanup := setUpHermeticServer(false)
t.Cleanup(cleanup)

want := &extProcPb.ProcessingResponse{}
Expand All @@ -88,7 +87,7 @@ func TestBodyBasedRouting(t *testing.T) {
}
}

res, err := sendRequest(t, client, test.req)
res, err := integrationutils.SendRequest(t, client, test.req)
if err != nil && !test.wantErr {
t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr)
}
Expand All @@ -99,12 +98,171 @@ func TestBodyBasedRouting(t *testing.T) {
}
}

func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) {
tests := []struct {
name string
reqs []*extProcPb.ProcessingRequest
wantResponses []*extProcPb.ProcessingResponse
wantErr bool
}{
{
name: "success adding model parameter to header",
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", "foo"),
wantResponses: []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
ClearRouteCache: true,
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "X-Gateway-Model-Name",
RawValue: []byte("foo"),
},
},
}},
},
},
},
},
{
Response: &extProcPb.ProcessingResponse_RequestBody{
RequestBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_StreamedResponse{
StreamedResponse: &extProcPb.StreamedBodyResponse{
Body: []byte("{\"max_tokens\":100,\"model\":\"foo\",\"prompt\":\"test\",\"temperature\":0}"),
EndOfStream: true,
},
},
},
},
},
},
},
},
},
{
name: "success adding model parameter to header with multiple body chunks",
reqs: []*extProcPb.ProcessingRequest{
{
Request: &extProcPb.ProcessingRequest_RequestHeaders{
RequestHeaders: &extProcPb.HttpHeaders{
Headers: &configPb.HeaderMap{
Headers: []*configPb.HeaderValue{
{
Key: "hi",
Value: "mom",
},
},
},
},
},
},
{
Request: &extProcPb.ProcessingRequest_RequestBody{
RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false},
},
},
{
Request: &extProcPb.ProcessingRequest_RequestBody{
RequestBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test\",\"temperature\":0}"), EndOfStream: true},
},
},
},
wantResponses: []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
ClearRouteCache: true,
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "X-Gateway-Model-Name",
RawValue: []byte("sql-lora-sheddable"),
},
},
}},
},
},
},
},
{
Response: &extProcPb.ProcessingResponse_RequestBody{
RequestBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_StreamedResponse{
StreamedResponse: &extProcPb.StreamedBodyResponse{
Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test\",\"temperature\":0}"),
EndOfStream: true,
},
},
},
},
},
},
},
},
},
{
name: "no model parameter",
reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", ""),
wantResponses: []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{},
},
},
{
Response: &extProcPb.ProcessingResponse_RequestBody{
RequestBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_StreamedResponse{
StreamedResponse: &extProcPb.StreamedBodyResponse{
Body: []byte("{\"max_tokens\":100,\"prompt\":\"test\",\"temperature\":0}"),
EndOfStream: true,
},
},
},
},
},
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, cleanup := setUpHermeticServer(true)
t.Cleanup(cleanup)

responses, err := integrationutils.StreamedRequest(t, client, test.reqs, len(test.wantResponses))
if err != nil && !test.wantErr {
t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr)
}

if diff := cmp.Diff(test.wantResponses, responses, protocmp.Transform()); diff != "" {
t.Errorf("Unexpected response, (-want +got): %v", diff)
}
})
}
}

func setUpHermeticServer(streaming bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
port := 9004

serverCtx, stopServer := context.WithCancel(context.Background())
serverRunner := runserver.NewDefaultExtProcServerRunner(port, false)
serverRunner.SecureServing = false
serverRunner.Streaming = streaming

go func() {
if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil {
Expand Down Expand Up @@ -133,41 +291,3 @@ func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cl
time.Sleep(5 * time.Second)
}
}

func generateRequest(logger logr.Logger, model string) *extProcPb.ProcessingRequest {
j := map[string]interface{}{
"prompt": "test1",
"max_tokens": 100,
"temperature": 0,
}
if model != "" {
j["model"] = model
}

llmReq, err := json.Marshal(j)
if err != nil {
logutil.Fatal(logger, err, "Failed to unmarshal LLM request")
}
req := &extProcPb.ProcessingRequest{
Request: &extProcPb.ProcessingRequest_RequestBody{
RequestBody: &extProcPb.HttpBody{Body: llmReq},
},
}
return req
}

func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) {
t.Logf("Sending request: %v", req)
if err := client.Send(req); err != nil {
t.Logf("Failed to send request %+v: %v", req, err)
return nil, err
}

res, err := client.Recv()
if err != nil {
t.Logf("Failed to receive: %v", err)
return nil, err
}
t.Logf("Received request %+v", res)
return res, err
}
8 changes: 5 additions & 3 deletions test/integration/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient,
t.Logf("Failed to receive: %v", err)
return nil, err
}
t.Logf("Received request %+v", res)
t.Logf("Received response %+v", res)
return res, err
}

Expand Down Expand Up @@ -71,19 +71,21 @@ func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessCli
t.Logf("Failed to receive: %v", err)
return nil, err
}
t.Logf("Received request %+v", res)
t.Logf("Received response %+v", res)
responses = append(responses, res)
}
return responses, nil
}

func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.ProcessingRequest {
j := map[string]interface{}{
"model": model,
"prompt": prompt,
"max_tokens": 100,
"temperature": 0,
}
if model != "" {
j["model"] = model
}

llmReq, err := json.Marshal(j)
if err != nil {
Expand Down