Skip to content

Commit

Permalink
#163: Changed a way to interact with clients to let clients manage re…
Browse files Browse the repository at this point in the history
…sponseChunk channels. The channel can contain both success and error messages/chunks. Separated stream and sync chat schemas
  • Loading branch information
roma-glushko committed Mar 11, 2024
1 parent 9657416 commit 548ea18
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 67 deletions.
22 changes: 15 additions & 7 deletions pkg/api/http/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
wg sync.WaitGroup
)

chatResponseC := make(chan schemas.ChatResponse)
chunkResultC := make(chan *schemas.ChatStreamResult)

router, _ := routerManager.GetLangRouter(routerID)

defer c.Conn.Close()
Expand All @@ -133,8 +134,16 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
go func() {
defer wg.Done()

for response := range chatResponseC {
if err = c.WriteJSON(response); err != nil {
for chunkResult := range chunkResultC {
if chunkResult.Error() != nil {
if err = c.WriteJSON(chunkResult.Error()); err != nil {
break
}

continue
}

if err = c.WriteJSON(chunkResult.Chunk()); err != nil {
break
}
}
Expand All @@ -157,13 +166,12 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
go func(chatRequest schemas.ChatRequest) {
defer wg.Done()

if err = router.ChatStream(context.Background(), &chatRequest, chatResponseC); err != nil {
tel.L().Error("Failed to process streaming chat request", zap.Error(err), zap.String("routerID", routerID))
}
router.ChatStream(context.Background(), &chatRequest, chunkResultC)
}(chatRequest)
}

close(chatResponseC)
close(chunkResultC)

wg.Wait()
})
}
Expand Down
File renamed without changes.
53 changes: 53 additions & 0 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package schemas

// ChatStreamRequest defines a message that requests a new streaming chat
type ChatStreamRequest struct {
// TODO: implement
}

// ChatStreamChunk defines a message for a chunk of streaming chat response
type ChatStreamChunk struct {
// TODO: modify according to the streaming chat needs
ID string `json:"id,omitempty"`
Created int `json:"created,omitempty"`
Provider string `json:"provider,omitempty"`
RouterID string `json:"router,omitempty"`
ModelID string `json:"model_id,omitempty"`
ModelName string `json:"model,omitempty"`
Cached bool `json:"cached,omitempty"`
ModelResponse ModelResponse `json:"modelResponse,omitempty"`
// TODO: add chat request-specific context
}

type ChatStreamError struct {
// TODO: add chat request-specific context
Reason string `json:"reason"`
Message string `json:"message"`
}

type ChatStreamResult struct {
chunk *ChatStreamChunk
err *ChatStreamError
}

func (r *ChatStreamResult) Chunk() *ChatStreamChunk {
return r.chunk
}

func (r *ChatStreamResult) Error() *ChatStreamError {
return r.err
}

func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult {
return &ChatStreamResult{
chunk: chunk,
err: nil,
}
}

func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult {
return &ChatStreamResult{
chunk: nil,
err: err,
}
}
9 changes: 7 additions & 2 deletions pkg/providers/anthropic/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
9 changes: 7 additions & 2 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
9 changes: 7 additions & 2 deletions pkg/providers/bedrock/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
23 changes: 23 additions & 0 deletions pkg/providers/clients/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package clients

import "glide/pkg/api/schemas"

type ChatStreamResult struct {
chunk *schemas.ChatStreamChunk
err error
}

func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk {
return r.chunk
}

func (r *ChatStreamResult) Error() error {
return r.err
}

func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult {
return &ChatStreamResult{
chunk: chunk,
err: err,
}
}
9 changes: 7 additions & 2 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
40 changes: 26 additions & 14 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ type LangProvider interface {

SupportChatStream() bool

Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error
Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult
}

type LangModel interface {
Expand Down Expand Up @@ -105,24 +105,36 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
return resp, err
}

func (m *LanguageModel) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error {
err := m.client.ChatStream(ctx, request, responseC)
func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)
resultC := m.client.ChatStream(ctx, req)

