diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 1bbf5f41..4c2c68b7 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -123,7 +123,8 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout wg sync.WaitGroup ) - chatResponseC := make(chan schemas.ChatResponse) + chunkResultC := make(chan *schemas.ChatStreamResult) + router, _ := routerManager.GetLangRouter(routerID) defer c.Conn.Close() @@ -133,8 +134,16 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout go func() { defer wg.Done() - for response := range chatResponseC { - if err = c.WriteJSON(response); err != nil { + for chunkResult := range chunkResultC { + if chunkResult.Error() != nil { + if err = c.WriteJSON(chunkResult.Error()); err != nil { + break + } + + continue + } + + if err = c.WriteJSON(chunkResult.Chunk()); err != nil { break } } @@ -157,13 +166,12 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout go func(chatRequest schemas.ChatRequest) { defer wg.Done() - if err = router.ChatStream(context.Background(), &chatRequest, chatResponseC); err != nil { - tel.L().Error("Failed to process streaming chat request", zap.Error(err), zap.String("routerID", routerID)) - } + router.ChatStream(context.Background(), &chatRequest, chunkResultC) }(chatRequest) } - close(chatResponseC) + close(chunkResultC) + wg.Wait() }) } diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/chat.go similarity index 100% rename from pkg/api/schemas/language.go rename to pkg/api/schemas/chat.go diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go new file mode 100644 index 00000000..1ee4f0f3 --- /dev/null +++ b/pkg/api/schemas/chat_stream.go @@ -0,0 +1,53 @@ +package schemas + +// ChatStreamRequest defines a message that requests a new streaming chat +type ChatStreamRequest struct { + // TODO: implement +} + +// ChatStreamChunk defines a message for a chunk of streaming chat response +type ChatStreamChunk struct { + // TODO: modify according to the streaming chat needs + ID string `json:"id,omitempty"` + Created int `json:"created,omitempty"` + Provider string `json:"provider,omitempty"` + RouterID string `json:"router,omitempty"` + ModelID string `json:"model_id,omitempty"` + ModelName string `json:"model,omitempty"` + Cached bool `json:"cached,omitempty"` + ModelResponse ModelResponse `json:"modelResponse,omitempty"` + // TODO: add chat request-specific context +} + +type ChatStreamError struct { + // TODO: add chat request-specific context + Reason string `json:"reason"` + Message string `json:"message"` +} + +type ChatStreamResult struct { + chunk *ChatStreamChunk + err *ChatStreamError +} + +func (r *ChatStreamResult) Chunk() *ChatStreamChunk { + return r.chunk +} + +func (r *ChatStreamResult) Error() *ChatStreamError { + return r.err +} + +func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult { + return &ChatStreamResult{ + chunk: chunk, + err: nil, + } +} + +func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult { + return &ChatStreamResult{ + chunk: nil, + err: err, + } +} diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go index 64cebb37..bdcf5483 100644 --- a/pkg/providers/anthropic/chat_stream.go +++ b/pkg/providers/anthropic/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go index fc596a85..2aa0e12b 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/providers/azureopenai/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go index d4118977..cc541744 100644 --- a/pkg/providers/bedrock/chat_stream.go +++ b/pkg/providers/bedrock/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/clients/stream.go b/pkg/providers/clients/stream.go new file mode 100644 index 00000000..0d5d3f22 --- /dev/null +++ b/pkg/providers/clients/stream.go @@ -0,0 +1,23 @@ +package clients + +import "glide/pkg/api/schemas" + +type ChatStreamResult struct { + chunk *schemas.ChatStreamChunk + err error +} + +func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk { + return r.chunk +} + +func (r *ChatStreamResult) Error() error { + return r.err +} + +func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult { + return &ChatStreamResult{ + chunk: chunk, + err: err, + } +} diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 81872e07..94414324 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go index c1250e72..3a95ce58 100644 --- a/pkg/providers/lang.go +++ b/pkg/providers/lang.go @@ -17,8 +17,8 @@ type LangProvider interface { SupportChatStream() bool - Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error + Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) + ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult } type LangModel interface { @@ -105,24 +105,36 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) return resp, err } -func (m *LanguageModel) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error { - err := m.client.ChatStream(ctx, request, responseC) +func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + resultC := m.client.ChatStream(ctx, req) - if err == nil { - return err - } + go func() { + defer close(streamResultC) - var rateLimitErr *clients.RateLimitError + for chunkResult := range resultC { + if chunkResult.Error() == nil { + streamResultC <- chunkResult + // TODO: calculate latency + continue + } - if errors.As(err, &rateLimitErr) { - m.rateLimit.SetLimited(rateLimitErr.UntilReset()) + var rateLimitErr *clients.RateLimitError - return err - } + if errors.As(chunkResult.Error(), &rateLimitErr) { + m.rateLimit.SetLimited(rateLimitErr.UntilReset()) - _ = m.errBudget.Take(1) + streamResultC <- chunkResult + + continue + } + + _ = m.errBudget.Take(1) + streamResultC <- chunkResult + } + }() - return err + return streamResultC } func (m *LanguageModel) SupportChatStream() bool { diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go index 46b4d7d7..3b6107c1 100644 --- a/pkg/providers/octoml/chat_stream.go +++ b/pkg/providers/octoml/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go index 864321d7..c6a444ea 100644 --- a/pkg/providers/ollama/chat_stream.go +++ b/pkg/providers/ollama/chat_stream.go @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { - return clients.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) + close(streamResultC) + + return streamResultC } diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index afed1401..6fd89bc5 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -10,6 +10,7 @@ import ( "github.com/r3labs/sse/v2" "glide/pkg/providers/clients" + "go.uber.org/zap" "glide/pkg/api/schemas" @@ -21,17 +22,30 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error { +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult { + streamResultC := make(chan *clients.ChatStreamResult) + + go c.streamChat(ctx, req, streamResultC) + + return streamResultC +} + +func (c *Client) streamChat(ctx context.Context, request *schemas.ChatRequest, resultC chan *clients.ChatStreamResult) { // Create a new chat request resp, err := c.initChatStream(ctx, request) + + defer close(resultC) + if err != nil { - return err + resultC <- clients.NewChatStreamResult(nil, err) + + return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.handleChatReqErrs(resp) + resultC <- clients.NewChatStreamResult(nil, c.handleChatReqErrs(resp)) } reader := sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? @@ -44,7 +58,7 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r if err == io.EOF { c.tel.L().Debug("Chat stream is over", zap.String("provider", c.Provider())) - return nil + return } c.tel.L().Warn( @@ -52,7 +66,9 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r zap.String("provider", c.Provider()), ) - return clients.ErrProviderUnavailable + resultC <- clients.NewChatStreamResult(nil, clients.ErrProviderUnavailable) + + return } c.tel.L().Debug( @@ -64,15 +80,16 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r event, err := clients.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, streamDoneMarker) { - return nil + return } if err != nil { - return fmt.Errorf("failed to parse chat stream message: %v", err) + resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to parse chat stream message: %v", err)) + return } if !event.HasContent() { - c.tel.Logger.Debug( + c.tel.L().Debug( "Received an empty message in chat stream, skipping it", zap.String("provider", c.Provider()), zap.Any("msg", event), @@ -83,7 +100,8 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r err = json.Unmarshal(event.Data, &completionChunk) if err != nil { - return fmt.Errorf("failed to unmarshal chat stream message: %v", err) + resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)) + return } c.tel.L().Debug( @@ -93,7 +111,7 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r ) // TODO: use objectpool here - chatResponse := schemas.ChatResponse{ + chatRespChunk := schemas.ChatStreamChunk{ ID: completionChunk.ID, Created: completionChunk.Created, Provider: providerName, @@ -111,7 +129,7 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r // TODO: Pass info if this is the final message } - responseC <- chatResponse + resultC <- clients.NewChatStreamResult(&chatRespChunk, nil) } } @@ -137,7 +155,7 @@ func (c *Client) initChatStream(ctx context.Context, request *schemas.ChatReques req.Header.Set("Connection", "keep-alive") // TODO: this could leak information from messages which may not be a desired thing to have - c.tel.Logger.Debug( + c.tel.L().Debug( "Stream chat request", zap.String("chatURL", c.chatURL), zap.Any("payload", chatRequest), diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go index 3abf459e..451f22c9 100644 --- a/pkg/providers/testing.go +++ b/pkg/providers/testing.go @@ -4,6 +4,8 @@ import ( "context" "time" + "glide/pkg/providers/clients" + "glide/pkg/routers/latency" "glide/pkg/api/schemas" @@ -57,7 +59,7 @@ func (c *ProviderMock) SupportChatStream() bool { return c.supportStreaming } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error { +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { // TODO: implement return nil } diff --git a/pkg/routers/router.go b/pkg/routers/router.go index 27948c92..3cef2324 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -59,7 +59,7 @@ func (r *LangRouter) ID() string { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -80,14 +80,14 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s langModel := model.(providers.LangModel) // Check if there is an override in the request - if request.Override != nil { + if req.Override != nil { // Override the message if the language model ID matches the override model ID - if langModel.ID() == request.Override.Model { - request.Message = request.Override.Message + if langModel.ID() == req.Override.Model { + req.Message = req.Override.Message } } - resp, err := langModel.Chat(ctx, request) + resp, err := langModel.Chat(ctx, req) if err != nil { r.telemetry.L().Warn( "Lang model failed processing chat request", @@ -122,9 +122,18 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s return nil, ErrNoModelAvailable } -func (r *LangRouter) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error { +func (r *LangRouter) ChatStream( + ctx context.Context, + req *schemas.ChatRequest, + respC chan<- *schemas.ChatStreamResult, +) { if len(r.chatStreamModels) == 0 { - return ErrNoModels + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + Reason: "noModels", + Message: ErrNoModels.Error(), + }) + + return } retryIterator := r.retry.Iterator() @@ -132,6 +141,7 @@ func (r *LangRouter) ChatStream(ctx context.Context, request *schemas.ChatReques for retryIterator.HasNext() { modelIterator := r.chatStreamRouting.Iterator() + NextModel: for { model, err := modelIterator.Next() @@ -141,21 +151,33 @@ func (r *LangRouter) ChatStream(ctx context.Context, request *schemas.ChatReques } langModel := model.(providers.LangModel) + modelRespC := langModel.ChatStream(ctx, req) + + for chunkResult := range modelRespC { + if chunkResult.Error() != nil { + r.telemetry.L().Warn( + "Lang model failed processing streaming chat request", + zap.String("routerID", r.ID()), + zap.String("modelID", langModel.ID()), + zap.String("provider", langModel.Provider()), + zap.Error(err), + ) + + // It's challenging to hide an error in case of streaming chat as consumer apps + // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), + // so we cannot easily restart that process from scratch + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + Reason: "modelUnavailable", + Message: err.Error(), + }) + + continue NextModel + } - err = langModel.ChatStream(ctx, request, responseC) - if err != nil { - r.telemetry.L().Warn( - "Lang model failed processing streaming chat request", - zap.String("routerID", r.ID()), - zap.String("modelID", langModel.ID()), - zap.String("provider", langModel.Provider()), - zap.Error(err), - ) - - continue + respC <- schemas.NewChatStreamResult(chunkResult.Chunk()) } - return nil + return } // no providers were available to handle the request, @@ -165,7 +187,12 @@ func (r *LangRouter) ChatStream(ctx context.Context, request *schemas.ChatReques err := retryIterator.WaitNext(ctx) if err != nil { // something has cancelled the context - return err + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + Reason: "other", + Message: err.Error(), + }) + + return } } @@ -175,5 +202,8 @@ func (r *LangRouter) ChatStream(ctx context.Context, request *schemas.ChatReques zap.String("routerID", r.ID()), ) - return ErrNoModelAvailable + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + Reason: "noModelAvailable", + Message: ErrNoModelAvailable.Error(), + }) }