From d6ab1b3a4f86d82a06ab601b3d3db85c1662a939 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Wed, 19 Apr 2023 20:05:00 +0800 Subject: [PATCH] fix: chat stream resp error (#259) --- api_test.go | 17 ++++++------ chat_stream.go | 4 +++ chat_stream_test.go | 51 ++++++++++++++++++++++++++++++++++++ client.go | 26 +++++++++++-------- error.go | 16 ++++++------ error_accumulator_test.go | 7 ++++- stream.go | 4 +++ stream_test.go | 54 ++++++++++++++++++++++++++++++++++++--- 8 files changed, 146 insertions(+), 33 deletions(-) diff --git a/api_test.go b/api_test.go index d6ad78932..ecba25625 100644 --- a/api_test.go +++ b/api_test.go @@ -1,16 +1,15 @@ package openai_test import ( - "encoding/json" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" + "encoding/json" "errors" "io" "os" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestAPI(t *testing.T) { @@ -119,8 +118,8 @@ func TestAPIError(t *testing.T) { t.Fatalf("Error is not an APIError: %+v", err) } - if apiErr.StatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) + if apiErr.HTTPStatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) } switch v := apiErr.Code.(type) { @@ -239,8 +238,8 @@ func TestRequestError(t *testing.T) { t.Fatalf("Error is not a RequestError: %+v", err) } - if reqErr.StatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) + if reqErr.HTTPStatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) } if reqErr.Unwrap() == nil { diff --git a/chat_stream.go b/chat_stream.go index 821129295..b5257ccc4 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,6 +3,7 @@ package openai import ( "bufio" "context" + "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -53,6 +54,9 @@ func (c *Client) CreateChatCompletionStream( if err != nil { return } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + return nil, c.handleErrorResp(resp) + } stream = &ChatCompletionStream{ streamReader: &streamReader[ChatCompletionStreamResponse]{ diff --git a/chat_stream_test.go b/chat_stream_test.go index 24046db6c..afcb86d5e 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -204,6 +204,57 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + + // Send test responses + dataBytes := []byte(`{"error":{` + + `"message": "You are sending requests too quickly.",` + + `"type":"rate_limit_reached",` + + `"param":null,` + + `"code":"rate_limit_reached"}}`) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + var apiErr *APIError + _, err := client.CreateChatCompletionStream(ctx, request) + if !errors.As(err, &apiErr) { + t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/client.go b/client.go index c1f76d7d7..b15a18ae1 100644 --- a/client.go +++ b/client.go @@ -72,17 +72,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { defer res.Body.Close() if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - var errRes ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errRes) - if err != nil || errRes.Error == nil { - reqErr := RequestError{ - StatusCode: res.StatusCode, - Err: err, - } - return fmt.Errorf("error, %w", &reqErr) - } - errRes.Error.StatusCode = res.StatusCode - return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error) + return c.handleErrorResp(res) } if v != nil { @@ -132,3 +122,17 @@ func (c *Client) newStreamRequest( } return req, nil } + +func (c *Client) handleErrorResp(resp *http.Response) error { + var errRes ErrorResponse + err := json.NewDecoder(resp.Body).Decode(&errRes) + if err != nil || errRes.Error == nil { + reqErr := RequestError{ + HTTPStatusCode: resp.StatusCode, + Err: err, + } + return fmt.Errorf("error, %w", &reqErr) + } + errRes.Error.HTTPStatusCode = resp.StatusCode + return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error) +} diff --git a/error.go b/error.go index 32ffa6cc8..8aee6708b 100644 --- a/error.go +++ b/error.go @@ -7,17 +7,17 @@ import ( // APIError provides error information returned by the OpenAI API. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - StatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatusCode int `json:"-"` } // RequestError provides informations about generic request errors. type RequestError struct { - StatusCode int - Err error + HTTPStatusCode int + Err error } type ErrorResponse struct { @@ -73,7 +73,7 @@ func (e *RequestError) Error() string { if e.Err != nil { return e.Err.Error() } - return fmt.Sprintf("status code %d", e.StatusCode) + return fmt.Sprintf("status code %d", e.HTTPStatusCode) } func (e *RequestError) Unwrap() error { diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 637bf3678..ecf954d58 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "net/http" "testing" "github.com/sashabaranov/go-openai/internal/test" @@ -71,7 +72,11 @@ func TestErrorByteWriteErrors(t *testing.T) { func TestErrorAccumulatorWriteErrors(t *testing.T) { var err error - ts := test.NewTestServer().OpenAITestServer() + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", 200) + }) + ts := server.OpenAITestServer() ts.Start() defer ts.Close() diff --git a/stream.go b/stream.go index 64688cdce..95662db6d 100644 --- a/stream.go +++ b/stream.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "errors" + "net/http" ) var ( @@ -43,6 +44,9 @@ func (c *Client) CreateCompletionStream( if err != nil { return } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + return nil, c.handleErrorResp(resp) + } stream = &CompletionStream{ streamReader: &streamReader[CompletionResponse]{ diff --git a/stream_test.go b/stream_test.go index a80504d24..a5c591fde 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,16 +1,16 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "errors" "io" "net/http" "net/http/httptest" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { @@ -171,6 +171,52 @@ func TestCreateCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateCompletionStreamRateLimitError(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + + // Send test responses + dataBytes := []byte(`{"error":{` + + `"message": "You are sending requests too quickly.",` + + `"type":"rate_limit_reached",` + + `"param":null,` + + `"code":"rate_limit_reached"}}`) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + MaxTokens: 5, + Model: GPT3Ada, + Prompt: "Hello!", + Stream: true, + } + + var apiErr *APIError + _, err := client.CreateCompletionStream(ctx, request) + if !errors.As(err, &apiErr) { + t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + // A "tokenRoundTripper" is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each