if err == nil {
return err
}
go func() {
defer close(streamResultC)

var rateLimitErr *clients.RateLimitError
for chunkResult := range resultC {
if chunkResult.Error() == nil {
streamResultC <- chunkResult
// TODO: calculate latency
continue
}

if errors.As(err, &rateLimitErr) {
m.rateLimit.SetLimited(rateLimitErr.UntilReset())
var rateLimitErr *clients.RateLimitError

return err
}
if errors.As(chunkResult.Error(), &rateLimitErr) {
m.rateLimit.SetLimited(rateLimitErr.UntilReset())

_ = m.errBudget.Take(1)
streamResultC <- chunkResult

continue
}

_ = m.errBudget.Take(1)
streamResultC <- chunkResult
}
}()

return err
return streamResultC
}

func (m *LanguageModel) SupportChatStream() bool {
Expand Down
9 changes: 7 additions & 2 deletions pkg/providers/octoml/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
9 changes: 7 additions & 2 deletions pkg/providers/ollama/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ func (c *Client) SupportChatStream() bool {
return false
}

func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest, _ chan<- schemas.ChatResponse) error {
return clients.ErrChatStreamNotImplemented
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

streamResultC <- clients.NewChatStreamResult(nil, clients.ErrChatStreamNotImplemented)
close(streamResultC)

return streamResultC
}
42 changes: 30 additions & 12 deletions pkg/providers/openai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/r3labs/sse/v2"
"glide/pkg/providers/clients"

"go.uber.org/zap"

"glide/pkg/api/schemas"
Expand All @@ -21,17 +22,30 @@ func (c *Client) SupportChatStream() bool {
return true
}

func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, responseC chan<- schemas.ChatResponse) error {
func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatRequest) <-chan *clients.ChatStreamResult {
streamResultC := make(chan *clients.ChatStreamResult)

go c.streamChat(ctx, req, streamResultC)

return streamResultC
}

func (c *Client) streamChat(ctx context.Context, request *schemas.ChatRequest, resultC chan *clients.ChatStreamResult) {
// Create a new chat request
resp, err := c.initChatStream(ctx, request)

defer close(resultC)

if err != nil {
return err
resultC <- clients.NewChatStreamResult(nil, err)

return
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return c.handleChatReqErrs(resp)
resultC <- clients.NewChatStreamResult(nil, c.handleChatReqErrs(resp))
}

reader := sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?
Expand All @@ -44,15 +58,17 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r
if err == io.EOF {
c.tel.L().Debug("Chat stream is over", zap.String("provider", c.Provider()))

return nil
return
}

c.tel.L().Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", c.Provider()),
)

return clients.ErrProviderUnavailable
resultC <- clients.NewChatStreamResult(nil, clients.ErrProviderUnavailable)

return
}

c.tel.L().Debug(
Expand All @@ -64,15 +80,16 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r
event, err := clients.ParseSSEvent(rawEvent)

if bytes.Equal(event.Data, streamDoneMarker) {
return nil
return
}

if err != nil {
return fmt.Errorf("failed to parse chat stream message: %v", err)
resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to parse chat stream message: %v", err))
return
}

if !event.HasContent() {
c.tel.Logger.Debug(
c.tel.L().Debug(
"Received an empty message in chat stream, skipping it",
zap.String("provider", c.Provider()),
zap.Any("msg", event),
Expand All @@ -83,7 +100,8 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r

err = json.Unmarshal(event.Data, &completionChunk)
if err != nil {
return fmt.Errorf("failed to unmarshal chat stream message: %v", err)
resultC <- clients.NewChatStreamResult(nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err))
return
}

c.tel.L().Debug(
Expand All @@ -93,7 +111,7 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r
)

// TODO: use objectpool here
chatResponse := schemas.ChatResponse{
chatRespChunk := schemas.ChatStreamChunk{
ID: completionChunk.ID,
Created: completionChunk.Created,
Provider: providerName,
Expand All @@ -111,7 +129,7 @@ func (c *Client) ChatStream(ctx context.Context, request *schemas.ChatRequest, r
// TODO: Pass info if this is the final message
}

responseC <- chatResponse
resultC <- clients.NewChatStreamResult(&chatRespChunk, nil)
}
}

Expand All @@ -137,7 +155,7 @@ func (c *Client) initChatStream(ctx context.Context, request *schemas.ChatReques
req.Header.Set("Connection", "keep-alive")

// TODO: this could leak information from messages which may not be a desired thing to have
c.tel.Logger.Debug(
c.tel.L().Debug(
"Stream chat request",
zap.String("chatURL", c.chatURL),
zap.Any("payload", chatRequest),
Expand Down
Loading

0 comments on commit 548ea18

Please sign in to comment.