diff --git a/api_internal_test.go b/api_internal_test.go index 214b627bf..0fb0f8993 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "") + req, err := cli.newRequest(context.Background(), "POST", "/chat/completions") if err != nil { t.Errorf("Failed to create request: %v", err) } diff --git a/audio.go b/audio.go index adfc52766..9f469159d 100644 --- a/audio.go +++ b/audio.go @@ -95,11 +95,11 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), + withBody(&formBody), withContentType(builder.FormDataContentType())) if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) diff --git a/chat.go b/chat.go index f99af2735..b74720d38 100644 --- a/chat.go +++ b/chat.go @@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 75aa6858a..9f4e80cff 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -1,10 +1,8 @@ package openai import ( - "bufio" "context" - - utils "github.com/sashabaranov/go-openai/internal" + "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &ChatCompletionStream{ - streamReader: &streamReader[ChatCompletionStreamResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/client.go b/client.go index f38c1dfc3..5779a8e1c 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package openai import ( + "bufio" "context" "encoding/json" "fmt" @@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client { return NewClientWithConfig(config) } +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +func withBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +func withContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { + // Default Options + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header) + if err != nil { + return nil, err + } + c.setCommonHeaders(req) + return req, nil +} + func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") @@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - c.setCommonHeaders(req) - res, err := c.config.HTTPClient.Do(req) if err != nil { return err @@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return decodeResponse(res.Body, v) } +func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { + resp, err := c.config.HTTPClient.Do(req) + if err != nil { + return + } + + if isFailureStatusCode(resp) { + err = c.handleErrorResp(resp) + return + } + return resp.Body, nil +} + +func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return new(streamReader[T]), err + } + if isFailureStatusCode(resp) { + return new(streamReader[T]), client.handleErrorResp(resp) + } + return &streamReader[T]{ + emptyMessagesLimit: client.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + }, nil +} + func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication @@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } -func (c *Client) newStreamRequest( - ctx context.Context, - method string, - urlSuffix string, - body any, - model string) (*http.Request, error) { - req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - c.setCommonHeaders(req) - return req, nil -} - func (c *Client) handleErrorResp(resp *http.Response) error { var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) diff --git a/client_test.go b/client_test.go index 00b66feae..29d84edfa 100644 --- a/client_test.go +++ b/client_test.go @@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed") type failingRequestBuilder struct{} -func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) { +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) { return nil, errTestRequestBuilderFailed } @@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) { stringInput := "" testCases := []struct { - name string - value interface{} - body io.Reader + name string + value interface{} + body io.Reader + hasError bool }{ { name: "nil input", @@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) { value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), }, + { + name: "reader return error", + value: &stringInput, + body: &errorReader{err: errors.New("dummy")}, + hasError: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if err != nil { + if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) } }) } } +type errorReader struct { + err error +} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, e.err +} + func TestHandleErrorResp(t *testing.T) { // var errRes *ErrorResponse var errRes ErrorResponse diff --git a/completion.go b/completion.go index e0571b007..b3b3abd1c 100644 --- a/completion.go +++ b/completion.go @@ -165,7 +165,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/edits.go b/edits.go index 23b1a64f0..3d3fc8950 100644 --- a/edits.go +++ b/edits.go @@ -32,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 942f3ea3a..ba327ce77 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) if err != nil { return } diff --git a/engines.go b/engines.go index ac01a00ed..adf6025c2 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines")) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/files.go b/files.go index fb9937bea..ea1f50a73 100644 --- a/files.go +++ b/files.go @@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) - err = c.sendRequest(req, &file) - return } // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID)) if err != nil { return } @@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files")) if err != nil { return } @@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) - if err != nil { - return - } - - c.setCommonHeaders(req) - - res, err := c.config.HTTPClient.Do(req) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - if isFailureStatusCode(res) { - err = c.handleErrorResp(res) - return - } - - content = res.Body + content, err = c.sendRequestRaw(req) return } diff --git a/fine_tunes.go b/fine_tunes.go index 069ddccfd..96e731d51 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return } @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { return } @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { return } @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { return } diff --git a/image.go b/image.go index df7363865..cb96f4f5e 100644 --- a/image.go +++ b/image.go @@ -44,7 +44,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -107,13 +107,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - urlSuffix := "/images/edits" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -158,14 +157,12 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - //https://platform.openai.com/docs/api-reference/images/create-variation - urlSuffix := "/images/variations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/internal/request_builder.go b/internal/request_builder.go index 0a9eabfde..5699f6b18 100644 --- a/internal/request_builder.go +++ b/internal/request_builder.go @@ -3,11 +3,12 @@ package openai import ( "bytes" "context" + "io" "net/http" ) type RequestBuilder interface { - Build(ctx context.Context, method, url string, request any) (*http.Request, error) + Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) } type HTTPRequestBuilder struct { @@ -20,21 +21,32 @@ func NewRequestBuilder() *HTTPRequestBuilder { } } -func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) { - if request == nil { - return http.NewRequestWithContext(ctx, method, url, nil) +func (b *HTTPRequestBuilder) Build( + ctx context.Context, + method string, + url string, + body any, + header http.Header, +) (req *http.Request, err error) { + var bodyReader io.Reader + if body != nil { + if v, ok := body.(io.Reader); ok { + bodyReader = v + } else { + var reqBytes []byte + reqBytes, err = b.marshaller.Marshal(body) + if err != nil { + return + } + bodyReader = bytes.NewBuffer(reqBytes) + } } - - var reqBytes []byte - reqBytes, err := b.marshaller.Marshal(request) + req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) if err != nil { - return nil, err + return } - - return http.NewRequestWithContext( - ctx, - method, - url, - bytes.NewBuffer(reqBytes), - ) + if header != nil { + req.Header = header + } + return } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e47d0f6ca..e26022a6b 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -22,7 +22,7 @@ func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { marshaller: &failingMarshaller{}, } - _, err := builder.Build(context.Background(), "", "", struct{}{}) + _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) if !errors.Is(err, errTestMarshallerFailed) { t.Fatalf("Did not return error when marshaller failed: %v", err) } @@ -38,7 +38,7 @@ func TestRequestBuilderReturnsRequest(t *testing.T) { reqBytes, _ = b.marshaller.Marshal(request) want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) ) - got, _ := b.Build(ctx, method, url, request) + got, _ := b.Build(ctx, method, url, request, nil) if !reflect.DeepEqual(got.Body, want.Body) || !reflect.DeepEqual(got.URL, want.URL) || !reflect.DeepEqual(got.Method, want.Method) { @@ -54,7 +54,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { want, _ = http.NewRequestWithContext(ctx, method, url, nil) ) b := NewRequestBuilder() - got, _ := b.Build(ctx, method, url, nil) + got, _ := b.Build(ctx, method, url, nil, nil) if !reflect.DeepEqual(got, want) { t.Errorf("Build() got = %v, want %v", got, want) } diff --git a/models.go b/models.go index b3d458366..560402e3f 100644 --- a/models.go +++ b/models.go @@ -41,7 +41,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models")) if err != nil { return } @@ -54,7 +54,7 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) // the model such as the owner and permissioning. func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { urlSuffix := fmt.Sprintf("/models/%s", modelID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/models_test.go b/models_test.go index 0b4daf4a8..59b4f5ef7 100644 --- a/models_test.go +++ b/models_test.go @@ -1,6 +1,9 @@ package openai_test import ( + "os" + "time" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -56,3 +59,22 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(Model{}) fmt.Fprintln(w, string(resBytes)) } + +func TestGetModelReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetModel(ctx, "text-davinci-003") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/moderation.go b/moderation.go index bae788035..a58d759c0 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return } diff --git a/stream.go b/stream.go index 94cc0a0a2..b277f3c29 100644 --- a/stream.go +++ b/stream.go @@ -1,11 +1,8 @@ package openai import ( - "bufio" "context" "errors" - - utils "github.com/sashabaranov/go-openai/internal" ) var ( @@ -36,27 +33,17 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[CompletionResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &CompletionStream{ - streamReader: &streamReader[CompletionResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/stream_test.go b/stream_test.go index 5997f27e8..f3f8f85cd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,7 +6,9 @@ import ( "errors" "io" "net/http" + "os" "testing" + "time" . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -300,6 +302,30 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { } } +func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {