Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,35 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requ
defer r.Body.Close()
require.NoError(t, err)

// Validate request body based on endpoint.
var validationErr error
if strings.Contains(r.URL.Path, "/chat/completions") {
validationErr = validateOpenAIChatCompletionRequest(body)
} else if strings.Contains(r.URL.Path, "/responses") {
validationErr = validateOpenAIResponsesRequest(body)
} else if strings.Contains(r.URL.Path, "/messages") {
validationErr = validateAnthropicMessagesRequest(body)
}

// If validation failed, return error response
if validationErr != nil {
// Return HTTP error response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
errResp := map[string]any{
"error": map[string]any{
"message": fmt.Sprintf("Request #%d validation failed: %v", ms.callCount.Load(), validationErr),
"type": "invalid_request_error",
},
}
json.NewEncoder(w).Encode(errResp)

// Mark test as failed with detailed message
t.Errorf("Request #%d validation failed: %v\n\nRequest body:\n%s",
ms.callCount.Load(), validationErr, string(body))
return
}

type msg struct {
Stream bool `json:"stream"`
}
Expand Down Expand Up @@ -2135,6 +2164,72 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requ
return ms
}

// validateOpenAIChatCompletionRequest validates that an OpenAI chat completion request
// has all required fields. Returns an error if validation fails.
func validateOpenAIChatCompletionRequest(body []byte) error {
var req openai.ChatCompletionNewParams
if err := json.Unmarshal(body, &req); err != nil {
return fmt.Errorf("request should unmarshal into ChatCompletionNewParams: %w", err)
}

// Collect all validation errors
var errs []string
if req.Model == "" {
errs = append(errs, "model field is required but empty")
}
if len(req.Messages) == 0 {
errs = append(errs, "messages field is required but empty")
}

if len(errs) > 0 {
return fmt.Errorf("validation failed: %s", strings.Join(errs, "; "))
}
return nil
}

// validateOpenAIResponsesRequest validates that an OpenAI responses request
// has all required fields. Returns an error if validation fails.
func validateOpenAIResponsesRequest(body []byte) error {
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return fmt.Errorf("request should be valid JSON: %w", err)
}

// Verify required fields for OpenAI responses
// Note: Using map here since there's no specific SDK type for responses endpoint
model, ok := reqBody["model"]
if !ok || model == "" {
return fmt.Errorf("model field is required but missing or empty")
}
return nil
Comment on lines +2193 to +2204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually use responses.ResponseNewParams:

Suggested change
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return fmt.Errorf("request should be valid JSON: %w", err)
}
// Verify required fields for OpenAI responses
// Note: Using map here since there's no specific SDK type for responses endpoint
model, ok := reqBody["model"]
if !ok || model == "" {
return fmt.Errorf("model field is required but missing or empty")
}
return nil
var req responses.ResponseNewParams
if err := json.Unmarshal(body, &req); err != nil {
return fmt.Errorf("request should unmarshal into ResponseNewParams: %w", err)
}
// Collect all validation errors
var errs []string
if req.Model == "" {
errs = append(errs, "model field is required but empty")
}
if len(errs) > 0 {
return fmt.Errorf("validation failed: %s", strings.Join(errs, "; "))
}
return nil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this as a follow-up 👍

}

// validateAnthropicMessagesRequest validates that an Anthropic messages request
// has all required fields. Returns an error if validation fails.
func validateAnthropicMessagesRequest(body []byte) error {
var req anthropic.MessageNewParams
if err := json.Unmarshal(body, &req); err != nil {
return fmt.Errorf("request should unmarshal into MessageNewParams: %w", err)
}

// Collect all validation errors
var errs []string
if req.Model == "" {
errs = append(errs, "model field is required but empty")
}
if len(req.Messages) == 0 {
errs = append(errs, "messages field is required but empty")
}
if req.MaxTokens == 0 {
errs = append(errs, "max_tokens field is required but zero")
}

if len(errs) > 0 {
return fmt.Errorf("validation failed: %s", strings.Join(errs, "; "))
}
return nil
}

const mockToolName = "coder_list_workspaces"

// callAccumulator tracks all tool invocations by name and each instance's arguments.
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ require (
replace github.com/anthropics/anthropic-sdk-go v1.13.0 => github.com/dannykopping/anthropic-sdk-go v0.0.0-20251230111224-88a4315810bd

// https://github.com/openai/openai-go/pull/602
replace github.com/openai/openai-go/v3 => github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95
replace github.com/openai/openai-go/v3 => github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0Fv
cloud.google.com/go/logging v1.8.1/go.mod h1:TJjR+SimHwuC8MZ9cjByQulAMgni+RkXeI3wwctHJEI=
cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI=
cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc=
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95 h1:HVJp3FanNaeFAlwg0/lkdkSnwFemHnwwjXBM8KRj540=
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728 h1:FOjd3xOH+arcrtz1e5P6WZ/VtRD5KQHHRg4kc4BZers=
github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=
Expand Down
12 changes: 11 additions & 1 deletion intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
}

// We take control of request body here and pass it to the SDK as a raw byte slice.
// This is because the SDK's serialization applies hidden request options that result in
// unexpected, breaking behaviour. See https://github.com/coder/aibridge/pull/164
body, err := json.Marshal(i.req.ChatCompletionNewParams)
if err != nil {
return fmt.Errorf("marshal request body: %w", err)
}
opts = append(opts, option.WithRequestBody("application/json", body))
opts = append(opts, option.WithJSONSet("stream", true))

stream = i.newStream(streamCtx, svc, opts)
processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName)

Expand Down Expand Up @@ -380,7 +390,7 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()

return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams, opts...)
return svc.NewStreaming(ctx, openai.ChatCompletionNewParams{}, opts...)
}

type streamProcessor struct {
Expand Down