forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from sashabaranov/master
extract and split integration tests (sashabaranov#389)
- Loading branch information
Showing
5 changed files
with
357 additions
and
353 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
package openai_test | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
"os" | ||
"testing" | ||
|
||
. "github.com/sashabaranov/go-openai" | ||
"github.com/sashabaranov/go-openai/internal/test/checks" | ||
) | ||
|
||
func TestAPI(t *testing.T) { | ||
apiToken := os.Getenv("OPENAI_TOKEN") | ||
if apiToken == "" { | ||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | ||
} | ||
|
||
var err error | ||
c := NewClient(apiToken) | ||
ctx := context.Background() | ||
_, err = c.ListEngines(ctx) | ||
checks.NoError(t, err, "ListEngines error") | ||
|
||
_, err = c.GetEngine(ctx, "davinci") | ||
checks.NoError(t, err, "GetEngine error") | ||
|
||
fileRes, err := c.ListFiles(ctx) | ||
checks.NoError(t, err, "ListFiles error") | ||
|
||
if len(fileRes.Files) > 0 { | ||
_, err = c.GetFile(ctx, fileRes.Files[0].ID) | ||
checks.NoError(t, err, "GetFile error") | ||
} // else skip | ||
|
||
embeddingReq := EmbeddingRequest{ | ||
Input: []string{ | ||
"The food was delicious and the waiter", | ||
"Other examples of embedding request", | ||
}, | ||
Model: AdaSearchQuery, | ||
} | ||
_, err = c.CreateEmbeddings(ctx, embeddingReq) | ||
checks.NoError(t, err, "Embedding error") | ||
|
||
_, err = c.CreateChatCompletion( | ||
ctx, | ||
ChatCompletionRequest{ | ||
Model: GPT3Dot5Turbo, | ||
Messages: []ChatCompletionMessage{ | ||
{ | ||
Role: ChatMessageRoleUser, | ||
Content: "Hello!", | ||
}, | ||
}, | ||
}, | ||
) | ||
|
||
checks.NoError(t, err, "CreateChatCompletion (without name) returned error") | ||
|
||
_, err = c.CreateChatCompletion( | ||
ctx, | ||
ChatCompletionRequest{ | ||
Model: GPT3Dot5Turbo, | ||
Messages: []ChatCompletionMessage{ | ||
{ | ||
Role: ChatMessageRoleUser, | ||
Name: "John_Doe", | ||
Content: "Hello!", | ||
}, | ||
}, | ||
}, | ||
) | ||
checks.NoError(t, err, "CreateChatCompletion (with name) returned error") | ||
|
||
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ | ||
Prompt: "Ex falso quodlibet", | ||
Model: GPT3Ada, | ||
MaxTokens: 5, | ||
Stream: true, | ||
}) | ||
checks.NoError(t, err, "CreateCompletionStream returned error") | ||
defer stream.Close() | ||
|
||
counter := 0 | ||
for { | ||
_, err = stream.Recv() | ||
if err != nil { | ||
if errors.Is(err, io.EOF) { | ||
break | ||
} | ||
t.Errorf("Stream error: %v", err) | ||
} else { | ||
counter++ | ||
} | ||
} | ||
if counter == 0 { | ||
t.Error("Stream did not return any responses") | ||
} | ||
} | ||
|
||
func TestAPIError(t *testing.T) { | ||
apiToken := os.Getenv("OPENAI_TOKEN") | ||
if apiToken == "" { | ||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | ||
} | ||
|
||
var err error | ||
c := NewClient(apiToken + "_invalid") | ||
ctx := context.Background() | ||
_, err = c.ListEngines(ctx) | ||
checks.HasError(t, err, "ListEngines should fail with an invalid key") | ||
|
||
var apiErr *APIError | ||
if !errors.As(err, &apiErr) { | ||
t.Fatalf("Error is not an APIError: %+v", err) | ||
} | ||
|
||
if apiErr.HTTPStatusCode != 401 { | ||
t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) | ||
} | ||
|
||
switch v := apiErr.Code.(type) { | ||
case string: | ||
if v != "invalid_api_key" { | ||
t.Fatalf("Unexpected API error code: %s", v) | ||
} | ||
default: | ||
t.Fatalf("Unexpected API error code type: %T", v) | ||
} | ||
|
||
if apiErr.Error() == "" { | ||
t.Fatal("Empty error message occurred") | ||
} | ||
} |
Oops, something went wrong.