Skip to content

Commit

Permalink
#163: Extracted health tracking logic into a separate component
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Mar 11, 2024
1 parent 081f5b9 commit ba56d3f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 49 deletions.
46 changes: 13 additions & 33 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package providers

import (
"context"
"errors"
"glide/pkg/routers/health"
"time"

"glide/pkg/api/schemas"
"glide/pkg/providers/clients"
"glide/pkg/routers/health"
"glide/pkg/routers/latency"
)

Expand All @@ -34,8 +33,7 @@ type LanguageModel struct {
modelID string
weight int
client LangProvider
rateLimit *health.RateLimitTracker
errBudget *health.TokenBucket
healthTracker *health.HealthTracker
chatLatency *latency.MovingAverage
chatStreamLatency *latency.MovingAverage
latencyUpdateInterval *time.Duration
Expand All @@ -45,8 +43,7 @@ func NewLangModel(modelID string, client LangProvider, budget health.ErrorBudget
return &LanguageModel{
modelID: modelID,
client: client,
rateLimit: health.NewRateLimitTracker(),
errBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
healthTracker: health.NewHealthTracker(budget),
chatLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
chatStreamLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
latencyUpdateInterval: latencyConfig.UpdateInterval,
Expand All @@ -58,6 +55,10 @@ func (m LanguageModel) ID() string {
return m.modelID
}

func (m LanguageModel) Healthy() bool {
return m.healthTracker.Healthy()
}

func (m LanguageModel) Weight() int {
return m.weight
}
Expand All @@ -66,6 +67,10 @@ func (m LanguageModel) LatencyUpdateInterval() *time.Duration {
return m.latencyUpdateInterval
}

func (m *LanguageModel) SupportChatStream() bool {
return m.client.SupportChatStream()
}

func (m LanguageModel) ChatLatency() *latency.MovingAverage {
return m.chatLatency
}
Expand All @@ -74,10 +79,6 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {
return m.chatStreamLatency
}

func (m LanguageModel) Healthy() bool {
return !m.rateLimit.Limited() && m.errBudget.HasTokens()
}

func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
startedAt := time.Now()
resp, err := m.client.Chat(ctx, request)
Expand All @@ -92,15 +93,7 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
return resp, err
}

var rateLimitErr *clients.RateLimitError

if errors.As(err, &rateLimitErr) {
m.rateLimit.SetLimited(rateLimitErr.UntilReset())

return resp, err
}

_ = m.errBudget.Take(1)
m.healthTracker.TrackErr(err)

return resp, err
}
Expand All @@ -119,28 +112,15 @@ func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest
continue
}

var rateLimitErr *clients.RateLimitError

if errors.As(chunkResult.Error(), &rateLimitErr) {
m.rateLimit.SetLimited(rateLimitErr.UntilReset())
m.healthTracker.TrackErr(chunkResult.Error())

streamResultC <- chunkResult

continue
}

_ = m.errBudget.Take(1)
streamResultC <- chunkResult
}
}()

return streamResultC
}

func (m *LanguageModel) SupportChatStream() bool {
return m.client.SupportChatStream()
}

func (m *LanguageModel) Provider() string {
return m.client.Provider()
}
Expand Down
34 changes: 19 additions & 15 deletions pkg/providers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,45 +98,49 @@ func (c *ProviderMock) Provider() string {
}

type LangModelMock struct {
modelID string
healthy bool
latency *latency.MovingAverage
weight int
modelID string
healthy bool
chatLatency *latency.MovingAverage
weight int
}

func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) *LangModelMock {
movingAverage := latency.NewMovingAverage(0.06, 3)
chatLatency := latency.NewMovingAverage(0.06, 3)

if avgLatency > 0.0 {
movingAverage.Set(avgLatency)
chatLatency.Set(avgLatency)
}

return &LangModelMock{
modelID: ID,
healthy: healthy,
latency: movingAverage,
weight: weight,
modelID: ID,
healthy: healthy,
chatLatency: chatLatency,
weight: weight,
}
}

func (m *LangModelMock) ID() string {
func (m LangModelMock) ID() string {
return m.modelID
}

func (m *LangModelMock) Healthy() bool {
func (m LangModelMock) Healthy() bool {
return m.healthy
}

func (m *LangModelMock) ChatLatency() *latency.MovingAverage {
return m.latency
return m.chatLatency
}

func (m *LangModelMock) LatencyUpdateInterval() *time.Duration {
func (m LangModelMock) LatencyUpdateInterval() *time.Duration {
updateInterval := 30 * time.Second

return &updateInterval
}

func (m *LangModelMock) Weight() int {
func (m LangModelMock) Weight() int {
return m.weight
}

func ChatMockLatency(model Model) *latency.MovingAverage {
return model.(LangModelMock).chatLatency
}
34 changes: 34 additions & 0 deletions pkg/routers/health/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package health

import (
"errors"
"glide/pkg/providers/clients"
)

type HealthTracker struct {

Check warning on line 8 in pkg/routers/health/tracker.go

View workflow job for this annotation

GitHub Actions / Static Checks

exported: type name will be used as health.HealthTracker by other packages, and that stutters; consider calling this Tracker (revive)
errBudget *TokenBucket
rateLimit *RateLimitTracker
}

func NewHealthTracker(budget ErrorBudget) *HealthTracker {
return &HealthTracker{
rateLimit: NewRateLimitTracker(),
errBudget: NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
}
}

func (t *HealthTracker) Healthy() bool {
return !t.rateLimit.Limited() && t.errBudget.HasTokens()
}

func (t *HealthTracker) TrackErr(err error) {
var rateLimitErr *clients.RateLimitError

if errors.As(err, &rateLimitErr) {
t.rateLimit.SetLimited(rateLimitErr.UntilReset())

return
}

_ = t.errBudget.Take(1)
}
2 changes: 1 addition & 1 deletion pkg/routers/routing/least_latency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
}

routing := NewLeastLatencyRouting(providers.ChatLatency, models)
routing := NewLeastLatencyRouting(providers.ChatMockLatency, models)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
Expand Down

0 comments on commit ba56d3f

Please sign in to comment.