diff --git a/api_test.go b/api_test.go index 78fd5cc6d..083b67412 100644 --- a/api_test.go +++ b/api_test.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "os" "testing" @@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { } func TestRequestError(t *testing.T) { - var err error - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) - })) - defer ts.Close() + }) - config := DefaultConfig("dummy") - config.BaseURL = ts.URL - c := NewClientWithConfig(config) - ctx := context.Background() - _, err = c.ListEngines(ctx) + _, err := client.ListEngines(context.Background()) checks.HasError(t, err, "ListEngines did not fail") var reqErr *RequestError diff --git a/audio_api_test.go b/audio_api_test.go new file mode 100644 index 000000000..aad7a225a --- /dev/null +++ b/audio_api_test.go @@ -0,0 +1,162 @@ +package openai_test + +import ( + "bytes" + "context" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strings" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. +func TestAudio(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + + t.Run(tc.name+" (with reader)", func(t *testing.T) { + req := AudioRequest{ + FilePath: "fake.webm", + Reader: bytes.NewBuffer([]byte(`some webm binary data`)), + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +func TestAudioWithOptionalArgs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + Prompt: "用简体中文", + Temperature: 0.5, + Language: "zh", + Format: AudioResponseFormatSRT, + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +// handleAudioEndpoint Handles the completion endpoint by the test server. +func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + + mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if !strings.HasPrefix(mediaType, "multipart") { + http.Error(w, "request is not multipart", http.StatusBadRequest) + } + + boundary, ok := params["boundary"] + if !ok { + http.Error(w, "no boundary in params", http.StatusBadRequest) + return + } + + fileData := &bytes.Buffer{} + mr := multipart.NewReader(r.Body, boundary) + part, err := mr.NextPart() + if err != nil && errors.Is(err, io.EOF) { + http.Error(w, "error accessing file", http.StatusBadRequest) + return + } + if _, err = io.Copy(fileData, part); err != nil { + http.Error(w, "failed to copy file", http.StatusInternalServerError) + return + } + + if len(fileData.Bytes()) == 0 { + w.WriteHeader(http.StatusInternalServerError) + http.Error(w, "received empty file data", http.StatusBadRequest) + return + } + + if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } +} diff --git a/audio_test.go b/audio_test.go index 6452e2eb7..e19a873f3 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field import ( "bytes" - "context" - "errors" "fmt" "io" - "mime" - "mime/multipart" - "net/http" "os" "path/filepath" - "strings" "testing" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) -// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. -func TestAudio(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, - } - - ctx := context.Background() - - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - - t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ - FilePath: "fake.webm", - Reader: bytes.NewBuffer([]byte(`some webm binary data`)), - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - } -} - -func TestAudioWithOptionalArgs(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, - } - - ctx := context.Background() - - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - Prompt: "用简体中文", - Temperature: 0.5, - Language: "zh", - Format: AudioResponseFormatSRT, - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - } -} - -// handleAudioEndpoint Handles the completion endpoint by the test server. -func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - - // audio endpoints only accept POST requests - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } - - mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - http.Error(w, "failed to parse media type", http.StatusBadRequest) - return - } - - if !strings.HasPrefix(mediaType, "multipart") { - http.Error(w, "request is not multipart", http.StatusBadRequest) - } - - boundary, ok := params["boundary"] - if !ok { - http.Error(w, "no boundary in params", http.StatusBadRequest) - return - } - - fileData := &bytes.Buffer{} - mr := multipart.NewReader(r.Body, boundary) - part, err := mr.NextPart() - if err != nil && errors.Is(err, io.EOF) { - http.Error(w, "error accessing file", http.StatusBadRequest) - return - } - if _, err = io.Copy(fileData, part); err != nil { - http.Error(w, "failed to copy file", http.StatusInternalServerError) - return - } - - if len(fileData.Bytes()) == 0 { - w.WriteHeader(http.StatusInternalServerError) - http.Error(w, "received empty file data", http.StatusBadRequest) - return - } - - if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { - http.Error(w, "failed to write body", http.StatusInternalServerError) - return - } -} - func TestAudioWithFailingFormBuilder(t *testing.T) { dir, cleanup := test.CreateTestDirectory(t) defer cleanup() diff --git a/chat_stream_test.go b/chat_stream_test.go index 19c2e3cd0..c3cb9f3f7 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,8 +1,7 @@ -package openai //nolint:testpackage // testing private field +package openai_test import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -10,7 +9,6 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "testing" ) @@ -37,7 +35,9 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } func TestCreateChatCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -57,21 +57,9 @@ func TestCreateChatCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -81,9 +69,7 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -143,7 +129,9 @@ func TestCreateChatCompletionStream(t *testing.T) { } func TestCreateChatCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -164,21 +152,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -188,9 +164,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { }, }, Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -205,7 +179,8 @@ func TestCreateChatCompletionStreamError(t *testing.T) { } func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -220,22 +195,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - - request := ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -245,10 +205,8 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { }, }, Stream: true, - } - + }) var apiErr *APIError - _, err := client.CreateChatCompletionStream(ctx, request) if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -262,7 +220,8 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { "Please retry after 20 seconds. " + "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." - server := test.NewTestServer() + client, server, teardown := setupAzureTestServer() + defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -273,17 +232,9 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), ts.URL) - client := NewClientWithConfig(config) - ctx := context.Background() - request := ChatCompletionRequest{ + apiErr := &APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -293,10 +244,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { }, }, Stream: true, - } - - apiErr := &APIError{} - _, err = client.CreateChatCompletionStream(ctx, request) + }) if !errors.As(err, &apiErr) { t.Errorf("Did not return APIError: %+v\n", apiErr) return @@ -316,33 +264,6 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } -func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error()) -} - // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/chat_test.go b/chat_test.go index ce302a69f..ebe29f9eb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -52,20 +51,10 @@ func TestChatCompletionsWithStream(t *testing.T) { // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -74,8 +63,7 @@ func TestChatCompletions(t *testing.T) { Content: "Hello!", }, }, - } - _, err = client.CreateChatCompletion(ctx, req) + }) checks.NoError(t, err, "CreateChatCompletion error") } diff --git a/client_test.go b/client_test.go index 7724dcec0..bc62a62b3 100644 --- a/client_test.go +++ b/client_test.go @@ -170,16 +170,9 @@ func TestHandleErrorResp(t *testing.T) { } func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} - ctx := context.Background() type TestCase struct { @@ -257,7 +250,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } for _, testCase := range testCases { - _, err = testCase.TestFunc() + _, err := testCase.TestFunc() if !errors.Is(err, errTestRequestBuilderFailed) { t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) } @@ -265,23 +258,14 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + _, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { t.Fatalf("Did not return error when request builder failed: %v", err) } - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { t.Fatalf("Did not return error when request builder failed: %v", err) diff --git a/completion_test.go b/completion_test.go index 2e302591a..aeddcfca1 100644 --- a/completion_test.go +++ b/completion_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -48,25 +47,15 @@ func TestCompletionWithStream(t *testing.T) { // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - req := CompletionRequest{ MaxTokens: 5, Model: "ada", + Prompt: "Lorem ipsum", } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) + _, err := client.CreateCompletion(context.Background(), req) checks.NoError(t, err, "CreateCompletion error") } diff --git a/edits_test.go b/edits_test.go index fa6c12825..c0bb84392 100644 --- a/edits_test.go +++ b/edits_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -16,19 +15,9 @@ import ( // TestEdits Tests the edits endpoint of the API using the mocked server. func TestEdits(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/edits", handleEditEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - // create an edit request model := "ada" editReq := EditsRequest{ @@ -40,7 +29,7 @@ func TestEdits(t *testing.T) { Instruction: "test instruction", N: 3, } - response, err := client.Edits(ctx, editReq) + response, err := client.Edits(context.Background(), editReq) checks.NoError(t, err, "Edits error") if len(response.Choices) != editReq.N { t.Fatalf("edits does not properly return the correct number of choices") diff --git a/embeddings_test.go b/embeddings_test.go index 252f7a5a0..d7892cd5d 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -67,7 +66,8 @@ func TestEmbeddingModel(t *testing.T) { } func TestEmbeddingEndpoint(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { @@ -75,17 +75,6 @@ func TestEmbeddingEndpoint(t *testing.T) { fmt.Fprintln(w, string(resBytes)) }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") } diff --git a/engines_test.go b/engines_test.go index dfa3187cf..2beb333b3 100644 --- a/engines_test.go +++ b/engines_test.go @@ -8,27 +8,29 @@ import ( "testing" . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. func TestGetEngine(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { resBytes, _ := json.Marshal(Engine{}) fmt.Fprintln(w, string(resBytes)) }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetEngine(ctx, "text-davinci-003") + _, err := client.GetEngine(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetEngine error") } + +// TestListEngines Tests the list engines endpoint of the API using the mocked server. +func TestListEngines(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(EnginesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListEngines(context.Background()) + checks.NoError(t, err, "ListEngines error") +} diff --git a/files_api_test.go b/files_api_test.go new file mode 100644 index 000000000..f0a08764d --- /dev/null +++ b/files_api_test.go @@ -0,0 +1,183 @@ +package openai_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strconv" + "testing" + "time" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestFileUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := FileRequest{ + FileName: "test.go", + FilePath: "client.go", + Purpose: "fine-tune", + } + _, err := client.CreateFile(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + +// handleCreateFile Handles the images endpoint by the test server. +func handleCreateFile(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + err = r.ParseMultipartForm(1024 * 1024 * 1024) + if err != nil { + http.Error(w, "file is more than 1GB", http.StatusInternalServerError) + return + } + + values := r.Form + var purpose string + for key, value := range values { + if key == "purpose" { + purpose = value[0] + } + } + file, header, err := r.FormFile("file") + if err != nil { + return + } + defer file.Close() + + var fileReq = File{ + Bytes: int(header.Size), + ID: strconv.Itoa(int(time.Now().Unix())), + FileName: header.Filename, + Purpose: purpose, + CreatedAt: time.Now().Unix(), + Object: "test-objecct", + Owner: "test-owner", + } + + resBytes, _ = json.Marshal(fileReq) + fmt.Fprint(w, string(resBytes)) +} + +func TestDeleteFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {}) + err := client.DeleteFile(context.Background(), "deadbeef") + checks.NoError(t, err, "DeleteFile error") +} + +func TestListFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FilesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListFiles(context.Background()) + checks.NoError(t, err, "ListFiles error") +} + +func TestGetFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(File{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.GetFile(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFile error") +} + +func TestGetFileContent(t *testing.T) { + wantRespJsonl := `{"prompt": "foo", "completion": "foo"} +{"prompt": "bar", "completion": "bar"} +{"prompt": "baz", "completion": "baz"} +` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + // edits only accepts GET requests + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + fmt.Fprint(w, wantRespJsonl) + }) + + content, err := client.GetFileContent(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFileContent error") + defer content.Close() + + actual, _ := io.ReadAll(content) + if string(actual) != wantRespJsonl { + t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) + } +} + +func TestGetFileContentReturnError(t *testing.T) { + wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + wantType := "invalid_request_error" + wantErrorResp := `{ + "error": { + "message": "` + wantMessage + `", + "type": "` + wantType + `", + "param": null, + "code": null + } +}` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, wantErrorResp) + }) + + _, err := client.GetFileContent(context.Background(), "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + + apiErr := &APIError{} + if !errors.As(err, &apiErr) { + t.Fatalf("Did not return APIError: %+v\n", apiErr) + } + if apiErr.Message != wantMessage { + t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) + return + } + if apiErr.Type != wantType { + t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) + return + } +} + +func TestGetFileContentReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/files_test.go b/files_test.go index 8e8934935..df6eaef7b 100644 --- a/files_test.go +++ b/files_test.go @@ -2,86 +2,15 @@ package openai //nolint:testpackage // testing private field import ( utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" - "encoding/json" - "errors" "fmt" "io" - "net/http" "os" - "strconv" "testing" - "time" ) -func TestFileUpload(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", handleCreateFile) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := FileRequest{ - FileName: "test.go", - FilePath: "client.go", - Purpose: "fine-tune", - } - _, err = client.CreateFile(ctx, req) - checks.NoError(t, err, "CreateFile error") -} - -// handleCreateFile Handles the images endpoint by the test server. -func handleCreateFile(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - err = r.ParseMultipartForm(1024 * 1024 * 1024) - if err != nil { - http.Error(w, "file is more than 1GB", http.StatusInternalServerError) - return - } - - values := r.Form - var purpose string - for key, value := range values { - if key == "purpose" { - purpose = value[0] - } - } - file, header, err := r.FormFile("file") - if err != nil { - return - } - defer file.Close() - - var fileReq = File{ - Bytes: int(header.Size), - ID: strconv.Itoa(int(time.Now().Unix())), - FileName: header.Filename, - Purpose: purpose, - CreatedAt: time.Now().Unix(), - Object: "test-objecct", - Owner: "test-owner", - } - - resBytes, _ = json.Marshal(fileReq) - fmt.Fprint(w, string(resBytes)) -} - func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -142,168 +71,3 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } - -func TestDeleteFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - err = client.DeleteFile(ctx, "deadbeef") - checks.NoError(t, err, "DeleteFile error") -} - -func TestListFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "{}") - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") -} - -func TestGetFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "{}") - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.GetFile(ctx, "deadbeef") - checks.NoError(t, err, "GetFile error") -} - -func TestGetFileContent(t *testing.T) { - wantRespJsonl := `{"prompt": "foo", "completion": "foo"} -{"prompt": "bar", "completion": "bar"} -{"prompt": "baz", "completion": "baz"} -` - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - // edits only accepts GET requests - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - fmt.Fprint(w, wantRespJsonl) - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - content, err := client.GetFileContent(ctx, "deadbeef") - checks.NoError(t, err, "GetFileContent error") - defer content.Close() - - actual, _ := io.ReadAll(content) - if string(actual) != wantRespJsonl { - t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) - } -} - -func TestGetFileContentReturnError(t *testing.T) { - wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." - wantType := "invalid_request_error" - wantErrorResp := `{ - "error": { - "message": "` + wantMessage + `", - "type": "` + wantType + `", - "param": null, - "code": null - } -}` - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, wantErrorResp) - }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetFileContent(ctx, "deadbeef") - if err == nil { - t.Fatal("Did not return error") - } - - apiErr := &APIError{} - if !errors.As(err, &apiErr) { - t.Fatalf("Did not return APIError: %+v\n", apiErr) - } - if apiErr.Message != wantMessage { - t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) - return - } - if apiErr.Type != wantType { - t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) - return - } -} - -func TestGetFileContentReturnTimeoutError(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - time.Sleep(10 * time.Nanosecond) - }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) - defer cancel() - - _, err := client.GetFileContent(ctx, "deadbeef") - if err == nil { - t.Fatal("Did not return error") - } - if !os.IsTimeout(err) { - t.Fatal("Did not return timeout error") - } -} diff --git a/fine_tunes_test.go b/fine_tunes_test.go index c60254993..67f681d97 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -16,7 +15,8 @@ const testFineTuneID = "fine-tune-id" // TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. func TestFineTunes(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler( "/v1/fine-tunes", func(w http.ResponseWriter, r *http.Request) { @@ -59,18 +59,9 @@ func TestFineTunes(t *testing.T) { }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) ctx := context.Background() - _, err = client.ListFineTunes(ctx) + _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") _, err = client.CreateFineTune(ctx, FineTuneRequest{}) diff --git a/image_api_test.go b/image_api_test.go new file mode 100644 index 000000000..b472eb04a --- /dev/null +++ b/image_api_test.go @@ -0,0 +1,223 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestImages(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/generations", handleImageEndpoint) + _, err := client.CreateImage(context.Background(), ImageRequest{ + Prompt: "Lorem ipsum", + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ImageResponse{ + Created: time.Now().Unix(), + } + for i := 0; i < imageReq.N; i++ { + imageData := ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return ImageRequest{}, err + } + return image, nil +} + +func TestImageEdit(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + mask, err := os.Create("mask.png") + if err != nil { + t.Error("open mask file error") + return + } + + defer func() { + mask.Close() + origin.Close() + os.Remove("mask.png") + os.Remove("image.png") + }() + + _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + Image: origin, + Mask: mask, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +func TestImageEditWithoutMask(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + Image: origin, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleEditImageEndpoint Handles the images endpoint by the test server. +func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} + +func TestImageVariation(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + Image: origin, + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleVariateImageEndpoint Handles the images endpoint by the test server. +func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/image_test.go b/image_test.go index 521720f78..81fff6cba 100644 --- a/image_test.go +++ b/image_test.go @@ -1,343 +1,16 @@ package openai //nolint:testpackage // testing private field import ( - "bytes" - "math/rand" - "strings" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" - "encoding/json" "fmt" "io" - "net/http" "os" "testing" - "time" ) -func TestImages(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -func TestAzureCreateImage(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/images/generations:submit", handleImageEndpoint) - server.RegisterHandler("/openai/operations/images/request-id", handleImageCallbackEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - req.ResponseFormat = CreateImageResponseFormatURL - req.N = 2 - _, err = client.CreateImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleImageEndpoint Handles the images endpoint by the test server. -func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // images only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - // Azure Image Generation request - respond with callback Header only & HTTP accepted status. - if strings.Contains(r.RequestURI, "/openai/images/generations:submit") { - w.Header().Add("Operation-Location", "http://"+r.Host+"/openai/operations/images/request-id") - w.WriteHeader(http.StatusAccepted) - return - } - res := ImageResponse{ - Created: time.Now().Unix(), - } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: - // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" - default: - http.Error(w, "invalid response format", http.StatusBadRequest) - return - } - res.Data = append(res.Data, imageData) - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// handleImageCallbackEndpoint Handles the callback endpoint by the test server. -func handleImageCallbackEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - - // image callback only accepts GET requests - if r.Method != "GET" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Randomly set the status to Succeeded or running - status := "" - rand.Seed(time.Now().UnixNano()) - switch rand.Intn(3) { - case 0: - status = "Succeeded" - case 1: - status = "running" - case 2: - status = "notRunning" - } - - cbResponse := CallBackResponse{ - Created: time.Now().Unix(), - Status: status, - Result: CBResult{ - Data: CBData{ - {URL: "http://example.com/image1"}, - {URL: "http://example.com/image2"}, - }, - }, - } - cbResponseBytes := new(bytes.Buffer) - err = json.NewEncoder(cbResponseBytes).Encode(cbResponse) - if err != nil { - http.Error(w, "could not write repsonse", http.StatusInternalServerError) - return - } - fmt.Fprintln(w, cbResponseBytes.String()) -} - -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ImageRequest{}, err - } - err = json.Unmarshal(reqBody, &image) - if err != nil { - return ImageRequest{}, err - } - return image, nil -} - -func TestImageEdit(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - mask, err := os.Create("mask.png") - if err != nil { - t.Error("open mask file error") - return - } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Mask: mask, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -func TestImageEditWithoutMask(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleEditImageEndpoint Handles the images endpoint by the test server. -func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - -func TestImageVariation(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageVariRequest{ - Image: origin, - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateVariImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleVariateImageEndpoint Handles the images endpoint by the test server. -func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - type mockFormBuilder struct { mockCreateFormFile func(string, *os.File) error mockCreateFormFileReader func(string, io.Reader, string) error diff --git a/models_test.go b/models_test.go index 834c849c4..0b4daf4a8 100644 --- a/models_test.go +++ b/models_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -12,85 +11,47 @@ import ( "testing" ) -// TestListModels Tests the models endpoint of the API using the mocked server. +// TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListModels(ctx) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) checks.NoError(t, err, "ListModels error") } func TestAzureListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListModels(ctx) + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) checks.NoError(t, err, "ListModels error") } -// handleModelsEndpoint Handles the models endpoint by the test server. -func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { +// handleListModelsEndpoint Handles the list models endpoint by the test server. +func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{}) fmt.Fprintln(w, string(resBytes)) } // TestGetModel Tests the retrieve model endpoint of the API using the mocked server. func TestGetModel(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetModel(ctx, "text-davinci-003") + _, err := client.GetModel(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetModel error") } func TestAzureGetModel(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetModel(ctx, "text-davinci-003") + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetModel error") } -// handleModelsEndpoint Handles the models endpoint by the test server. +// handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(Model{}) fmt.Fprintln(w, string(resBytes)) diff --git a/moderation_test.go b/moderation_test.go index 2c1145627..4e756137e 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -18,26 +17,13 @@ import ( // TestModeration Tests the moderations endpoint of the API using the mocked server. func TestModerations(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - // create an edit request - model := "text-moderation-stable" - moderationReq := ModerationRequest{ - Model: model, + _, err := client.Moderations(context.Background(), ModerationRequest{ + Model: ModerationTextStable, Input: "I want to kill them.", - } - _, err = client.Moderations(ctx, moderationReq) + }) checks.NoError(t, err, "Moderation error") } diff --git a/openai_test.go b/openai_test.go new file mode 100644 index 000000000..a5e7b64ee --- /dev/null +++ b/openai_test.go @@ -0,0 +1,28 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" +) + +func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client = NewClientWithConfig(config) + return +} + +func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config.BaseURL = ts.URL + client = NewClientWithConfig(config) + return +} diff --git a/stream_reader_test.go b/stream_reader_test.go index 0e45c0b73..cd6e46eff 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -7,6 +7,8 @@ import ( "testing" utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") @@ -47,7 +49,17 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { unmarshaler: &utils.JSONUnmarshaler{}, } _, err := stream.Recv() - if !errors.Is(err, ErrTooManyEmptyStreamMessages) { - t.Fatalf("Did not return error when recv failed: %v", err) + checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error()) +} + +func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), + errAccumulator: &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + }, + unmarshaler: &utils.JSONUnmarshaler{}, } + _, err := stream.Recv() + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } diff --git a/stream_test.go b/stream_test.go index 0faa21222..5997f27e8 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,11 +6,9 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "testing" . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -32,7 +30,9 @@ func TestCompletionsStreamWrongModel(t *testing.T) { } func TestCreateCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -52,28 +52,14 @@ func TestCreateCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -116,7 +102,9 @@ func TestCreateCompletionStream(t *testing.T) { } func TestCreateCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -137,28 +125,14 @@ func TestCreateCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ MaxTokens: 5, Model: GPT3TextDavinci003, Prompt: "Hello!", Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -173,7 +147,8 @@ func TestCreateCompletionStreamError(t *testing.T) { } func TestCreateCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -188,30 +163,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - request := CompletionRequest{ + var apiErr *APIError + _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ MaxTokens: 5, Model: GPT3Ada, Prompt: "Hello!", Stream: true, - } - - var apiErr *APIError - _, err := client.CreateCompletionStream(ctx, request) + }) if !errors.As(err, &apiErr) { t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") } @@ -219,7 +178,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { } func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -244,28 +205,14 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -277,7 +224,9 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { } func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -291,28 +240,14 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -324,7 +259,9 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { } func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -344,28 +281,14 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close()