|
1 | 1 | package genai
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - "encoding/json" |
| 4 | + "fmt" |
| 5 | + "strings" |
5 | 6 |
|
6 |
| - "github.com/fasthttp/router" |
7 | 7 | bifrost "github.com/maximhq/bifrost/core"
|
8 |
| - "github.com/maximhq/bifrost/transports/bifrost-http/lib" |
| 8 | + "github.com/maximhq/bifrost/core/schemas" |
| 9 | + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" |
9 | 10 | "github.com/valyala/fasthttp"
|
10 | 11 | )
|
11 | 12 |
|
12 | 13 | // GenAIRouter holds route registrations for genai endpoints.
|
13 | 14 | type GenAIRouter struct {
|
14 |
| - client *bifrost.Bifrost |
| 15 | + *integrations.GenericRouter |
15 | 16 | }
|
16 | 17 |
|
17 | 18 | // NewGenAIRouter creates a new GenAIRouter with the given bifrost client.
|
18 | 19 | func NewGenAIRouter(client *bifrost.Bifrost) *GenAIRouter {
|
19 |
| - return &GenAIRouter{client: client} |
20 |
| -} |
| 20 | + routes := []integrations.RouteConfig{ |
| 21 | + { |
| 22 | + Path: "/genai/v1beta/models/{model}", |
| 23 | + Method: "POST", |
| 24 | + RequestType: &GeminiChatRequest{}, |
| 25 | + RequestConverter: func(req interface{}) *schemas.BifrostRequest { |
| 26 | + if geminiReq, ok := req.(*GeminiChatRequest); ok { |
| 27 | + return geminiReq.ConvertToBifrostRequest() |
| 28 | + } |
| 29 | + return nil |
| 30 | + }, |
| 31 | + ResponseFunc: func(resp *schemas.BifrostResponse) interface{} { |
| 32 | + return DeriveGenAIFromBifrostResponse(resp) |
| 33 | + }, |
| 34 | + PreCallback: extractAndSetModelFromURL, |
| 35 | + }, |
| 36 | + } |
21 | 37 |
|
22 |
| -// RegisterRoutes registers all genai routes on the given router. |
23 |
| -func (g *GenAIRouter) RegisterRoutes(r *router.Router) { |
24 |
| - r.POST("/genai/v1beta/models/{model}", g.handleChatCompletion) |
| 38 | + return &GenAIRouter{ |
| 39 | + GenericRouter: integrations.NewGenericRouter(client, routes), |
| 40 | + } |
25 | 41 | }
|
26 | 42 |
|
27 |
| -// handleChatCompletion handles POST /genai/v1beta/models/{model} |
28 |
| -func (g *GenAIRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { |
| 43 | +// extractAndSetModelFromURL extracts model from URL and sets it in the request |
| 44 | +func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, req interface{}) error { |
29 | 45 | model := ctx.UserValue("model")
|
30 | 46 | if model == nil {
|
31 |
| - ctx.SetStatusCode(fasthttp.StatusBadRequest) |
32 |
| - ctx.SetBodyString("Model parameter is required") |
33 |
| - return |
| 47 | + return fmt.Errorf("model parameter is required") |
34 | 48 | }
|
| 49 | + |
35 | 50 | modelStr := model.(string)
|
36 |
| - modelStr = modelStr[:len(modelStr)-len(":generateContent")] |
| 51 | + // Remove :generateContent suffix if present |
| 52 | + modelStr = strings.TrimSuffix(modelStr, ":generateContent") |
| 53 | + // Remove trailing colon if present |
37 | 54 | if len(modelStr) > 0 && modelStr[len(modelStr)-1] == ':' {
|
38 | 55 | modelStr = modelStr[:len(modelStr)-1]
|
39 | 56 | }
|
40 | 57 |
|
41 |
| - var req GeminiChatRequest |
42 |
| - if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { |
43 |
| - ctx.SetStatusCode(fasthttp.StatusBadRequest) |
44 |
| - json.NewEncoder(ctx).Encode(err) |
45 |
| - return |
46 |
| - } |
47 |
| - |
48 |
| - bifrostReq := req.ConvertToBifrostRequest("google/" + modelStr) |
49 |
| - |
50 |
| - bifrostCtx := lib.ConvertToBifrostContext(ctx) |
| 58 | + // Add google/ prefix for Bifrost |
| 59 | + processedModel := "google/" + modelStr |
51 | 60 |
|
52 |
| - result, err := g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) |
53 |
| - if err != nil { |
54 |
| - ctx.SetStatusCode(fasthttp.StatusInternalServerError) |
55 |
| - json.NewEncoder(ctx).Encode(err) |
56 |
| - return |
| 61 | + // Set the model in the request |
| 62 | + if geminiReq, ok := req.(*GeminiChatRequest); ok { |
| 63 | + geminiReq.Model = processedModel |
| 64 | + return nil |
57 | 65 | }
|
58 | 66 |
|
59 |
| - genAIResponse := DeriveGenAIFromBifrostResponse(result) |
60 |
| - ctx.SetStatusCode(fasthttp.StatusOK) |
61 |
| - ctx.SetContentType("application/json") |
62 |
| - json.NewEncoder(ctx).Encode(genAIResponse) |
| 67 | + return fmt.Errorf("invalid request type for GenAI") |
63 | 68 | }
|
0 commit comments