Skip to content

Commit

Permalink
#163: Covered the fallback mechanism by tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Mar 17, 2024
1 parent 2198aa2 commit 0c158bd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
30 changes: 23 additions & 7 deletions pkg/providers/testing/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,43 @@ func (m *RespMock) RespChunk() *schemas.ChatStreamChunk {

// RespStreamMock mocks a chat stream
type RespStreamMock struct {
idx int
Chunks []RespMock
idx int
OpenErr error
Chunks *[]RespMock
}

func NewRespStreamMock(chunk []RespMock) RespStreamMock {
func NewRespStreamMock(chunk *[]RespMock) RespStreamMock {
return RespStreamMock{
idx: 0,
Chunks: chunk,
idx: 0,
OpenErr: nil,
Chunks: chunk,
}
}

func NewRespStreamWithOpenErr(openErr error) RespStreamMock {
return RespStreamMock{
idx: 0,
OpenErr: openErr,
Chunks: nil,
}
}

func (m *RespStreamMock) Open() error {
if m.OpenErr != nil {
return m.OpenErr
}

return nil
}

func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) {
if m.idx >= len(m.Chunks) {
if m.Chunks != nil && m.idx >= len(*m.Chunks) {
return nil, io.EOF
}

chunk := m.Chunks[m.idx]
chunks := *m.Chunks

chunk := chunks[m.idx]
m.idx++

if chunk.Err != nil {
Expand Down
76 changes: 72 additions & 4 deletions pkg/routers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestLangRouter_ChatStream(t *testing.T) {
providers.NewLangModel(
"first",
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
ptesting.NewRespStreamMock([]ptesting.RespMock{
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
{Msg: "Bill"},
{Msg: "Gates"},
{Msg: "entered"},
Expand All @@ -276,7 +276,7 @@ func TestLangRouter_ChatStream(t *testing.T) {
providers.NewLangModel(
"second",
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
ptesting.NewRespStreamMock([]ptesting.RespMock{
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
{Msg: "Knock"},
{Msg: "Knock"},
{Msg: "joke"},
Expand Down Expand Up @@ -327,6 +327,74 @@ func TestLangRouter_ChatStream(t *testing.T) {
require.Equal(t, []string{"Bill", "Gates", "entered", "the", "bar"}, chunks)
}

func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
latConfig := latency.DefaultConfig()

langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
ptesting.NewStreamProviderMock(nil),
budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
ptesting.NewRespStreamMock(
&[]ptesting.RespMock{
{Msg: "Knock"},
{Msg: "knock"},
{Msg: "joke"},
},
),
}),
budget,
*latConfig,
1,
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

router := LangRouter{
routerID: "test_stream_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
chatRouting: routing.NewPriority(models),
chatModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
}

ctx := context.Background()
req := schemas.NewChatFromStr("tell me a dad joke")
respC := make(chan *schemas.ChatStreamResult)

defer close(respC)

go router.ChatStream(ctx, req, respC)

chunks := make([]string, 0, 3)

for range 3 {
select { //nolint:gosimple
case chunk := <-respC:
require.Nil(t, chunk.Error())
require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content)

chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content)
}
}

require.Equal(t, []string{"Knock", "knock", "joke"}, chunks)
}

func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
budget := health.NewErrorBudget(1, health.SEC)
latConfig := latency.DefaultConfig()
Expand All @@ -335,7 +403,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
providers.NewLangModel(
"first",
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
ptesting.NewRespStreamMock([]ptesting.RespMock{
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
{Err: &clients.ErrProviderUnavailable},
}),
}),
Expand All @@ -346,7 +414,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
providers.NewLangModel(
"second",
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
ptesting.NewRespStreamMock([]ptesting.RespMock{
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
{Err: &clients.ErrProviderUnavailable},
}),
}),
Expand Down

0 comments on commit 0c158bd

Please sign in to comment.