Skip to content

Commit 58dfdd3

Browse files
chore: plugin support extensions
1 parent bc74f7d commit 58dfdd3

File tree

4 files changed

+138
-120
lines changed

4 files changed

+138
-120
lines changed

core/bifrost.go

Lines changed: 56 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -514,33 +514,31 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
514514
func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
515515
queue, err := bifrost.getProviderQueue(req.Provider)
516516
if err != nil {
517-
return nil, &schemas.BifrostError{
518-
IsBifrostError: false,
519-
Error: schemas.ErrorField{
520-
Message: err.Error(),
521-
},
522-
}
517+
return nil, newBifrostError(err)
523518
}
524519

525-
for _, plugin := range bifrost.plugins {
526-
req, err = plugin.PreHook(&ctx, req)
520+
var resp *schemas.BifrostResponse
521+
var processedPluginCount int
522+
for i, plugin := range bifrost.plugins {
523+
req, resp, err = plugin.PreHook(&ctx, req)
527524
if err != nil {
528-
return nil, &schemas.BifrostError{
529-
IsBifrostError: false,
530-
Error: schemas.ErrorField{
531-
Message: err.Error(),
532-
},
525+
return nil, newBifrostError(err)
526+
}
527+
processedPluginCount = i + 1
528+
if resp != nil {
529+
// Run post-hooks in reverse order for plugins that had PreHook executed
530+
for j := processedPluginCount - 1; j >= 0; j-- {
531+
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
532+
if err != nil {
533+
return nil, newBifrostError(err)
534+
}
533535
}
536+
return resp, nil
534537
}
535538
}
536539

537540
if req == nil {
538-
return nil, &schemas.BifrostError{
539-
IsBifrostError: false,
540-
Error: schemas.ErrorField{
541-
Message: "bifrost request after plugin hooks cannot be nil",
542-
},
543-
}
541+
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
544542
}
545543

546544
// Get a ChannelMessage from the pool
@@ -554,23 +552,13 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
554552
case <-ctx.Done():
555553
// Request was cancelled by caller
556554
bifrost.releaseChannelMessage(msg)
557-
return nil, &schemas.BifrostError{
558-
IsBifrostError: false,
559-
Error: schemas.ErrorField{
560-
Message: "request cancelled while waiting for queue space",
561-
},
562-
}
555+
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
563556
default:
564557
if bifrost.dropExcessRequests {
565558
// Drop request immediately if configured to do so
566559
bifrost.releaseChannelMessage(msg)
567560
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
568-
return nil, &schemas.BifrostError{
569-
IsBifrostError: false,
570-
Error: schemas.ErrorField{
571-
Message: "request dropped: queue is full",
572-
},
573-
}
561+
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
574562
}
575563

576564
// If not dropping excess requests, wait with context
@@ -582,12 +570,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
582570
// Message was sent successfully
583571
case <-ctx.Done():
584572
bifrost.releaseChannelMessage(msg)
585-
return nil, &schemas.BifrostError{
586-
IsBifrostError: false,
587-
Error: schemas.ErrorField{
588-
Message: "request cancelled while waiting for queue space",
589-
},
590-
}
573+
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
591574
}
592575
}
593576

@@ -600,12 +583,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
600583
result, err = bifrost.plugins[i].PostHook(&ctx, result)
601584
if err != nil {
602585
bifrost.releaseChannelMessage(msg)
603-
return nil, &schemas.BifrostError{
604-
IsBifrostError: false,
605-
Error: schemas.ErrorField{
606-
Message: err.Error(),
607-
},
608-
}
586+
return nil, newBifrostError(err)
609587
}
610588
}
611589
case err := <-msg.Err:
@@ -623,30 +601,15 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
623601
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
624602
func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
625603
if req == nil {
626-
return nil, &schemas.BifrostError{
627-
IsBifrostError: false,
628-
Error: schemas.ErrorField{
629-
Message: "bifrost request cannot be nil",
630-
},
631-
}
604+
return nil, newBifrostErrorFromMsg("bifrost request cannot be nil")
632605
}
633606

634607
if req.Provider == "" {
635-
return nil, &schemas.BifrostError{
636-
IsBifrostError: false,
637-
Error: schemas.ErrorField{
638-
Message: "provider is required",
639-
},
640-
}
608+
return nil, newBifrostErrorFromMsg("provider is required")
641609
}
642610

643611
if req.Model == "" {
644-
return nil, &schemas.BifrostError{
645-
IsBifrostError: false,
646-
Error: schemas.ErrorField{
647-
Message: "model is required",
648-
},
649-
}
612+
return nil, newBifrostErrorFromMsg("model is required")
650613
}
651614

652615
// Try the primary provider first
@@ -688,33 +651,31 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
688651
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
689652
queue, err := bifrost.getProviderQueue(req.Provider)
690653
if err != nil {
691-
return nil, &schemas.BifrostError{
692-
IsBifrostError: false,
693-
Error: schemas.ErrorField{
694-
Message: err.Error(),
695-
},
696-
}
654+
return nil, newBifrostError(err)
697655
}
698656

