Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 56 additions & 105 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,33 +514,31 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
}
return nil, newBifrostError(err)
}

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
var resp *schemas.BifrostResponse
var processedPluginCount int
for i, plugin := range bifrost.plugins {
req, resp, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
return nil, newBifrostError(err)
}
processedPluginCount = i + 1
if resp != nil {
// Run post-hooks in reverse order for plugins that had PreHook executed
for j := processedPluginCount - 1; j >= 0; j-- {
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
if err != nil {
return nil, newBifrostError(err)
}
}
return resp, nil
}
}

if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "bifrost request after plugin hooks cannot be nil",
},
}
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

// Get a ChannelMessage from the pool
Expand All @@ -554,23 +552,13 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
case <-ctx.Done():
// Request was cancelled by caller
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request cancelled while waiting for queue space",
},
}
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
// Drop request immediately if configured to do so
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request dropped: queue is full",
},
}
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}

// If not dropping excess requests, wait with context
Expand All @@ -582,12 +570,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request cancelled while waiting for queue space",
},
}
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
}
}

Expand All @@ -600,12 +583,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
result, err = bifrost.plugins[i].PostHook(&ctx, result)
if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
}
return nil, newBifrostError(err)
}
}
case err := <-msg.Err:
Expand All @@ -623,30 +601,15 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "bifrost request cannot be nil",
},
}
return nil, newBifrostErrorFromMsg("bifrost request cannot be nil")
}

if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "provider is required",
},
}
return nil, newBifrostErrorFromMsg("provider is required")
}

if req.Model == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "model is required",
},
}
return nil, newBifrostErrorFromMsg("model is required")
}

// Try the primary provider first
Expand Down Expand Up @@ -688,33 +651,31 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
}
return nil, newBifrostError(err)
}

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
var resp *schemas.BifrostResponse
var processedPluginCount int
for i, plugin := range bifrost.plugins {
req, resp, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
return nil, newBifrostError(err)
}
processedPluginCount = i + 1
if resp != nil {
// Run post-hooks in reverse order for plugins that had PreHook executed
for j := processedPluginCount - 1; j >= 0; j-- {
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
if err != nil {
return nil, newBifrostError(err)
}
}
return resp, nil
}
}

if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "bifrost request after plugin hooks cannot be nil",
},
}
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

// Get a ChannelMessage from the pool
Expand All @@ -728,23 +689,13 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
case <-ctx.Done():
// Request was cancelled by caller
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request cancelled while waiting for queue space",
},
}
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
// Drop request immediately if configured to do so
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request dropped: queue is full",
},
}
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}
// If not dropping excess requests, wait with context
if ctx == nil {
Expand All @@ -755,12 +706,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request cancelled while waiting for queue space",
},
}
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
}
}

Expand All @@ -773,12 +719,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
result, err = bifrost.plugins[i].PostHook(&ctx, result)
if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
},
}
return nil, newBifrostError(err)
}
}
case err := <-msg.Err:
Expand All @@ -794,7 +735,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
// Cleanup gracefully stops all workers when triggered.
// It closes all request channels and waits for workers to exit.
func (bifrost *Bifrost) Cleanup() {
bifrost.logger.Info("[BIFROST] Graceful Cleanup Initiated - Closing all request channels...")
bifrost.logger.Info("Graceful Cleanup Initiated - Closing all request channels...")

// Close all provider queues to signal workers to stop
for _, queue := range bifrost.requestQueues {
Expand All @@ -805,4 +746,14 @@ func (bifrost *Bifrost) Cleanup() {
for _, waitGroup := range bifrost.waitGroups {
waitGroup.Wait()
}

// Cleanup plugins
for _, plugin := range bifrost.plugins {
err := plugin.Cleanup()
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error cleaning up plugin: %s", err.Error()))
}
}

bifrost.logger.Info("Graceful Cleanup Completed")
}
12 changes: 10 additions & 2 deletions core/schemas/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ import "context"
// - Logging
// - Monitoring

// No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance.

type Plugin interface {
// PreHook is called before a request is processed by a provider.
// It allows plugins to modify the request before it is sent to the provider.
// The context parameter can be used to maintain state across plugin calls.
// Returns the modified request and any error that occurred during processing.
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error)
// Returns the modified request, an optional response (if the plugin wants to short-circuit the provider call), and any error that occurred during processing.
// If a response is returned, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order.
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error)

// PostHook is called after a response is received from a provider.
// It allows plugins to modify the response before it is returned to the caller.
// Returns the modified response and any error that occurred during processing.
PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error)

// Cleanup is called on bifrost shutdown.
// It allows plugins to clean up any resources they have allocated.
// Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance.
Cleanup() error
}
25 changes: 25 additions & 0 deletions core/utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
package bifrost

import schemas "github.com/maximhq/bifrost/core/schemas"

func Ptr[T any](v T) *T {
return &v
}

// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
// This helper function reduces code duplication when handling non-Bifrost errors.
func newBifrostError(err error) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
Error: err,
},
}
}

// newBifrostErrorFromMsg creates a BifrostError with a custom message.
// This helper function is used for static error messages.
func newBifrostErrorFromMsg(message string) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: message,
},
}
}
Loading