@@ -294,36 +294,87 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *sch
294
294
return bifrost .handleStreamRequest (ctx , req , schemas .TranscriptionStreamRequest )
295
295
}
296
296
297
+ // RemovePlugin removes a plugin from the server.
298
+ func (bifrost * Bifrost ) RemovePlugin (name string ) error {
299
+
300
+ for {
301
+ oldPlugins := bifrost .plugins .Load ()
302
+ if oldPlugins == nil {
303
+ return nil
304
+ }
305
+ var pluginToCleanup schemas.Plugin
306
+ found := false
307
+ // Create new slice with replaced plugin
308
+ newPlugins := make ([]schemas.Plugin , len (* oldPlugins ))
309
+ copy (newPlugins , * oldPlugins )
310
+ for i , p := range newPlugins {
311
+ if p .GetName () == name {
312
+ pluginToCleanup = p
313
+ bifrost .logger .Debug ("removing plugin %s" , name )
314
+ newPlugins = append (newPlugins [:i ], newPlugins [i + 1 :]... )
315
+ found = true
316
+ break
317
+ }
318
+ }
319
+ if ! found {
320
+ return nil
321
+ }
322
+ if pluginToCleanup != nil {
323
+ // Atomic compare-and-swap
324
+ if bifrost .plugins .CompareAndSwap (oldPlugins , & newPlugins ) {
325
+ // Cleanup the old plugin
326
+ err := pluginToCleanup .Cleanup ()
327
+ if err != nil {
328
+ bifrost .logger .Warn ("failed to cleanup old plugin %s: %v" , pluginToCleanup .GetName (), err )
329
+ }
330
+ return nil
331
+ }
332
+ }
333
+ // Retrying as swapping did not work
334
+ }
335
+ }
336
+
297
337
// ReloadPlugin reloads a plugin with new instance
298
338
// During the reload - it's stop the world phase where we take a global lock on the plugin mutex
299
- func (bifrost * Bifrost ) ReloadPlugin (plugin schemas.Plugin ) error {
339
+ func (bifrost * Bifrost ) ReloadPlugin (plugin schemas.Plugin ) error {
300
340
for {
341
+ var pluginToCleanup schemas.Plugin
342
+ found := false
301
343
oldPlugins := bifrost .plugins .Load ()
302
344
if oldPlugins == nil {
303
345
return nil
304
346
}
305
347
// Create new slice with replaced plugin
306
348
newPlugins := make ([]schemas.Plugin , len (* oldPlugins ))
307
349
copy (newPlugins , * oldPlugins )
308
- found := false
309
350
for i , p := range newPlugins {
310
351
if p .GetName () == plugin .GetName () {
352
+ // Cleaning up old plugin before replacing it
353
+ pluginToCleanup = p
354
+ bifrost .logger .Debug ("replacing plugin %s with new instance" , plugin .GetName ())
311
355
newPlugins [i ] = plugin
312
356
found = true
313
357
break
314
358
}
315
359
}
316
- if ! found {
360
+ if ! found {
317
361
// This means that user is adding a new plugin
362
+ bifrost .logger .Debug ("adding new plugin %s" , plugin .GetName ())
318
363
newPlugins = append (newPlugins , plugin )
319
364
}
320
365
// Atomic compare-and-swap
321
366
if bifrost .plugins .CompareAndSwap (oldPlugins , & newPlugins ) {
367
+ // Cleanup the old plugin
368
+ if found && pluginToCleanup != nil {
369
+ err := pluginToCleanup .Cleanup ()
370
+ if err != nil {
371
+ bifrost .logger .Warn ("failed to cleanup old plugin %s: %v" , pluginToCleanup .GetName (), err )
372
+ }
373
+ }
322
374
return nil
323
375
}
324
376
// Retrying as swapping did not work
325
377
}
326
-
327
378
}
328
379
329
380
// UpdateProviderConcurrency dynamically updates the queue size and concurrency for an existing provider.
@@ -1023,7 +1074,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
1023
1074
1024
1075
msg := bifrost .getChannelMessage (* preReq , requestType )
1025
1076
msg .Context = ctx
1026
-
1077
+ startTime := time . Now ()
1027
1078
select {
1028
1079
case queue <- * msg :
1029
1080
// Message was sent successfully
@@ -1047,9 +1098,14 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
1047
1098
1048
1099
var result * schemas.BifrostResponse
1049
1100
var resp * schemas.BifrostResponse
1101
+ pluginCount := len (* bifrost .plugins .Load ())
1050
1102
select {
1051
1103
case result = <- msg .Response :
1052
- resp , bifrostErr := pipeline .RunPostHooks (& ctx , result , nil , len (* bifrost .plugins .Load ()))
1104
+ latency := time .Since (startTime ).Milliseconds ()
1105
+ if result .ExtraFields .Latency == nil {
1106
+ result .ExtraFields .Latency = & latency
1107
+ }
1108
+ resp , bifrostErr := pipeline .RunPostHooks (& ctx , result , nil , pluginCount )
1053
1109
if bifrostErr != nil {
1054
1110
bifrost .releaseChannelMessage (msg )
1055
1111
return nil , bifrostErr
@@ -1058,7 +1114,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
1058
1114
return resp , nil
1059
1115
case bifrostErrVal := <- msg .Err :
1060
1116
bifrostErrPtr := & bifrostErrVal
1061
- resp , bifrostErrPtr = pipeline .RunPostHooks (& ctx , nil , bifrostErrPtr , len ( * bifrost . plugins . Load ()) )
1117
+ resp , bifrostErrPtr = pipeline .RunPostHooks (& ctx , nil , bifrostErrPtr , pluginCount )
1062
1118
bifrost .releaseChannelMessage (msg )
1063
1119
if bifrostErrPtr != nil {
1064
1120
return nil , bifrostErrPtr
@@ -1172,7 +1228,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
1172
1228
// Marking final chunk
1173
1229
ctx = context .WithValue (ctx , schemas .BifrostContextKeyStreamEndIndicator , true )
1174
1230
// On error we will complete post-hooks
1175
- recoveredResp , recoveredErr := pipeline .RunPostHooks (& ctx , nil , & bifrostErrVal , len (bifrost .plugins ))
1231
+ recoveredResp , recoveredErr := pipeline .RunPostHooks (& ctx , nil , & bifrostErrVal , len (* bifrost .plugins . Load () ))
1176
1232
bifrost .releaseChannelMessage (msg )
1177
1233
if recoveredErr != nil {
1178
1234
return nil , recoveredErr
@@ -1332,7 +1388,7 @@ func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key s
1332
1388
case schemas .TextCompletionRequest :
1333
1389
return provider .TextCompletion (req .Context , req .Model , key , * req .Input .TextCompletionInput , req .Params )
1334
1390
case schemas .ChatCompletionRequest :
1335
- return provider .ChatCompletion (req .Context , req .Model , key , * req .Input .ChatCompletionInput , req .Params )
1391
+ return provider .ChatCompletion (req .Context , req .Model , key , req .Input .ChatCompletionInput , req .Params )
1336
1392
case schemas .EmbeddingRequest :
1337
1393
return provider .Embedding (req .Context , req .Model , key , req .Input .EmbeddingInput , req .Params )
1338
1394
case schemas .SpeechRequest :
@@ -1353,7 +1409,7 @@ func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key s
1353
1409
func handleProviderStreamRequest (provider schemas.Provider , req * ChannelMessage , key schemas.Key , postHookRunner schemas.PostHookRunner , reqType schemas.RequestType ) (chan * schemas.BifrostStream , * schemas.BifrostError ) {
1354
1410
switch reqType {
1355
1411
case schemas .ChatCompletionStreamRequest :
1356
- return provider .ChatCompletionStream (req .Context , postHookRunner , req .Model , key , * req .Input .ChatCompletionInput , req .Params )
1412
+ return provider .ChatCompletionStream (req .Context , postHookRunner , req .Model , key , req .Input .ChatCompletionInput , req .Params )
1357
1413
case schemas .SpeechStreamRequest :
1358
1414
return provider .SpeechStream (req .Context , postHookRunner , req .Model , key , req .Input .SpeechInput , req .Params )
1359
1415
case schemas .TranscriptionStreamRequest :
@@ -1375,6 +1431,7 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR
1375
1431
var shortCircuit * schemas.PluginShortCircuit
1376
1432
var err error
1377
1433
for i , plugin := range p .plugins {
1434
+ p .logger .Debug ("running pre-hook for plugin %s" , plugin .GetName ())
1378
1435
req , shortCircuit , err = plugin .PreHook (ctx , req )
1379
1436
if err != nil {
1380
1437
p .preHookErrors = append (p .preHookErrors , err )
@@ -1391,17 +1448,19 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR
1391
1448
// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran.
1392
1449
// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response).
1393
1450
// Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil.
1394
- func (p * PluginPipeline ) RunPostHooks (ctx * context.Context , resp * schemas.BifrostResponse , bifrostErr * schemas.BifrostError , count int ) (* schemas.BifrostResponse , * schemas.BifrostError ) {
1451
+ // runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0
1452
+ func (p * PluginPipeline ) RunPostHooks (ctx * context.Context , resp * schemas.BifrostResponse , bifrostErr * schemas.BifrostError , runFrom int ) (* schemas.BifrostResponse , * schemas.BifrostError ) {
1395
1453
// Defensive: ensure count is within valid bounds
1396
- if count < 0 {
1397
- count = 0
1454
+ if runFrom < 0 {
1455
+ runFrom = 0
1398
1456
}
1399
- if count > len (p .plugins ) {
1400
- count = len (p .plugins )
1457
+ if runFrom > len (p .plugins ) {
1458
+ runFrom = len (p .plugins )
1401
1459
}
1402
1460
var err error
1403
- for i := count - 1 ; i >= 0 ; i -- {
1461
+ for i := runFrom - 1 ; i >= 0 ; i -- {
1404
1462
plugin := p .plugins [i ]
1463
+ p .logger .Debug ("running post-hook for plugin %s" , plugin .GetName ())
1405
1464
resp , bifrostErr , err = plugin .PostHook (ctx , resp , bifrostErr )
1406
1465
if err != nil {
1407
1466
p .postHookErrors = append (p .postHookErrors , err )
@@ -1618,4 +1677,5 @@ func (bifrost *Bifrost) Shutdown() {
1618
1677
bifrost .logger .Warn (fmt .Sprintf ("Error cleaning up plugin: %s" , err .Error ()))
1619
1678
}
1620
1679
}
1680
+ bifrost .logger .Info ("all request channels closed" )
1621
1681
}
0 commit comments