699-
for _, plugin := range bifrost.plugins {
700-
req, err = plugin.PreHook(&ctx, req)
657+
var resp *schemas.BifrostResponse
658+
var processedPluginCount int
659+
for i, plugin := range bifrost.plugins {
660+
req, resp, err = plugin.PreHook(&ctx, req)
701661
if err != nil {
702-
return nil, &schemas.BifrostError{
703-
IsBifrostError: false,
704-
Error: schemas.ErrorField{
705-
Message: err.Error(),
706-
},
662+
return nil, newBifrostError(err)
663+
}
664+
processedPluginCount = i + 1
665+
if resp != nil {
666+
// Run post-hooks in reverse order for plugins that had PreHook executed
667+
for j := processedPluginCount - 1; j >= 0; j-- {
668+
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
669+
if err != nil {
670+
return nil, newBifrostError(err)
671+
}
707672
}
673+
return resp, nil
708674
}
709675
}
710676

711677
if req == nil {
712-
return nil, &schemas.BifrostError{
713-
IsBifrostError: false,
714-
Error: schemas.ErrorField{
715-
Message: "bifrost request after plugin hooks cannot be nil",
716-
},
717-
}
678+
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
718679
}
719680

720681
// Get a ChannelMessage from the pool
@@ -728,23 +689,13 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
728689
case <-ctx.Done():
729690
// Request was cancelled by caller
730691
bifrost.releaseChannelMessage(msg)
731-
return nil, &schemas.BifrostError{
732-
IsBifrostError: false,
733-
Error: schemas.ErrorField{
734-
Message: "request cancelled while waiting for queue space",
735-
},
736-
}
692+
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
737693
default:
738694
if bifrost.dropExcessRequests {
739695
// Drop request immediately if configured to do so
740696
bifrost.releaseChannelMessage(msg)
741697
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
742-
return nil, &schemas.BifrostError{
743-
IsBifrostError: false,
744-
Error: schemas.ErrorField{
745-
Message: "request dropped: queue is full",
746-
},
747-
}
698+
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
748699
}
749700
// If not dropping excess requests, wait with context
750701
if ctx == nil {
@@ -755,12 +706,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
755706
// Message was sent successfully
756707
case <-ctx.Done():
757708
bifrost.releaseChannelMessage(msg)
758-
return nil, &schemas.BifrostError{
759-
IsBifrostError: false,
760-
Error: schemas.ErrorField{
761-
Message: "request cancelled while waiting for queue space",
762-
},
763-
}
709+
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
764710
}
765711
}
766712

@@ -773,12 +719,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
773719
result, err = bifrost.plugins[i].PostHook(&ctx, result)
774720
if err != nil {
775721
bifrost.releaseChannelMessage(msg)
776-
return nil, &schemas.BifrostError{
777-
IsBifrostError: false,
778-
Error: schemas.ErrorField{
779-
Message: err.Error(),
780-
},
781-
}
722+
return nil, newBifrostError(err)
782723
}
783724
}
784725
case err := <-msg.Err:
@@ -794,7 +735,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
794735
// Cleanup gracefully stops all workers when triggered.
795736
// It closes all request channels and waits for workers to exit.
796737
func (bifrost *Bifrost) Cleanup() {
797-
bifrost.logger.Info("[BIFROST] Graceful Cleanup Initiated - Closing all request channels...")
738+
bifrost.logger.Info("Graceful Cleanup Initiated - Closing all request channels...")
798739

799740
// Close all provider queues to signal workers to stop
800741
for _, queue := range bifrost.requestQueues {
@@ -805,4 +746,14 @@ func (bifrost *Bifrost) Cleanup() {
805746
for _, waitGroup := range bifrost.waitGroups {
806747
waitGroup.Wait()
807748
}
749+
750+
// Cleanup plugins
751+
for _, plugin := range bifrost.plugins {
752+
err := plugin.Cleanup()
753+
if err != nil {
754+
bifrost.logger.Warn(fmt.Sprintf("Error cleaning up plugin: %s", err.Error()))
755+
}
756+
}
757+
758+
bifrost.logger.Info("Graceful Cleanup Completed")
808759
}

core/schemas/plugin.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,23 @@ import "context"
1616
// - Logging
1717
// - Monitoring
1818

19+
// No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance.
20+
1921
type Plugin interface {
2022
// PreHook is called before a request is processed by a provider.
2123
// It allows plugins to modify the request before it is sent to the provider.
2224
// The context parameter can be used to maintain state across plugin calls.
23-
// Returns the modified request and any error that occurred during processing.
24-
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error)
25+
// Returns the modified request, an optional response (if the plugin wants to short-circuit the provider call), and any error that occurred during processing.
26+
// If a response is returned, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order.
27+
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error)
2528

2629
// PostHook is called after a response is received from a provider.
2730
// It allows plugins to modify the response before it is returned to the caller.
2831
// Returns the modified response and any error that occurred during processing.
2932
PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error)
33+
34+
// Cleanup is called on bifrost shutdown.
35+
// It allows plugins to clean up any resources they have allocated.
36+
// Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance.
37+
Cleanup() error
3038
}

core/utils.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
11
package bifrost
22

3+
import schemas "github.com/maximhq/bifrost/core/schemas"
4+
35
func Ptr[T any](v T) *T {
46
return &v
57
}
8+
9+
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
10+
// This helper function reduces code duplication when handling non-Bifrost errors.
11+
func newBifrostError(err error) *schemas.BifrostError {
12+
return &schemas.BifrostError{
13+
IsBifrostError: false,
14+
Error: schemas.ErrorField{
15+
Message: err.Error(),
16+
Error: err,
17+
},
18+
}
19+
}
20+
21+
// newBifrostErrorFromMsg creates a BifrostError with a custom message.
22+
// This helper function is used for static error messages.
23+
func newBifrostErrorFromMsg(message string) *schemas.BifrostError {
24+
return &schemas.BifrostError{
25+
IsBifrostError: false,
26+
Error: schemas.ErrorField{
27+
Message: message,
28+
},
29+
}
30+
}

0 commit comments

Comments
 (0)