Skip to content

Commit 4ad600d

Browse files
feat: openai integration added to http transport
1 parent ac54c86 commit 4ad600d

File tree

6 files changed

+782
-38
lines changed

6 files changed

+782
-38
lines changed
Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,68 @@
11
package genai
22

33
import (
4-
"encoding/json"
4+
"fmt"
5+
"strings"
56

6-
"github.com/fasthttp/router"
77
bifrost "github.com/maximhq/bifrost/core"
8-
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
8+
"github.com/maximhq/bifrost/core/schemas"
9+
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
910
"github.com/valyala/fasthttp"
1011
)
1112

1213
// GenAIRouter holds route registrations for genai endpoints.
1314
type GenAIRouter struct {
14-
client *bifrost.Bifrost
15+
*integrations.GenericRouter
1516
}
1617

1718
// NewGenAIRouter creates a new GenAIRouter with the given bifrost client.
1819
func NewGenAIRouter(client *bifrost.Bifrost) *GenAIRouter {
19-
return &GenAIRouter{client: client}
20-
}
20+
routes := []integrations.RouteConfig{
21+
{
22+
Path: "/genai/v1beta/models/{model}",
23+
Method: "POST",
24+
RequestType: &GeminiChatRequest{},
25+
RequestConverter: func(req interface{}) *schemas.BifrostRequest {
26+
if geminiReq, ok := req.(*GeminiChatRequest); ok {
27+
return geminiReq.ConvertToBifrostRequest()
28+
}
29+
return nil
30+
},
31+
ResponseFunc: func(resp *schemas.BifrostResponse) interface{} {
32+
return DeriveGenAIFromBifrostResponse(resp)
33+
},
34+
PreCallback: extractAndSetModelFromURL,
35+
},
36+
}
2137

22-
// RegisterRoutes registers all genai routes on the given router.
23-
func (g *GenAIRouter) RegisterRoutes(r *router.Router) {
24-
r.POST("/genai/v1beta/models/{model}", g.handleChatCompletion)
38+
return &GenAIRouter{
39+
GenericRouter: integrations.NewGenericRouter(client, routes),
40+
}
2541
}
2642

27-
// handleChatCompletion handles POST /genai/v1beta/models/{model}
28-
func (g *GenAIRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) {
43+
// extractAndSetModelFromURL extracts model from URL and sets it in the request
44+
func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, req interface{}) error {
2945
model := ctx.UserValue("model")
3046
if model == nil {
31-
ctx.SetStatusCode(fasthttp.StatusBadRequest)
32-
ctx.SetBodyString("Model parameter is required")
33-
return
47+
return fmt.Errorf("model parameter is required")
3448
}
49+
3550
modelStr := model.(string)
36-
modelStr = modelStr[:len(modelStr)-len(":generateContent")]
51+
// Remove :generateContent suffix if present
52+
modelStr = strings.TrimSuffix(modelStr, ":generateContent")
53+
// Remove trailing colon if present
3754
if len(modelStr) > 0 && modelStr[len(modelStr)-1] == ':' {
3855
modelStr = modelStr[:len(modelStr)-1]
3956
}
4057

41-
var req GeminiChatRequest
42-
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
43-
ctx.SetStatusCode(fasthttp.StatusBadRequest)
44-
json.NewEncoder(ctx).Encode(err)
45-
return
46-
}
47-
48-
bifrostReq := req.ConvertToBifrostRequest("google/" + modelStr)
49-
50-
bifrostCtx := lib.ConvertToBifrostContext(ctx)
58+
// Add google/ prefix for Bifrost
59+
processedModel := "google/" + modelStr
5160

52-
result, err := g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq)
53-
if err != nil {
54-
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
55-
json.NewEncoder(ctx).Encode(err)
56-
return
61+
// Set the model in the request
62+
if geminiReq, ok := req.(*GeminiChatRequest); ok {
63+
geminiReq.Model = processedModel
64+
return nil
5765
}
5866

59-
genAIResponse := DeriveGenAIFromBifrostResponse(result)
60-
ctx.SetStatusCode(fasthttp.StatusOK)
61-
ctx.SetContentType("application/json")
62-
json.NewEncoder(ctx).Encode(genAIResponse)
67+
return fmt.Errorf("invalid request type for GenAI")
6368
}

transports/bifrost-http/integrations/genai/types.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction))
1313

1414
type GeminiChatRequest struct {
15+
Model string `json:"model,omitempty"` // Model field for explicit model specification
1516
Contents []genai_sdk.Content `json:"contents"`
1617
GenerationConfig genai_sdk.GenerationConfig `json:"generationConfig,omitempty"`
1718
SafetySettings []genai_sdk.SafetySetting `json:"safetySettings,omitempty"`
@@ -20,10 +21,10 @@ type GeminiChatRequest struct {
2021
Labels map[string]string `json:"labels,omitempty"`
2122
}
2223

23-
func (r *GeminiChatRequest) ConvertToBifrostRequest(modelStr string) *schemas.BifrostRequest {
24+
func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest {
2425
bifrostReq := &schemas.BifrostRequest{
2526
Provider: schemas.Vertex,
26-
Model: modelStr,
27+
Model: r.Model,
2728
Input: schemas.RequestInput{
2829
ChatCompletionInput: &[]schemas.BifrostMessage{},
2930
},
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package openai
2+
3+
import (
4+
bifrost "github.com/maximhq/bifrost/core"
5+
"github.com/maximhq/bifrost/core/schemas"
6+
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
7+
)
8+
9+
// OpenAIRouter holds route registrations for OpenAI endpoints.
10+
// It supports standard chat completions and image-enabled vision capabilities.
11+
type OpenAIRouter struct {
12+
*integrations.GenericRouter
13+
}
14+
15+
// NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client.
16+
func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter {
17+
routes := []integrations.RouteConfig{
18+
{
19+
Path: "/openai/v1/chat/completions",
20+
Method: "POST",
21+
RequestType: &OpenAIChatRequest{},
22+
RequestConverter: func(req interface{}) *schemas.BifrostRequest {
23+
if openaiReq, ok := req.(*OpenAIChatRequest); ok {
24+
return openaiReq.ConvertToBifrostRequest()
25+
}
26+
return nil
27+
},
28+
ResponseFunc: func(resp *schemas.BifrostResponse) interface{} {
29+
return DeriveOpenAIFromBifrostResponse(resp)
30+
},
31+
},
32+
}
33+
34+
return &OpenAIRouter{
35+
GenericRouter: integrations.NewGenericRouter(client, routes),
36+
}
37+
}

0 commit comments

Comments
 (0)