Skip to content

Commit

Permalink
#163: Fixed issues found by linter
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Mar 13, 2024
1 parent 3c11462 commit 8894ab0
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 183 deletions.
9 changes: 2 additions & 7 deletions pkg/providers/anthropic/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
9 changes: 2 additions & 7 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
9 changes: 2 additions & 7 deletions pkg/providers/bedrock/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
6 changes: 6 additions & 0 deletions pkg/providers/clients/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import (
"glide/pkg/api/schemas"
)

type ChatStream interface {
Open() error
Recv() (*schemas.ChatStreamChunk, error)
Close() error
}

type ChatStreamResult struct {
chunk *schemas.ChatStreamChunk
err error
Expand Down
9 changes: 2 additions & 7 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
63 changes: 48 additions & 15 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package providers

import (
"context"
"io"
"time"

"glide/pkg/routers/health"
Expand All @@ -18,12 +19,14 @@ type LangProvider interface {
SupportChatStream() bool

Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult
ChatStream(ctx context.Context, req *schemas.ChatRequest) (clients.ChatStream, error)
}

type LangModel interface {
LangProvider
Model
Provider() string
Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error)
}

// LanguageModel wraps provider client and expend it with health & latency tracking
Expand Down Expand Up @@ -99,35 +102,65 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
return resp, err
}

func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) (<-chan *clients.ChatStreamResult, error) {
stream, err := m.client.ChatStream(ctx, req)
if err != nil {
return nil, err
}

streamResultC := make(chan *clients.ChatStreamResult)
resultC := m.client.ChatStream(ctx, req)

go func() {
defer close(streamResultC)

var chunkLatency *time.Duration
startedAt := time.Now()
err = stream.Open()
chunkLatency := time.Since(startedAt)

for chunkResult := range resultC {
if chunkResult.Error() == nil {
streamResultC <- chunkResult
// the first chunk latency
m.chatStreamLatency.Add(float64(chunkLatency))

chunkLatency = chunkResult.Chunk().Latency
if err != nil {
streamResultC <- clients.NewChatStreamResult(nil, err)

m.healthTracker.TrackErr(err)

return
}

if chunkLatency != nil {
m.chatStreamLatency.Add(float64(*chunkLatency))
defer stream.Close()

for {
startedAt = time.Now()
chunk, err := stream.Recv()
chunkLatency = time.Since(startedAt)

if err != nil {
if err == io.EOF {
// end of the stream
return
}

continue
streamResultC <- clients.NewChatStreamResult(nil, err)

m.healthTracker.TrackErr(err)

return
}

m.healthTracker.TrackErr(chunkResult.Error())
streamResultC <- clients.NewChatStreamResult(chunk, nil)

streamResultC <- chunkResult
if chunkLatency > 1*time.Millisecond {
// All events are read in a bigger chunks of bytes, so one chunk may contain more than one event.
// Each byte chunk is then parsed, so there is no easy way to precisely guess latency per chunk,
// So we assume that if we spent more than 1ms waiting for a chunk it's likely
// we were trying to read from the connection (otherwise, it would take nanoseconds)
m.chatStreamLatency.Add(float64(chunkLatency))
}
}
}()

return streamResultC
return streamResultC, nil
}

func (m *LanguageModel) Provider() string {
Expand Down
9 changes: 2 additions & 7 deletions pkg/providers/octoml/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
9 changes: 2 additions & 7 deletions pkg/providers/ollama/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) (clients.ChatStream, error) {
return nil, clients.ErrChatStreamNotImplemented
}
41 changes: 1 addition & 40 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import (
"fmt"
"io"
"net/http"
"time"

"glide/pkg/providers/clients"

"glide/pkg/api/schemas"
"go.uber.org/zap"
Expand Down Expand Up @@ -106,7 +103,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, c.handleChatReqErrs(resp)
return nil, c.errMapper.Map(resp)
}

// Read the response body into a byte slice
Expand Down Expand Up @@ -161,39 +158,3 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche

return &response, nil
}

func (c *Client) handleChatReqErrs(resp *http.Response) error {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.tel.Logger.Error(
"Failed to unmarshal chat response error",
zap.String("provider", c.Provider()),
zap.Error(err),
zap.ByteString("rawResponse", bodyBytes),
)
}

c.tel.Logger.Error(
"Chat request failed",
zap.String("provider", c.Provider()),
zap.Int("statusCode", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.Any("headers", resp.Header),
)

if resp.StatusCode == http.StatusTooManyRequests {
// Read the value of the "Retry-After" header to get the cooldown delay
retryAfter := resp.Header.Get("Retry-After")

// Parse the value to get the duration
cooldownDelay, err := time.ParseDuration(retryAfter)
if err != nil {
return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
}

return clients.NewRateLimitError(&cooldownDelay)
}

// Server & client errors result in the same error to keep gateway resilient
return clients.ErrProviderUnavailable
}
Loading

0 comments on commit 8894ab0

Please sign in to comment.