Skip to content

Commit

Permalink
Refactor/internal testing (sashabaranov#194)
Browse files Browse the repository at this point in the history
* added NoError check

* corrected NoError

* has error checks

* replace more checks

* Used checks test helper

* Used checks test helper

* remove duplicate import

* fixed lint issues regarding length of messages

---------

Co-authored-by: Rex Posadas <rposadas@redwoodlogistics.com>
  • Loading branch information
rexposadas and Rex Posadas authored Mar 24, 2023
1 parent 479dab3 commit 8e3a046
Show file tree
Hide file tree
Showing 15 changed files with 115 additions and 140 deletions.
42 changes: 11 additions & 31 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"errors"
Expand All @@ -20,25 +21,17 @@ func TestAPI(t *testing.T) {
c := NewClient(apiToken)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err != nil {
t.Fatalf("ListEngines error: %v", err)
}
checks.NoError(t, err, "ListEngines error")

_, err = c.GetEngine(ctx, "davinci")
if err != nil {
t.Fatalf("GetEngine error: %v", err)
}
checks.NoError(t, err, "GetEngine error")

fileRes, err := c.ListFiles(ctx)
if err != nil {
t.Fatalf("ListFiles error: %v", err)
}
checks.NoError(t, err, "ListFiles error")

if len(fileRes.Files) > 0 {
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
if err != nil {
t.Fatalf("GetFile error: %v", err)
}
checks.NoError(t, err, "GetFile error")
} // else skip

embeddingReq := EmbeddingRequest{
Expand All @@ -49,9 +42,7 @@ func TestAPI(t *testing.T) {
Model: AdaSearchQuery,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil {
t.Fatalf("Embedding error: %v", err)
}
checks.NoError(t, err, "Embedding error")

_, err = c.CreateChatCompletion(
ctx,
Expand All @@ -66,9 +57,7 @@ func TestAPI(t *testing.T) {
},
)

if err != nil {
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion (without name) returned error")

_, err = c.CreateChatCompletion(
ctx,
Expand All @@ -83,20 +72,15 @@ func TestAPI(t *testing.T) {
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")

stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: GPT3Ada,
MaxTokens: 5,
Stream: true,
})
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

counter := 0
Expand Down Expand Up @@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) {
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines did not fail")
}
checks.NoError(t, err, "ListEngines did not fail")

var apiErr *APIError
if !errors.As(err, &apiErr) {
Expand All @@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) {
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines request did not fail")
}
checks.HasError(t, err, "ListEngines did not fail")

var reqErr *RequestError
if !errors.As(err, &reqErr) {
Expand Down
18 changes: 6 additions & 12 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"testing"
Expand Down Expand Up @@ -62,9 +63,7 @@ func TestAudio(t *testing.T) {
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
checks.NoError(t, err, "audio API error")
})
}
}
Expand Down Expand Up @@ -115,19 +114,16 @@ func TestAudioWithOptionalArgs(t *testing.T) {
Language: "zh",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
checks.NoError(t, err, "audio API error")
})
}
}

// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file %v", err)
}
checks.NoError(t, err, "failed to create file")

if _, err = file.WriteString("hello"); err != nil {
t.Fatalf("failed to write to file %v", err)
}
Expand All @@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()

path, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
t.Fatal(err)
}
checks.NoError(t, err)

return path, func() { os.RemoveAll(path) }
}
Expand Down
28 changes: 10 additions & 18 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
checks.NoError(t, err, "Write error")
}))
defer server.Close()

Expand Down Expand Up @@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

expectedResponses := []ChatCompletionStreamResponse{
Expand Down Expand Up @@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
t.Logf("%d: %s", ix, string(b))

receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
checks.NoError(t, streamErr, "stream.Recv() failed")
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
Expand All @@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) {
}

_, streamErr = stream.Recv()

checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished")
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
Expand All @@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
}

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
checks.NoError(t, err, "Write error")
}))
defer server.Close()

Expand Down Expand Up @@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

_, streamErr := stream.Recv()
if streamErr == nil {
t.Errorf("stream.Recv() did not return error")
}
checks.HasError(t, streamErr, "stream.Recv() did not return error")

var apiErr *APIError
if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError")
Expand Down
15 changes: 5 additions & 10 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) {
},
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) {
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
}
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
}

func TestChatCompletionsWithStream(t *testing.T) {
Expand All @@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) {
Stream: true,
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
}
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
Expand Down Expand Up @@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) {
},
}
_, err = client.CreateChatCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletion error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion error")
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
Expand Down
5 changes: 2 additions & 3 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) {
}
req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
checks.NoError(t, err, "CreateCompletion error")
}

// handleCompletionEndpoint Handles the completion endpoint by the test server.
Expand Down
5 changes: 2 additions & 3 deletions edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -40,9 +41,7 @@ func TestEdits(t *testing.T) {
N: 3,
}
response, err := client.Edits(ctx, editReq)
if err != nil {
t.Fatalf("Edits error: %v", err)
}
checks.NoError(t, err, "Edits error")
if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices")
}
Expand Down
5 changes: 2 additions & 3 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"bytes"
"encoding/json"
Expand Down Expand Up @@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) {
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
Expand Down
10 changes: 4 additions & 6 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

var (
Expand Down Expand Up @@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
ctx := context.Background()

stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if err != nil {
t.Fatal(err)
}
checks.NoError(t, err)

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}

_, err = stream.Recv()
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}
5 changes: 2 additions & 3 deletions files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) {
Purpose: "fine-tune",
}
_, err = client.CreateFile(ctx, req)
if err != nil {
t.Fatalf("CreateFile error: %v", err)
}
checks.NoError(t, err, "CreateFile erro")
}

// handleCreateFile Handles the images endpoint by the test server.
Expand Down
Loading

0 comments on commit 8e3a046

Please sign in to comment.