diff --git a/chat_stream.go b/chat_stream.go index 26e964c38..26c4bfc15 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -2,11 +2,7 @@ package openai import ( "bufio" - "bytes" "context" - "encoding/json" - "io" - "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -30,52 +26,7 @@ type ChatCompletionStreamResponse struct { // ChatCompletionStream // Note: Perhaps it is more elegant to abstract Stream using generics. type ChatCompletionStream struct { - emptyMessagesLimit uint - isFinished bool - - reader *bufio.Reader - response *http.Response -} - -func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) { - if stream.isFinished { - err = io.EOF - return - } - - var emptyMessagesCount uint - -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - return - } - - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return - } - - goto waitForData - } - - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return - } - - err = json.Unmarshal(line, &response) - return -} - -func (stream *ChatCompletionStream) Close() { - stream.response.Body.Close() + *streamReader[ChatCompletionStreamResponse] } // CreateChatCompletionStream — API call to create a chat completion w/ streaming @@ -98,9 +49,13 @@ func (c *Client) CreateChatCompletionStream( } stream = &ChatCompletionStream{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, + streamReader: &streamReader[ChatCompletionStreamResponse]{ + emptyMessagesLimit: c.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: newErrorAccumulator(), + unmarshaler: &jsonUnmarshaler{}, + }, } return } diff --git a/chat_stream_test.go b/chat_stream_test.go index e3da2daf7..de604fa8b 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -123,6 +123,73 @@ func TestCreateChatCompletionStream(t *testing.T) { } } +func TestCreateChatCompletionStreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataStr := []string{ + `{`, + `"error": {`, + `"message": "Incorrect API key provided: sk-***************************************",`, + `"type": "invalid_request_error",`, + `"param": null,`, + `"code": "invalid_api_key"`, + `}`, + `}`, + } + for _, str := range dataStr { + dataBytes = append(dataBytes, []byte(str+"\n")...) + } + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Write error: %s", err) + } + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, request) + if err != nil { + t.Errorf("CreateCompletionStream returned error: %v", err) + } + defer stream.Close() + + _, streamErr := stream.Recv() + if streamErr == nil { + t.Errorf("stream.Recv() did not return error") + } + var apiErr *APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/error_accumulator.go b/error_accumulator.go new file mode 100644 index 000000000..e75086e67 --- /dev/null +++ b/error_accumulator.go @@ -0,0 +1,51 @@ +package openai + +import ( + "bytes" + "fmt" + "io" +) + +type errorAccumulator interface { + write(p []byte) error + unmarshalError() (*ErrorResponse, error) +} + +type errorBuffer interface { + io.Writer + Len() int + Bytes() []byte +} + +type errorAccumulate struct { + buffer errorBuffer + unmarshaler unmarshaler +} + +func newErrorAccumulator() errorAccumulator { + return &errorAccumulate{ + buffer: &bytes.Buffer{}, + unmarshaler: &jsonUnmarshaler{}, + } +} + +func (e *errorAccumulate) write(p []byte) error { + _, err := e.buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +func (e *errorAccumulate) unmarshalError() (*ErrorResponse, error) { + var err error + if e.buffer.Len() > 0 { + var errRes ErrorResponse + err = e.unmarshaler.unmarshal(e.buffer.Bytes(), &errRes) + if err != nil { + return nil, err + } + return &errRes, nil + } + return nil, err +} diff --git a/error_accumulator_test.go b/error_accumulator_test.go new file mode 100644 index 000000000..d4008c06a --- /dev/null +++ b/error_accumulator_test.go @@ -0,0 +1,90 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bytes" + "context" + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test" +) + +var ( + errTestUnmarshalerFailed = errors.New("test unmarshaler failed") + errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") +) + +type ( + failingUnMarshaller struct{} + failingErrorBuffer struct{} +) + +func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) { + return 0, errTestErrorAccumulatorWriteFailed +} + +func (b *failingErrorBuffer) Len() int { + return 0 +} + +func (b *failingErrorBuffer) Bytes() []byte { + return []byte{} +} + +func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error { + return errTestUnmarshalerFailed +} + +func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { + accumulator := &errorAccumulate{ + buffer: &bytes.Buffer{}, + unmarshaler: &failingUnMarshaller{}, + } + + err := accumulator.write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + _, err = accumulator.unmarshalError() + if !errors.Is(err, errTestUnmarshalerFailed) { + t.Fatalf("Did not return error when unmarshaler failed: %v", err) + } +} + +func TestErrorByteWriteErrors(t *testing.T) { + accumulator := &errorAccumulate{ + buffer: &failingErrorBuffer{}, + unmarshaler: &jsonUnmarshaler{}, + } + err := accumulator.write([]byte("{")) + if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} + +func TestErrorAccumulatorWriteErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + ctx := context.Background() + + stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + if err != nil { + t.Fatal(err) + } + stream.errAccumulator = &errorAccumulate{ + buffer: &failingErrorBuffer{}, + unmarshaler: &jsonUnmarshaler{}, + } + + _, err = stream.Recv() + if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} diff --git a/stream.go b/stream.go index 0eed4aa73..322d27fb9 100644 --- a/stream.go +++ b/stream.go @@ -2,12 +2,8 @@ package openai import ( "bufio" - "bytes" "context" - "encoding/json" "errors" - "io" - "net/http" ) var ( @@ -15,52 +11,7 @@ var ( ) type CompletionStream struct { - emptyMessagesLimit uint - isFinished bool - - reader *bufio.Reader - response *http.Response -} - -func (stream *CompletionStream) Recv() (response CompletionResponse, err error) { - if stream.isFinished { - err = io.EOF - return - } - - var emptyMessagesCount uint - -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - return - } - - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return - } - - goto waitForData - } - - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return - } - - err = json.Unmarshal(line, &response) - return -} - -func (stream *CompletionStream) Close() { - stream.response.Body.Close() + *streamReader[CompletionResponse] } // CreateCompletionStream — API call to create a completion w/ streaming @@ -83,10 +34,13 @@ func (c *Client) CreateCompletionStream( } stream = &CompletionStream{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - - reader: bufio.NewReader(resp.Body), - response: resp, + streamReader: &streamReader[CompletionResponse]{ + emptyMessagesLimit: c.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: newErrorAccumulator(), + unmarshaler: &jsonUnmarshaler{}, + }, } return } diff --git a/stream_reader.go b/stream_reader.go new file mode 100644 index 000000000..07500a5d3 --- /dev/null +++ b/stream_reader.go @@ -0,0 +1,71 @@ +package openai + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/http" +) + +type streamable interface { + ChatCompletionStreamResponse | CompletionResponse +} + +type streamReader[T streamable] struct { + emptyMessagesLimit uint + isFinished bool + + reader *bufio.Reader + response *http.Response + errAccumulator errorAccumulator + unmarshaler unmarshaler +} + +func (stream *streamReader[T]) Recv() (response T, err error) { + if stream.isFinished { + err = io.EOF + return + } + + var emptyMessagesCount uint + +waitForData: + line, err := stream.reader.ReadBytes('\n') + if err != nil { + if errRes, _ := stream.errAccumulator.unmarshalError(); errRes != nil { + err = fmt.Errorf("error, %w", errRes.Error) + } + return + } + + var headerData = []byte("data: ") + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, headerData) { + if writeErr := stream.errAccumulator.write(line); writeErr != nil { + err = writeErr + return + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + err = ErrTooManyEmptyStreamMessages + return + } + + goto waitForData + } + + line = bytes.TrimPrefix(line, headerData) + if string(line) == "[DONE]" { + stream.isFinished = true + err = io.EOF + return + } + + err = stream.unmarshaler.unmarshal(line, &response) + return +} + +func (stream *streamReader[T]) Close() { + stream.response.Body.Close() +} diff --git a/stream_test.go b/stream_test.go index 8f89e6b85..ce560c644 100644 --- a/stream_test.go +++ b/stream_test.go @@ -100,6 +100,68 @@ func TestCreateCompletionStream(t *testing.T) { } } +func TestCreateCompletionStreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataStr := []string{ + `{`, + `"error": {`, + `"message": "Incorrect API key provided: sk-***************************************",`, + `"type": "invalid_request_error",`, + `"param": null,`, + `"code": "invalid_api_key"`, + `}`, + `}`, + } + for _, str := range dataStr { + dataBytes = append(dataBytes, []byte(str+"\n")...) + } + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Write error: %s", err) + } + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Prompt: "Hello!", + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + if err != nil { + t.Errorf("CreateCompletionStream returned error: %v", err) + } + defer stream.Close() + + _, streamErr := stream.Recv() + if streamErr == nil { + t.Errorf("stream.Recv() did not return error") + } + var apiErr *APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + // A "tokenRoundTripper" is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each diff --git a/unmarshaler.go b/unmarshaler.go new file mode 100644 index 000000000..05218f764 --- /dev/null +++ b/unmarshaler.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type unmarshaler interface { + unmarshal(data []byte, v any) error +} + +type jsonUnmarshaler struct{} + +func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +}