From 7fc065e794ceac3e5f478882def03872a3653d53 Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Sun, 18 Jun 2023 05:44:09 +0900 Subject: [PATCH] extract and split integration tests --- api_integration_test.go | 136 ++++++++++++++++ api_test.go | 353 ---------------------------------------- engines_test.go | 11 ++ error_test.go | 201 +++++++++++++++++++++++ openai_test.go | 9 + 5 files changed, 357 insertions(+), 353 deletions(-) create mode 100644 api_integration_test.go delete mode 100644 api_test.go create mode 100644 error_test.go diff --git a/api_integration_test.go b/api_integration_test.go new file mode 100644 index 000000000..3cafa24b4 --- /dev/null +++ b/api_integration_test.go @@ -0,0 +1,136 @@ +package openai_test + +import ( + "context" + "errors" + "io" + "os" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestAPI(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken) + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.NoError(t, err, "ListEngines error") + + _, err = c.GetEngine(ctx, "davinci") + checks.NoError(t, err, "GetEngine error") + + fileRes, err := c.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") + + if len(fileRes.Files) > 0 { + _, err = c.GetFile(ctx, fileRes.Files[0].ID) + checks.NoError(t, err, "GetFile error") + } // else skip + + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: AdaSearchQuery, + } + _, err = c.CreateEmbeddings(ctx, embeddingReq) + checks.NoError(t, err, "Embedding error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + checks.NoError(t, err, "CreateChatCompletion (without name) returned error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Name: "John_Doe", + Content: "Hello!", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with name) returned error") + + stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: GPT3Ada, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.HasError(t, err, "ListEngines should fail with an invalid key") + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.HTTPStatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) + } + + switch v := apiErr.Code.(type) { + case string: + if v != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } + + if apiErr.Error() == "" { + t.Fatal("Empty error message occurred") + } +} diff --git a/api_test.go b/api_test.go deleted file mode 100644 index 34173708f..000000000 --- a/api_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package openai_test - -import ( - "context" - "encoding/json" - "errors" - "io" - "net/http" - "os" - "testing" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.NoError(t, err, "ListEngines error") - - _, err = c.GetEngine(ctx, "davinci") - checks.NoError(t, err, "GetEngine error") - - fileRes, err := c.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") - - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) - checks.NoError(t, err, "GetFile error") - } // else skip - - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) - checks.NoError(t, err, "Embedding error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - - checks.NoError(t, err, "CreateChatCompletion (without name) returned error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Name: "John_Doe", - Content: "Hello!", - }, - }, - }, - ) - checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } -} - -func TestAPIError(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken + "_invalid") - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.HasError(t, err, "ListEngines should fail with an invalid key") - - var apiErr *APIError - if !errors.As(err, &apiErr) { - t.Fatalf("Error is not an APIError: %+v", err) - } - - if apiErr.HTTPStatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) - } - - switch v := apiErr.Code.(type) { - case string: - if v != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } - - if apiErr.Error() == "" { - t.Fatal("Empty error message occurred") - } -} - -func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { - type testCase struct { - name string - response string - hasError bool - checkFn func(t *testing.T, apiErr APIError) - } - testCases := []testCase{ - { - name: "parse succeeds when the message is string", - response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with single item", - response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with multiple items", - response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo, bar, baz" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is empty array", - response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse succeeds when the message is null", - response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse failed when the message is object", - response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is int", - response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is float", - response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is bool", - response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is not exists", - response: `{"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var apiErr APIError - err := json.Unmarshal([]byte(tc.response), &apiErr) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) - return - } - if tc.checkFn != nil { - tc.checkFn(t, apiErr) - } - }) - } -} - -func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case int: - if v != 418 { - t.Fatalf("Unexpected API code integer: %d; expected 418", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONString(t *testing.T) { - var apiErr APIError - response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case string: - if v != "teapot" { - t.Fatalf("Unexpected API code string: %s; expected `teapot`", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) { - // test integer code - response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - var apiErr APIError - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case nil: - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalInvalidData(t *testing.T) { - apiErr := APIError{} - data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`) - err := apiErr.UnmarshalJSON(data) - checks.HasError(t, err, "Expected error when unmarshaling invalid data") - - if apiErr.Code != nil { - t.Fatalf("Expected nil code, got %q", apiErr.Code) - } - if apiErr.Message != "" { - t.Fatalf("Expected empty message, got %q", apiErr.Message) - } - if apiErr.Param != nil { - t.Fatalf("Expected nil param, got %q", *apiErr.Param) - } - if apiErr.Type != "" { - t.Fatalf("Expected empty type, got %q", apiErr.Type) - } -} - -func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Param should be a string") -} - -func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Type should be a string") -} - -func TestRequestError(t *testing.T) { - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTeapot) - }) - - _, err := client.ListEngines(context.Background()) - checks.HasError(t, err, "ListEngines did not fail") - - var reqErr *RequestError - if !errors.As(err, &reqErr) { - t.Fatalf("Error is not a RequestError: %+v", err) - } - - if reqErr.HTTPStatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) - } - - if reqErr.Unwrap() == nil { - t.Fatalf("Empty request error occurred") - } -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} diff --git a/engines_test.go b/engines_test.go index 2beb333b3..31e7ec8be 100644 --- a/engines_test.go +++ b/engines_test.go @@ -34,3 +34,14 @@ func TestListEngines(t *testing.T) { _, err := client.ListEngines(context.Background()) checks.NoError(t, err, "ListEngines error") } + +func TestListEnginesReturnError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + _, err := client.ListEngines(context.Background()) + checks.HasError(t, err, "ListEngines did not fail") +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 000000000..e2309abd7 --- /dev/null +++ b/error_test.go @@ -0,0 +1,201 @@ +package openai_test + +import ( + "errors" + "net/http" + "testing" + + . "github.com/sashabaranov/go-openai" +) + +func TestAPIErrorUnmarshalJSON(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFunc func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + // testcase for message field + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo, bar, baz") + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + // testcase for code field + { + name: "parse succeeds when the code is int", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, 418) + }, + }, + { + name: "parse succeeds when the code is string", + response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, "teapot") + }, + }, + { + name: "parse succeeds when the code is not exists", + response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + }, + }, + // testcase for param field + { + name: "parse failed when the param is bool", + response: `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`, + hasError: true, + }, + // testcase for type field + { + name: "parse failed when the type is bool", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`, + hasError: true, + }, + // testcase for error response + { + name: "parse failed when the response is invalid json", + response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + assertAPIErrorMessage(t, apiErr, "") + assertAPIErrorParam(t, apiErr, nil) + assertAPIErrorType(t, apiErr, "") + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := apiErr.UnmarshalJSON([]byte(tc.response)) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + } + if tc.checkFunc != nil { + tc.checkFunc(t, apiErr) + } + }) + } +} + +func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { + if apiErr.Message != expected { + t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) + } +} + +func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { + switch v := apiErr.Code.(type) { + case int: + if v != expected { + t.Errorf("Unexpected APIError code integer: %d; expected %d", v, expected) + } + case string: + if v != expected { + t.Errorf("Unexpected APIError code string: %s; expected %s", v, expected) + } + case nil: + default: + t.Errorf("Unexpected APIError error code type: %T", v) + } +} + +func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { + if apiErr.Param != expected { + t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) + } +} + +func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { + if apiErr.Type != typ { + t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) + } +} + +func TestRequestError(t *testing.T) { + var err error = &RequestError{ + HTTPStatusCode: http.StatusTeapot, + Err: errors.New("i am a teapot"), + } + + var reqErr *RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.HTTPStatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) + } + + if reqErr.Unwrap() == nil { + t.Fatalf("Empty request error occurred") + } +} diff --git a/openai_test.go b/openai_test.go index a5e7b64ee..4fc41ecc0 100644 --- a/openai_test.go +++ b/openai_test.go @@ -26,3 +26,12 @@ func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown f client = NewClientWithConfig(config) return } + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer +// +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +func numTokens(s string) int { + return int(float32(len(s)) / 4) +}