-
Notifications
You must be signed in to change notification settings - Fork 333
/
Copy pathvtc_router.go
69 lines (52 loc) · 2.03 KB
/
vtc_router.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/*
Copyright 2024 The Aibrix Team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package vtc implements the Virtual Token Counter routing algorithms focused on fairness and utilization
package vtc
import (
"context"
"github.com/vllm-project/aibrix/pkg/types"
)
const RouterVTCBasic types.RoutingAlgorithm = "vtc-basic"
// TODO: add other variants - "vtc-fair", "vtc-max-fair", "vtc-pred-50"
// TokenTracker tracks token usage per user
type TokenTracker interface {
GetTokenCount(ctx context.Context, user string) (float64, error)
UpdateTokenCount(ctx context.Context, user string, inputTokens, outputTokens float64) error
GetMinTokenCount(ctx context.Context) (float64, error)
GetMaxTokenCount(ctx context.Context) (float64, error)
}
// TokenEstimator estimates token counts for messages
type TokenEstimator interface {
EstimateInputTokens(message string) float64
EstimateOutputTokens(message string) float64
}
type VTCConfig struct {
Variant types.RoutingAlgorithm
InputTokenWeight float64
OutputTokenWeight float64
}
func DefaultVTCConfig() VTCConfig {
// Use the global variables loaded from environment
return VTCConfig{
Variant: RouterVTCBasic,
InputTokenWeight: inputTokenWeight,
OutputTokenWeight: outputTokenWeight,
}
}
func NewVTCBasicRouter() (types.Router, error) {
config := DefaultVTCConfig()
configPtr := &config
var tokenEstimator TokenEstimator = NewSimpleTokenEstimator()
var tokenTracker TokenTracker = NewInMemorySlidingWindowTokenTracker(configPtr)
return NewBasicVTCRouter(tokenTracker, tokenEstimator, configPtr)
}