diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go index bdcf5483..7ca861c1 100644 --- a/pkg/providers/anthropic/chat_stream.go +++ b/pkg/providers/anthropic/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go index 2aa0e12b..6facb899 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/providers/azureopenai/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go index cc541744..3ae8498c 100644 --- a/pkg/providers/bedrock/chat_stream.go +++ b/pkg/providers/bedrock/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/clients/stream.go b/pkg/providers/clients/stream.go index 19beaae0..ff150e8f 100644 --- a/pkg/providers/clients/stream.go +++ b/pkg/providers/clients/stream.go @@ -4,6 +4,12 @@ import ( "glide/pkg/api/schemas" ) +type ChatStream interface { + Open() error + Recv() (*schemas.ChatStreamChunk, error) + Close() error +} + type ChatStreamResult struct { chunk *schemas.ChatStreamChunk err error diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 94414324..f4d0e8e2 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go index 20efe815..48eed994 100644 --- a/pkg/providers/lang.go +++ b/pkg/providers/lang.go @@ -2,6 +2,7 @@ package providers import ( "context" + "io" "time" "glide/pkg/routers/health" @@ -18,12 +19,14 @@ type LangProvider interface { SupportChatStream() bool Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult + ChatStream(ctx context.Context, req *schemas.ChatRequest) (clients.ChatStream, error) } type LangModel interface { - LangProvider Model + Provider() string + Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) + ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error) } // LanguageModel wraps provider client and expend it with health & latency tracking @@ -99,35 +102,65 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) return resp, err } -func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult { +func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error) { + stream, err := m.client.ChatStream(ctx, req) + if err != nil { + return nil, err + } + streamResultC := make(chan *clients.ChatStreamResult) - resultC := m.client.ChatStream(ctx, req) go func() { defer close(streamResultC) - var chunkLatency *time.Duration + startedAt := time.Now() + err = stream.Open() + chunkLatency := time.Since(startedAt) - for chunkResult := range resultC { - if chunkResult.Error() == nil { - streamResultC <- chunkResult + // the first chunk latency + m.chatStreamLatency.Add(float64(chunkLatency)) - chunkLatency = chunkResult.Chunk().Latency + if err != nil { + streamResultC <- clients.NewChatStreamResult(nil, err) + + m.healthTracker.TrackErr(err) + + return + } - if chunkLatency != nil { - m.chatStreamLatency.Add(float64(*chunkLatency)) + defer stream.Close() + + for { + startedAt = time.Now() + chunk, err := stream.Recv() + chunkLatency = time.Since(startedAt) + + if err != nil { + if err == io.EOF { + // end of the stream + return } - continue + streamResultC <- clients.NewChatStreamResult(nil, err) + + m.healthTracker.TrackErr(err) + + return } - m.healthTracker.TrackErr(chunkResult.Error()) + streamResultC <- clients.NewChatStreamResult(chunk, nil) - streamResultC <- chunkResult + if chunkLatency > 1*time.Millisecond { + // All events are read in a bigger chunks of bytes, so one chunk may contain more than one event. + // Each byte chunk is then parsed, so there is no easy way to precisely guess latency per chunk, + // So we assume that if we spent more than 1ms waiting for a chunk it's likely + // we were trying to read from the connection (otherwise, it would take nanoseconds) + m.chatStreamLatency.Add(float64(chunkLatency)) + } } }() - return streamResultC + return streamResultC, nil } func (m *LanguageModel) Provider() string { diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go index 3b6107c1..d2418e93 100644 --- a/pkg/providers/octoml/chat_stream.go +++ b/pkg/providers/octoml/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go index c6a444ea..bb15180a 100644 --- a/pkg/providers/ollama/chat_stream.go +++ b/pkg/providers/ollama/chat_stream.go @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented) - close(streamResultC) - - return streamResultC +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 3da0479a..31bd3a6e 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -7,9 +7,6 @@ import ( "fmt" "io" "net/http" - "time" - - "glide/pkg/providers/clients" "glide/pkg/api/schemas" "go.uber.org/zap" @@ -106,7 +103,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, c.handleChatReqErrs(resp) + return nil, c.errMapper.Map(resp) } // Read the response body into a byte slice @@ -161,39 +158,3 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return &response, nil } - -func (c *Client) handleChatReqErrs(resp *http.Response) error { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.tel.Logger.Error( - "Failed to unmarshal chat response error", - zap.String("provider", c.Provider()), - zap.Error(err), - zap.ByteString("rawResponse", bodyBytes), - ) - } - - c.tel.Logger.Error( - "Chat request failed", - zap.String("provider", c.Provider()), - zap.Int("statusCode", resp.StatusCode), - zap.String("response", string(bodyBytes)), - zap.Any("headers", resp.Header), - ) - - if resp.StatusCode == http.StatusTooManyRequests { - // Read the value of the "Retry-After" header to get the cooldown delay - retryAfter := resp.Header.Get("Retry-After") - - // Parse the value to get the duration - cooldownDelay, err := time.ParseDuration(retryAfter) - if err != nil { - return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) - } - - return clients.NewRateLimitError(&cooldownDelay) - } - - // Server & client errors result in the same error to keep gateway resilient - return clients.ErrProviderUnavailable -} diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index a985668e..d20879d1 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -7,10 +7,10 @@ import ( "fmt" "io" "net/http" - "time" "github.com/r3labs/sse/v2" "glide/pkg/providers/clients" + "glide/pkg/telemetry" "go.uber.org/zap" @@ -19,83 +19,84 @@ import ( var streamDoneMarker = []byte("[DONE]") -func (c *Client) SupportChatStream() bool { - return true +// ChatStream represents OpenAI chat stream for a specific request +type ChatStream struct { + tel *telemetry.Telemetry + client *http.Client + req *http.Request + resp *http.Response + reader *sse.EventStreamReader + errMapper *ErrorMapper } -func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - go c.streamChat(ctx, req, streamResultC) - - return streamResultC +func NewChatStream(tel *telemetry.Telemetry, client *http.Client, req *http.Request, errMapper *ErrorMapper) *ChatStream { + return &ChatStream{ + tel: tel, + client: client, + req: req, + errMapper: errMapper, + } } -func (c *Client) streamChat(ctx context.Context, request *schemas.ChatRequest, resultC chan *clients.ChatStreamResult) { - // Create a new chat request - resp, err := c.initChatStream(ctx, request) - - defer close(resultC) - +func (s *ChatStream) Open() error { + resp, err := s.client.Do(s.req) //nolint:bodyclose if err != nil { - resultC <- clients.NewChatStreamResult(nil, err) - - return + return err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - resultC <- clients.NewChatStreamResult(nil, c.handleChatReqErrs(resp)) + return s.errMapper.Map(resp) } - reader := sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? + s.resp = resp + s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? + + return nil +} +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { - started_at := time.Now() - rawEvent, err := reader.ReadEvent() - chunkLatency := time.Since(started_at) - + rawEvent, err := s.reader.ReadEvent() if err != nil { if err == io.EOF { - c.tel.L().Debug("Chat stream is over", zap.String("provider", c.Provider())) + s.tel.L().Debug("Chat stream is over", zap.String("provider", providerName)) + + // TODO: This should be treated as an error probably (unexpected stream end) - return + return nil, io.EOF } - c.tel.L().Warn( + s.tel.L().Warn( "Chat stream is unexpectedly disconnected", - zap.String("provider", c.Provider()), + zap.String("provider", providerName), + zap.Error(err), ) - resultC <- clients.NewChatStreamResult(nil, clients.ErrProviderUnavailable) - - return + return nil, clients.ErrProviderUnavailable } - c.tel.L().Debug( + s.tel.L().Debug( "Raw chat stream chunk", - zap.String("provider", c.Provider()), + zap.String("provider", providerName), zap.ByteString("rawChunk", rawEvent), ) event, err := clients.ParseSSEvent(rawEvent) if bytes.Equal(event.Data, streamDoneMarker) { - return + return nil, io.EOF } if err != nil { - resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to parse chat stream message: %v", err)) - return + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) } if !event.HasContent() { - c.tel.L().Debug( + s.tel.L().Debug( "Received an empty message in chat stream, skipping it", - zap.String("provider", c.Provider()), + zap.String("provider", providerName), zap.Any("msg", event), ) @@ -104,18 +105,11 @@ func (c *Client) streamChat(ctx context.Context, request *schemas.ChatRequest, r err = json.Unmarshal(event.Data, &completionChunk) if err != nil { - resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)) - return + return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err) } - c.tel.L().Debug( - "Chat response chunk", - zap.String("provider", c.Provider()), - zap.Any("chunk", completionChunk), - ) - // TODO: use objectpool here - chatRespChunk := schemas.ChatStreamChunk{ + return &schemas.ChatStreamChunk{ ID: completionChunk.ID, Created: completionChunk.Created, Provider: providerName, @@ -130,19 +124,39 @@ func (c *Client) streamChat(ctx context.Context, request *schemas.ChatRequest, r Content: completionChunk.Choices[0].Delta.Content, }, }, - Latency: &chunkLatency, // TODO: Pass info if this is the final message - } + }, nil + } +} - resultC <- clients.NewChatStreamResult( - &chatRespChunk, - nil, - ) +func (s *ChatStream) Close() error { + if s.resp != nil { + return s.resp.Body.Close() } + + return nil } -// initChatStream establishes a new chat stream -func (c *Client) initChatStream(ctx context.Context, request *schemas.ChatRequest) (*http.Response, error) { +func (c *Client) SupportChatStream() bool { + return true +} + +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatRequest) (clients.ChatStream, error) { + // Create a new chat request + request, err := c.makeStreamReq(ctx, req) + if err != nil { + return nil, err + } + + return NewChatStream( + c.tel, + c.httpClient, + request, + c.errMapper, + ), nil +} + +func (c *Client) makeStreamReq(ctx context.Context, request *schemas.ChatRequest) (*http.Request, error) { chatRequest := *c.createChatRequestSchema(request) chatRequest.Stream = true @@ -169,10 +183,5 @@ func (c *Client) initChatStream(ctx context.Context, request *schemas.ChatReques zap.Any("payload", chatRequest), ) - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send OpenAI stream chat request: %w", err) - } - - return resp, nil + return req, nil } diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 60ea8384..31373579 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -76,11 +76,17 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { Content: "What's the capital of the United Kingdom?", }} - resultC := client.ChatStream(ctx, &req) + stream, err := client.ChatStream(ctx, &req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() - for chunkResult := range resultC { - require.NoError(t, chunkResult.Error()) - require.NotNil(t, chunkResult.Chunk().ModelResponse.Message.Content) + require.NoError(t, err) + require.NotNil(t, chunk) } }) } diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index 9e0b6fb9..22d68d45 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -23,6 +23,7 @@ type Client struct { baseURL string chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client tel *telemetry.Telemetry @@ -40,6 +41,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * chatURL: chatURL, config: providerConfig, chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + errMapper: NewErrorMapper(tel), httpClient: &http.Client{ Timeout: *clientConfig.Timeout, // TODO: use values from the config diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go new file mode 100644 index 00000000..49fdc412 --- /dev/null +++ b/pkg/providers/openai/errors.go @@ -0,0 +1,60 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.Logger.Error( + "Failed to unmarshal chat response error", + zap.String("provider", providerName), + zap.Error(err), + zap.ByteString("rawResponse", bodyBytes), + ) + + return clients.ErrProviderUnavailable + } + + m.tel.Logger.Error( + "Chat request failed", + zap.String("provider", providerName), + zap.Int("statusCode", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go index 2893108e..282cf7e8 100644 --- a/pkg/providers/testing.go +++ b/pkg/providers/testing.go @@ -73,24 +73,9 @@ func (c *ProviderMock) SupportChatStream() bool { return c.supportStreaming } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult { - streamResultC := make(chan *clients.ChatStreamResult) - - response := c.responses[c.idx] - c.idx++ - - go func() { - defer close(streamResultC) - - if response.Err != nil { - streamResultC <- clients.NewChatStreamResult(nil, *response.Err) - return - } - - streamResultC <- clients.NewChatStreamResult(response.RespChunk(), nil) - }() - - return streamResultC +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) { + // TODO: implement + return nil, nil } func (c *ProviderMock) Provider() string { diff --git a/pkg/routers/router.go b/pkg/routers/router.go index 49306a99..e6cf9791 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -151,7 +151,18 @@ func (r *LangRouter) ChatStream( } langModel := model.(providers.LangModel) - modelRespC := langModel.ChatStream(ctx, req) + modelRespC, err := langModel.ChatStream(ctx, req) + if err != nil { + r.tel.L().Error( + "Lang model failed to create streaming chat request", + zap.String("routerID", r.ID()), + zap.String("modelID", langModel.ID()), + zap.String("provider", langModel.Provider()), + zap.Error(err), + ) + + continue + } for chunkResult := range modelRespC { err = chunkResult.Error()