From a6b35c3ab5d23ecb601d2bbe89e5979e3463aa48 Mon Sep 17 00:00:00 2001 From: sashabaranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 18 Mar 2023 19:31:54 +0400 Subject: [PATCH] Check for `Stream` parameter usage (#174) * check for stream:true usage * lint --- chat.go | 11 ++++++++--- chat_test.go | 15 +++++++++++++++ completion.go | 8 +++++++- completion_test.go | 12 ++++++++++++ models_test.go | 2 +- request_builder_test.go | 4 ++-- 6 files changed, 45 insertions(+), 7 deletions(-) diff --git a/chat.go b/chat.go index 99edfe85f..0f56216fd 100644 --- a/chat.go +++ b/chat.go @@ -14,7 +14,8 @@ const ( ) var ( - ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") + ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll + ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll ) type ChatCompletionMessage struct { @@ -65,8 +66,12 @@ func (c *Client) CreateChatCompletion( ctx context.Context, request ChatCompletionRequest, ) (response ChatCompletionResponse, err error) { - model := request.Model - switch model { + if request.Stream { + err = ErrChatCompletionStreamNotSupported + return + } + + switch request.Model { case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: default: err = ErrChatCompletionInvalidModel diff --git a/chat_test.go b/chat_test.go index 5c03ebf7b..8866ff2ae 100644 --- a/chat_test.go +++ b/chat_test.go @@ -38,6 +38,21 @@ func TestChatCompletionsWrongModel(t *testing.T) { } } +func TestChatCompletionsWithStream(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := ChatCompletionRequest{ + Stream: true, + } + _, err := client.CreateChatCompletion(ctx, req) + if !errors.Is(err, ErrChatCompletionStreamNotSupported) { + t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error") + } +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletions(t *testing.T) { server := test.NewTestServer() diff --git a/completion.go b/completion.go index 66b486665..22211d39f 100644 --- a/completion.go +++ b/completion.go @@ -7,7 +7,8 @@ import ( ) var ( - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ) // GPT3 Defines the models provided by OpenAI to use when generating @@ -99,6 +100,11 @@ func (c *Client) CreateCompletion( ctx context.Context, request CompletionRequest, ) (response CompletionResponse, err error) { + if request.Stream { + err = ErrCompletionStreamNotSupported + return + } + if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo { err = ErrCompletionUnsupportedModel return diff --git a/completion_test.go b/completion_test.go index 9868eb2bb..daa02e383 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,18 @@ func TestCompletionsWrongModel(t *testing.T) { } } +func TestCompletionWithStream(t *testing.T) { + config := DefaultConfig("whatever") + client := NewClientWithConfig(config) + + ctx := context.Background() + req := CompletionRequest{Stream: true} + _, err := client.CreateCompletion(ctx, req) + if !errors.Is(err, ErrCompletionStreamNotSupported) { + t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") + } +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { server := test.NewTestServer() diff --git a/models_test.go b/models_test.go index c96ece823..972a5fe64 100644 --- a/models_test.go +++ b/models_test.go @@ -33,7 +33,7 @@ func TestListModels(t *testing.T) { } // handleModelsEndpoint Handles the models endpoint by the test server. -func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { +func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/request_builder_test.go b/request_builder_test.go index 533977a68..f0f99ee5b 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -19,11 +19,11 @@ type ( failingMarshaller struct{} ) -func (*failingMarshaller) marshal(value any) ([]byte, error) { +func (*failingMarshaller) marshal(_ any) ([]byte, error) { return []byte{}, errTestMarshallerFailed } -func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { +func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) { return nil, errTestRequestBuilderFailed }