diff --git a/api_integration_test.go b/api_integration_test.go index 254fbeb03..6be188bc6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,7 +9,6 @@ import ( "os" "testing" - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) diff --git a/audio_api_test.go b/audio_api_test.go index aad7a225a..a0efc7921 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -12,7 +12,7 @@ import ( "strings" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -26,7 +26,7 @@ func TestAudio(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -48,7 +48,7 @@ func TestAudio(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", } @@ -57,7 +57,7 @@ func TestAudio(t *testing.T) { }) t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: "fake.webm", Reader: bytes.NewBuffer([]byte(`some webm binary data`)), Model: "whisper-3", @@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", Prompt: "用简体中文", Temperature: 0.5, Language: "zh", - Format: AudioResponseFormatSRT, + Format: openai.AudioResponseFormatSRT, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index e19a873f3..5346244c8 100644 --- a/audio_test.go +++ b/audio_test.go @@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/chat_stream_test.go b/chat_stream_test.go index 2c109d454..bd571cb48 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -10,28 +10,28 @@ import ( "strconv" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { + if !errors.Is(err, openai.ErrChatCompletionInvalidModel) { t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } @@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []ChatCompletionStreamResponse{ + expectedResponses := []openai.ChatCompletionStreamResponse{ { ID: "1", Object: "completion", Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response1", }, FinishReason: "max_tokens", @@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response2", }, FinishReason: "max_tokens", @@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set(xCustomHeader, xCustomHeaderValue) @@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") for k, v := range rateLimitHeaders { switch val := v.(type) { @@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, }) - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) // Send test responses @@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - apiErr := &APIError{} - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + apiErr := &openai.APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } // Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { +func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { return true } -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { +func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false } diff --git a/chat_test.go b/chat_test.go index 329b2b9cb..5bf1eaf6c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -21,49 +21,47 @@ const ( xCustomHeaderValue = "test" ) -var ( - rateLimitHeaders = map[string]any{ - "x-ratelimit-limit-requests": 60, - "x-ratelimit-limit-tokens": 150000, - "x-ratelimit-remaining-requests": 59, - "x-ratelimit-remaining-tokens": 149984, - "x-ratelimit-reset-requests": "1s", - "x-ratelimit-reset-tokens": "6m0s", - } -) +var rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", +} func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) - checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) + checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") + checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) { t.Run("bytes", func(t *testing.T) { //nolint:lll msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) { Count: 2, Words: []string{"hello", "world"}, } - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinition", func(t *testing.T) { - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) { }) t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []openai.FunctionDefine{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq ChatCompletionRequest + var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ChatCompletionResponse{ + res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { return } - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleFunction, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleFunction, // this is valid json so it should be fine - FunctionCall: &FunctionCall{ + FunctionCall: &openai.FunctionCall{ Name: completionReq.Functions[0].Name, Arguments: string(fcb), }, @@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, @@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens := numTokens(completionReq.Messages[0].Content) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} +func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { + completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } return completion, nil } func TestFinishReason(t *testing.T) { - c := &ChatCompletionChoice{ - FinishReason: FinishReasonNull, + c := &openai.ChatCompletionChoice{ + FinishReason: openai.FinishReasonNull, } resBytes, _ := json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { @@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) { t.Error("null should not be quoted") } - otherReasons := []FinishReason{ - FinishReasonStop, - FinishReasonLength, - FinishReasonFunctionCall, - FinishReasonContentFilter, + otherReasons := []openai.FinishReason{ + openai.FinishReasonStop, + openai.FinishReasonLength, + openai.FinishReasonFunctionCall, + openai.FinishReasonContentFilter, } for _, r := range otherReasons { c.FinishReason = r diff --git a/completion_test.go b/completion_test.go index 844ef484f..89950bf94 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" @@ -14,33 +11,36 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletion( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + config := openai.DefaultConfig("whatever") + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := CompletionRequest{Stream: true} + req := openai.CompletionRequest{Stream: true} _, err := client.CreateCompletion(ctx, req) - if !errors.Is(err, ErrCompletionStreamNotSupported) { + if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") } } @@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - req := CompletionRequest{ + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", Prompt: "Lorem ipsum", @@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt.(string) + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, openai.CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt.(string)) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } diff --git a/config_test.go b/config_test.go index 488511b11..3e528c3e9 100644 --- a/config_test.go +++ b/config_test.go @@ -3,7 +3,7 @@ package openai_test import ( "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestGetAzureDeploymentByModel(t *testing.T) { @@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) { for _, c := range cases { t.Run(c.Model, func(t *testing.T) { - conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") if c.AzureModelMapperFunc != nil { conf.AzureModelMapperFunc = c.AzureModelMapperFunc } diff --git a/edits_test.go b/edits_test.go index c0bb84392..d2a6db40d 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -11,6 +8,9 @@ import ( "net/http" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. @@ -20,7 +20,7 @@ func TestEdits(t *testing.T) { server.RegisterHandler("/v1/edits", handleEditEndpoint) // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: time.Now().Unix(), } @@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } diff --git a/embeddings_test.go b/embeddings_test.go index 72e8c245f..af04d96bf 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -11,32 +11,32 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { // test embedding request with strings (simple embedding request) - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with strings - embeddingReqStrings := EmbeddingRequestStrings{ + embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with tokens - embeddingReqTokens := EmbeddingRequestTokens{ + embeddingReqTokens := openai.EmbeddingRequestTokens{ Input: [][]int{ {464, 2057, 373, 12625, 290, 262, 46612}, {6395, 6096, 286, 11525, 12083, 2581}, @@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) { } func TestEmbeddingModel(t *testing.T) { - var em EmbeddingModel + var em openai.EmbeddingModel err := em.UnmarshalText([]byte("text-similarity-ada-001")) checks.NoError(t, err, "Could not marshal embedding model") - if em != AdaSimilarity { + if em != openai.AdaSimilarity { t.Errorf("Model is not equal to AdaSimilarity") } err = em.UnmarshalText([]byte("some-non-existent-model")) checks.NoError(t, err, "Could not marshal embedding model") - if em != Unknown { + if em != openai.Unknown { t.Errorf("Model is not equal to Unknown") } } @@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - sampleEmbeddings := []Embedding{ + sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } - sampleBase64Embeddings := []Base64Embedding{ + sampleBase64Embeddings := []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, } @@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) { "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { var req struct { - EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` - User string `json:"user"` + EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` } _ = json.NewDecoder(r.Body).Decode(&req) @@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) { case req.User == "invalid": w.WriteHeader(http.StatusBadRequest) return - case req.EncodingFormat == EmbeddingEncodingFormatBase64: - resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) default: - resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) @@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) { // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), - EmbeddingRequest{ - EncodingFormat: EmbeddingEncodingFormatBase64, + openai.EmbeddingRequest{ + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }, ) checks.NoError(t, err, "CreateEmbeddings error") @@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) { } // test create embeddings with strings - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with tokens - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test failed sendRequest - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ User: "invalid", - EncodingFormat: EmbeddingEncodingFormatBase64, + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }) checks.HasError(t, err, "CreateEmbeddings error") } @@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string - Data []Base64Embedding - Model EmbeddingModel - Usage Usage + Data []openai.Base64Embedding + Model openai.EmbeddingModel + Usage openai.Usage } tests := []struct { name string fields fields - want EmbeddingResponse + want openai.EmbeddingResponse wantErr bool }{ { name: "test embedding response base64 to embedding response", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, }, }, - want: EmbeddingResponse{ - Data: []Embedding{ + want: openai.EmbeddingResponse{ + Data: []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, }, @@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { { name: "Invalid embedding", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ { Embedding: "----", }, }, }, - want: EmbeddingResponse{}, + want: openai.EmbeddingResponse{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &EmbeddingResponseBase64{ + r := &openai.EmbeddingResponseBase64{ Object: tt.fields.Object, Data: tt.fields.Data, Model: tt.fields.Model, @@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { } func TestDotProduct(t *testing.T) { - v1 := &Embedding{Embedding: []float32{1, 2, 3}} - v2 := &Embedding{Embedding: []float32{2, 4, 6}} + v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} + v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} expected := float32(28.0) result, err := v1.DotProduct(v2) @@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1, 0}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} expected = float32(0.0) result, err = v1.DotProduct(v2) @@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) { } // Test for VectorLengthMismatchError - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1}} _, err = v1.DotProduct(v2) - if !errors.Is(err, ErrVectorLengthMismatch) { + if !errors.Is(err, openai.ErrVectorLengthMismatch) { t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) } } diff --git a/engines_test.go b/engines_test.go index 31e7ec8be..d26aa5541 100644 --- a/engines_test.go +++ b/engines_test.go @@ -7,7 +7,7 @@ import ( "net/http" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -15,8 +15,8 @@ import ( func TestGetEngine(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(Engine{}) + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Engine{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetEngine(context.Background(), "text-davinci-003") @@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) { func TestListEngines(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EnginesList{}) + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EnginesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListEngines(context.Background()) @@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) { func TestListEnginesReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }) diff --git a/error_test.go b/error_test.go index a0806b7ed..48cbe4f29 100644 --- a/error_test.go +++ b/error_test.go @@ -6,7 +6,7 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestAPIErrorUnmarshalJSON(t *testing.T) { @@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name string response string hasError bool - checkFunc func(t *testing.T, apiErr APIError) + checkFunc func(t *testing.T, apiErr openai.APIError) } testCases := []testCase{ // testcase for message field @@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is string", response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with single item", response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with multiple items", response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo, bar, baz") }, }, @@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is empty array", response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is null", response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } }`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{ + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{ Code: "ResponsibleAIPolicyViolation", - ContentFilterResults: ContentFilterResults{ - Hate: Hate{ + ContentFilterResults: openai.ContentFilterResults{ + Hate: openai.Hate{ Filtered: false, Severity: "safe", }, - SelfHarm: SelfHarm{ + SelfHarm: openai.SelfHarm{ Filtered: false, Severity: "safe", }, - Sexual: Sexual{ + Sexual: openai.Sexual{ Filtered: true, Severity: "medium", }, - Violence: Violence{ + Violence: openai.Violence{ Filtered: false, Severity: "safe", }, @@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the innerError is empty (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { @@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is int", response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, 418) }, }, @@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is string", response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, "teapot") }, }, @@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is not exists", response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) }, }, @@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse failed when the response is invalid json", response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) assertAPIErrorMessage(t, apiErr, "") assertAPIErrorParam(t, apiErr, nil) @@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var apiErr APIError + var apiErr openai.APIError err := apiErr.UnmarshalJSON([]byte(tc.response)) if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) @@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } } -func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { +func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) { if apiErr.Message != expected { t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) } } -func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) { if !reflect.DeepEqual(apiErr.InnerError, expected) { t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) } } -func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: if v != expected { @@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { } } -func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { +func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) { if apiErr.Param != expected { t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) } } -func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { +func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) { if apiErr.Type != typ { t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) } } func TestRequestError(t *testing.T) { - var err error = &RequestError{ + var err error = &openai.RequestError{ HTTPStatusCode: http.StatusTeapot, Err: errors.New("i am a teapot"), } - var reqErr *RequestError + var reqErr *openai.RequestError if !errors.As(err, &reqErr) { t.Fatalf("Error is not a RequestError: %+v", err) } diff --git a/example_test.go b/example_test.go index b5dfafea9..de67c57cd 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,6 @@ func Example() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return diff --git a/files_api_test.go b/files_api_test.go index 1cbc72894..330b88159 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", handleCreateFile) - req := FileRequest{ + req := openai.FileRequest{ FileName: "test.go", FilePath: "client.go", Purpose: "fine-tune", @@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { } defer file.Close() - var fileReq = File{ + fileReq := openai.File{ Bytes: int(header.Size), ID: strconv.Itoa(int(time.Now().Unix())), FileName: header.Filename, @@ -82,7 +82,7 @@ func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FilesList{}) + resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListFiles(context.Background()) @@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(File{}) + resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetFile(context.Background(), "deadbeef") @@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) { t.Fatal("Did not return error") } - apiErr := &APIError{} + apiErr := &openai.APIError{} if !errors.As(err, &apiErr) { t.Fatalf("Did not return APIError: %+v\n", apiErr) } diff --git a/files_test.go b/files_test.go index df6eaef7b..f588b30dc 100644 --- a/files_test.go +++ b/files_test.go @@ -1,14 +1,14 @@ package openai //nolint:testpackage // testing private field import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "fmt" "io" "os" "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestFileUploadWithFailingFormBuilder(t *testing.T) { diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..46f89f165 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 67f681d97..2ab6817f7 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,14 +1,14 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" @@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) + resBytes, _ = json.Marshal(openai.FineTuneList{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTune{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) { _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index f6d41c33d..c892ef775 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -2,14 +2,13 @@ package openai_test import ( "context" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuninigJobID = "fine-tuning-job-id" @@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) { defer teardown() server.RegisterHandler( "/v1/fine_tuning/jobs", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{ + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{ Object: "fine_tuning.job", ID: testFineTuninigJobID, Model: "davinci-002", @@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) { Status: "succeeded", ValidationFile: "", TrainingFile: "file-abc123", - Hyperparameters: Hyperparameters{ + Hyperparameters: openai.Hyperparameters{ Epochs: "auto", }, TrainedTokens: 5768, @@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID, - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ = json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJobEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) ctx := context.Background() - _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) checks.NoError(t, err, "CreateFineTuningJob error") _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) @@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) { _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") } diff --git a/image_api_test.go b/image_api_test.go index b472eb04a..422f831fe 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -12,13 +9,16 @@ import ( "os" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestImages(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - _, err := client.CreateImage(context.Background(), ImageRequest{ + _, err := client.CreateImage(context.Background(), openai.ImageRequest{ Prompt: "Lorem ipsum", }) checks.NoError(t, err, "CreateImage error") @@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest + var imageReq openai.ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + res := openai.ImageResponse{ Created: time.Now().Unix(), } for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} + imageData := openai.ImageResponseDataInner{} switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": + case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + case openai.CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: @@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } return image, nil } @@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", @@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index c8d0c1d9e..744706082 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -5,28 +5,28 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai/jsonschema" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestDefinition_MarshalJSON(t *testing.T) { tests := []struct { name string - def Definition + def jsonschema.Definition want string }{ { name: "Test with empty Definition", - def: Definition{}, + def: jsonschema.Definition{}, want: `{"properties":{}}`, }, { name: "Test with Definition properties set", - def: Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.String, Description: "A string type", - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with nested Definition properties", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, }, }, @@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with complex nested Definition", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, "address": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "city": { - Type: String, + Type: jsonschema.String, }, "country": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with Array type Definition", - def: Definition{ - Type: Array, - Items: &Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.Array, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, diff --git a/models_test.go b/models_test.go index 9ff73042a..4a4c759dc 100644 --- a/models_test.go +++ b/models_test.go @@ -1,17 +1,16 @@ package openai_test import ( - "os" - "time" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" + "os" "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneModelID = "fine-tune-model-id" @@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) { // handleListModelsEndpoint Handles the list models endpoint by the test server. func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) + resBytes, _ := json.Marshal(openai.ModelsList{}) fmt.Fprintln(w, string(resBytes)) } @@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) { // handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(Model{}) + resBytes, _ := json.Marshal(openai.Model{}) fmt.Fprintln(w, string(resBytes)) } @@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) { } func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/moderation_test.go b/moderation_test.go index 68f9565e1..059f0d1c7 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,9 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. @@ -20,8 +20,8 @@ func TestModerations(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - _, err := client.Moderations(context.Background(), ModerationRequest{ - Model: ModerationTextStable, + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: openai.ModerationTextStable, Input: "I want to kill them.", }) checks.NoError(t, err, "Moderation error") @@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { expect error } modelOptions = append(modelOptions, - getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), - getModerationModelTestOption(ModerationTextStable, nil), - getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), + getModerationModelTestOption(openai.ModerationTextStable, nil), + getModerationModelTestOption(openai.ModerationTextLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) for _, modelTest := range modelOptions { - _, err := client.Moderations(context.Background(), ModerationRequest{ + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ Model: modelTest.model, Input: "I want to kill them.", }) @@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var moderationReq ModerationRequest + var moderationReq openai.ModerationRequest if moderationReq, err = getModerationBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} + resCat := openai.ResultCategories{} + resCatScore := openai.ResultCategoryScores{} switch { case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} + resCat = openai.ResultCategories{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} + resCat = openai.ResultCategories{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} + resCat = openai.ResultCategories{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} } - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - res := ModerationResponse{ + res := openai.ModerationResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Model: moderationReq.Model, } @@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { } // getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} +func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { + moderation := openai.ModerationRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } return moderation, nil } diff --git a/openai_test.go b/openai_test.go index 4fc41ecc0..729d8880c 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,29 +1,29 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" ) -func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultConfig(test.GetTestToken()) + config := openai.DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } -func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } diff --git a/stream_test.go b/stream_test.go index f3f8f85cd..35c52ae3b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -10,23 +10,23 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletionStream( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []CompletionResponse{ + expectedResponses := []openai.CompletionResponse{ { ID: "1", Object: "completion", Created: 1598069254, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, }, { ID: "2", Object: "completion", Created: 1598069255, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, } @@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3TextDavinci003, + Model: openai.GPT3TextDavinci003, Prompt: "Hello!", Stream: true, }) @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - var apiErr *APIError - _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + var apiErr *openai.APIError + _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Ada, + Model: openai.GPT3Ada, Prompt: "Hello!", Stream: true, }) @@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, _ = stream.Recv() _, streamErr := stream.Recv() - if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) { t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") } } @@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() - _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { } // Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { +func compareResponses(r1, r2 openai.CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool { return true } -func compareResponseChoices(c1, c2 CompletionChoice) bool { +func compareResponseChoices(c1, c2 openai.CompletionChoice) bool { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { return false }