Skip to content

Commit

Permalink
#163: Fixed the stream chat test
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Mar 11, 2024
1 parent 548ea18 commit 21d1f33
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pkg/cmd/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewCLI() *cobra.Command {
if err != nil {
log.Println("⚠️failed to load dotenv file: ", err) // don't have an inited logger at this moment
} else {
log.Printf("🔧dot env file loaded (%v)", dotEnvFile)
log.Printf("🔧dot env file is loaded (%v)", dotEnvFile)
}

_, err = configProvider.Load(cfgFile)
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ func (m *LanguageModel) Provider() string {
}

func ChatLatency(model Model) *latency.MovingAverage {
return model.(LanguageModel).chatLatency
return model.(LanguageModel).ChatLatency()
}

func ChatStreamLatency(model Model) *latency.MovingAverage {
return model.(LanguageModel).chatStreamLatency
return model.(LanguageModel).ChatStreamLatency()
}
35 changes: 32 additions & 3 deletions pkg/providers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ func (m *ResponseMock) Resp() *schemas.ChatResponse {
}
}

func (m *ResponseMock) RespChunk() *schemas.ChatStreamChunk {
return &schemas.ChatStreamChunk{
ID: "rsp0001",
ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"ID": "0001",
},
Message: schemas.ChatMessage{
Content: m.Msg,
},
},
}
}

type ProviderMock struct {
idx int
responses []ResponseMock
Expand Down Expand Up @@ -60,8 +74,23 @@ func (c *ProviderMock) SupportChatStream() bool {
}

func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
// TODO: implement
return nil
streamResultC := make(chan *clients.ChatStreamResult)

response := c.responses[c.idx]
c.idx++

go func() {
defer close(streamResultC)

if response.Err != nil {
streamResultC <- clients.NewChatStreamResult(nil, *response.Err)
return
}

streamResultC <- clients.NewChatStreamResult(response.RespChunk(), nil)
}()

return streamResultC
}

func (c *ProviderMock) Provider() string {
Expand Down Expand Up @@ -98,7 +127,7 @@ func (m *LangModelMock) Healthy() bool {
return m.healthy
}

func (m *LangModelMock) Latency() *latency.MovingAverage {
func (m *LangModelMock) ChatLatency() *latency.MovingAverage {
return m.latency
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/routers/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ var ErrRouterNotFound = errors.New("no router found with given ID")

type RouterManager struct {
Config *Config
telemetry *telemetry.Telemetry
tel *telemetry.Telemetry
langRouterMap *map[string]*LangRouter
langRouters []*LangRouter
}
Expand All @@ -30,7 +30,7 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) {

manager := RouterManager{
Config: cfg,
telemetry: tel,
tel: tel,
langRouters: langRouters,
langRouterMap: &langRouterMap,
}
Expand Down
21 changes: 11 additions & 10 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type LangRouter struct {
chatRouting routing.LangModelRouting
chatStreamRouting routing.LangModelRouting
retry *retry.ExpRetry
telemetry *telemetry.Telemetry
tel *telemetry.Telemetry
}

func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) {
Expand All @@ -49,7 +49,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter
retry: cfg.BuildRetry(),
chatRouting: chatRouting,
chatStreamRouting: chatStreamRouting,
telemetry: tel,
tel: tel,
}

return router, err
Expand Down Expand Up @@ -89,7 +89,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem

resp, err := langModel.Chat(ctx, req)
if err != nil {
r.telemetry.L().Warn(
r.tel.L().Warn(
"Lang model failed processing chat request",
zap.String("routerID", r.ID()),
zap.String("modelID", langModel.ID()),
Expand All @@ -107,7 +107,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem

// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
r.telemetry.L().Warn("No healthy model found to serve chat request, wait and retry", zap.String("routerID", r.ID()))
r.tel.L().Warn("No healthy model found to serve chat request, wait and retry", zap.String("routerID", r.ID()))

err := retryIterator.WaitNext(ctx)
if err != nil {
Expand All @@ -117,7 +117,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
}

// if we reach this part, then we are in trouble
r.telemetry.L().Error("No model was available to handle chat request", zap.String("routerID", r.ID()))
r.tel.L().Error("No model was available to handle chat request", zap.String("routerID", r.ID()))

return nil, ErrNoModelAvailable
}
Expand Down Expand Up @@ -154,8 +154,9 @@ func (r *LangRouter) ChatStream(
modelRespC := langModel.ChatStream(ctx, req)

for chunkResult := range modelRespC {
if chunkResult.Error() != nil {
r.telemetry.L().Warn(
err = chunkResult.Error()
if err != nil {
r.tel.L().Warn(
"Lang model failed processing streaming chat request",
zap.String("routerID", r.ID()),
zap.String("modelID", langModel.ID()),
Expand All @@ -182,7 +183,7 @@ func (r *LangRouter) ChatStream(

// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
r.telemetry.L().Warn("No healthy model found to serve streaming chat request, wait and retry", zap.String("routerID", r.ID()))
r.tel.L().Warn("No healthy model found to serve streaming chat request, wait and retry", zap.String("routerID", r.ID()))

err := retryIterator.WaitNext(ctx)
if err != nil {
Expand All @@ -197,13 +198,13 @@ func (r *LangRouter) ChatStream(
}

// if we reach this part, then we are in trouble
r.telemetry.L().Error(
r.tel.L().Error(
"No model was available to handle streaming chat request. Try to configure more fallback models to avoid this",
zap.String("routerID", r.ID()),
)

respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{
Reason: "noModelAvailable",
Reason: "allModelsUnavailable",
Message: ErrNoModelAvailable.Error(),
})
}
26 changes: 16 additions & 10 deletions pkg/routers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) {
chatRouting: routing.NewPriority(models),
chatModels: langModels,
chatStreamModels: langModels,
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

ctx := context.Background()
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) {
chatStreamRouting: routing.NewPriority(models),
chatModels: langModels,
chatStreamModels: langModels,
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

ctx := context.Background()
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) {
chatStreamRouting: routing.NewPriority(models),
chatModels: langModels,
chatStreamModels: langModels,
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke"))
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) {
chatModels: langModels,
chatStreamModels: langModels,
chatStreamRouting: routing.NewPriority(models),
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

for i := 0; i < 2; i++ {
Expand Down Expand Up @@ -244,7 +244,7 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) {
chatModels: langModels,
chatStreamModels: langModels,
chatStreamRouting: routing.NewPriority(models),
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

_, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke"))
Expand Down Expand Up @@ -286,16 +286,22 @@ func TestLangRouter_ChatStream(t *testing.T) {
chatModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamModels: langModels,
telemetry: telemetry.NewTelemetryMock(),
tel: telemetry.NewTelemetryMock(),
}

ctx := context.Background()
req := schemas.NewChatFromStr("tell me a dad joke")
respC := make(chan schemas.ChatResponse)
respC := make(chan *schemas.ChatStreamResult)

for i := 0; i < 2; i++ {
err := router.ChatStream(ctx, req, respC)
defer close(respC)

require.NoError(t, err)
go router.ChatStream(ctx, req, respC)

select {
case chunkResult := <-respC:
require.Nil(t, chunkResult.Error())
require.NotNil(t, chunkResult.Chunk().ModelResponse.Message.Content)
case <-time.Tick(5 * time.Second):
t.Error("Timeout while waiting for stream chat chunk")
}
}

0 comments on commit 21d1f33

Please sign in to comment.