Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 #200: Implemented a custom json per line stream reader to read Cohere chat streams correctly #201

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"net/http"

"github.com/r3labs/sse/v2"
"glide/pkg/providers/clients"
"glide/pkg/telemetry"

Expand All @@ -17,25 +16,30 @@ import (
"glide/pkg/api/schemas"
)

// SupportedEventType Cohere has other types too:
// Ref: https://docs.cohere.com/reference/chat (see Chat -> Responses -> StreamedChatResponse)
type SupportedEventType = string

var (
TextGenEvent SupportedEventType = "text-generation"
StreamEndEvent SupportedEventType = "stream-end"
StreamStartEvent SupportedEventType = "stream-start"
TextGenEvent SupportedEventType = "text-generation"
StreamEndEvent SupportedEventType = "stream-end"
)

// ChatStream represents cohere chat stream for a specific request
type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID string
modelName string
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
generationID string
streamFinished bool
reader *StreamReader
errMapper *ErrorMapper
finishReasonMapper *FinishReasonMapper
tel *telemetry.Telemetry
}

func NewChatStream(
Expand All @@ -56,6 +60,7 @@ func NewChatStream(
modelName: modelName,
reqMetadata: reqMetadata,
errMapper: errMapper,
streamFinished: false,
finishReasonMapper: finishReasonMapper,
}
}
Expand All @@ -70,73 +75,65 @@ func (s *ChatStream) Open() error {
return s.errMapper.Map(resp)
}

s.tel.L().Debug("Resp Headers", zap.Any("headers", resp.Header))

s.resp = resp
s.reader = sse.NewEventStreamReader(resp.Body, 8192) // TODO: should we expose maxBufferSize?
s.reader = NewStreamReader(resp.Body, 8192) // TODO: should we expose maxBufferSize?

return nil
}

func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
if s.streamFinished {
return nil, io.EOF
}

var responseChunk ChatCompletionChunk

for {
rawEvent, err := s.reader.ReadEvent()
rawChunk, err := s.reader.ReadEvent()
if err != nil {
s.tel.L().Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", providerName),
zap.Error(err),
)

if err == io.EOF {
return nil, io.EOF
}

// if err is io.EOF, this still means that the stream is interrupted unexpectedly
// because the normal stream termination is done via finding out streamDoneMarker
// if io.EOF occurred in the middle of the stream, then the stream was interrupted

return nil, clients.ErrProviderUnavailable
}

s.tel.L().Debug(
"Raw chat stream chunk",
zap.String("provider", providerName),
zap.ByteString("rawChunk", rawEvent),
zap.ByteString("rawChunk", rawChunk),
)

event, err := clients.ParseSSEvent(rawEvent)
err = json.Unmarshal(rawChunk, &responseChunk)
if err != nil {
return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)
}

if !event.HasContent() {
s.tel.L().Debug(
"Received an empty message in chat stream, skipping it",
zap.String("provider", providerName),
zap.Any("msg", event),
)
if responseChunk.EventType == StreamStartEvent {
s.generationID = *responseChunk.GenerationID

continue
}

rawChunk := event.Data

err = json.Unmarshal(rawChunk, &responseChunk)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)
}

if responseChunk.EventType != TextGenEvent && responseChunk.EventType != StreamEndEvent {
s.tel.L().Debug(
"Unsupported stream chunk type, skipping it",
zap.String("provider", providerName),
zap.Any("chunk", string(rawChunk)),
zap.ByteString("chunk", rawChunk),
)

continue
}

if responseChunk.IsFinished {
s.streamFinished = true

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Expand All @@ -146,7 +143,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": responseChunk.Response.GenerationID,
"generationId": s.generationID,
"responseId": responseChunk.Response.ResponseID,
},
Message: schemas.ChatMessage{
Expand All @@ -166,6 +163,9 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
ModelName: s.modelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
},
Message: schemas.ChatMessage{
Role: "model",
Content: responseChunk.Text,
Expand Down
15 changes: 8 additions & 7 deletions pkg/providers/cohere/chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) {
t.Errorf("error reading cohere chat mock response: %v", err)
}

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Content-Type", "application/stream+json")

_, err = w.Write(chatResponse)
if err != nil {
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) {

func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {
tests := map[string]string{
"success stream, but with empty event": "./testdata/chat_stream.empty.txt",
"interrupted stream": "./testdata/chat_stream.interrupted.txt",
}

for name, streamFile := range tests {
Expand Down Expand Up @@ -141,16 +141,17 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {
err = stream.Open()
require.NoError(t, err)

for {
for range 5 {
chunk, err := stream.Recv()
if err != nil {
require.ErrorIs(t, err, io.EOF)
return
}

require.NoError(t, err)
require.NotNil(t, chunk)
}

chunk, err := stream.Recv()

require.Error(t, err)
require.Nil(t, chunk)
})
}
}
1 change: 1 addition & 0 deletions pkg/providers/cohere/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type ConnectorsResponse struct {
type ChatCompletionChunk struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
GenerationID *string `json:"generation_id"`
Text string `json:"text"`
Response *FinalResponse `json:"response,omitempty"`
FinishReason *string `json:"finish_reason,omitempty"`
Expand Down
73 changes: 73 additions & 0 deletions pkg/providers/cohere/stream_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package cohere

import (
"bufio"
"bytes"
"context"
"io"
)

// StreamReader reads Cohere streaming chat chunks that are formated
// as serializer chunk json per line (a.k.a. application/stream+json)
type StreamReader struct {
scanner *bufio.Scanner
}

func containNewline(data []byte) (int, int) {
return bytes.Index(data, []byte("\n")), 1
}

// NewStreamReader creates an instance of StreamReader
func NewStreamReader(stream io.Reader, maxBufferSize int) *StreamReader {
scanner := bufio.NewScanner(stream)

initBufferSize := min(4096, maxBufferSize)

scanner.Buffer(make([]byte, initBufferSize), maxBufferSize)

split := func(data []byte, atEOF bool) (int, []byte, error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}

// We have a full event payload to parse.
if i, nlen := containNewline(data); i >= 0 {
return i + nlen, data[0:i], nil
}

// If we're at EOF, we have all the data.
if atEOF {
return len(data), data, nil
}

// Request more data.

return 0, nil, nil
}

// Set the split function for the scanning operation.
scanner.Split(split)

return &StreamReader{
scanner: scanner,
}
}

// ReadEvent scans the EventStream for events.
func (r *StreamReader) ReadEvent() ([]byte, error) {
if r.scanner.Scan() {
event := r.scanner.Bytes()

return event, nil
}

if err := r.scanner.Err(); err != nil {
if err == context.Canceled {
return nil, io.EOF
}

return nil, err
}

return nil, io.EOF
}
Loading
Loading