Skip to content

CompletionBatchingRequestSupport #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
)

var (
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Promp only supports string and []string") //nolint:lll
)

// GPT3 Defines the models provided by OpenAI to use when generating
Expand Down Expand Up @@ -77,10 +78,16 @@ func checkEndpointSupportsModel(endpoint, model string) bool {
return !disabledModelsForEndpoints[endpoint][model]
}

func checkPromptType(prompt any) bool {
_, isString := prompt.(string)
_, isStringSlice := prompt.([]string)
return isString || isStringSlice
}

// CompletionRequest represents a request structure for completion API.
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
Prompt any `json:"prompt,omitempty"`
Suffix string `json:"suffix,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Expand Down Expand Up @@ -143,6 +150,11 @@ func (c *Client) CreateCompletion(
return
}

if !checkPromptType(request.Prompt) {
err = ErrCompletionRequestPromptTypeNotSupported
return
}

req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
Expand Down
4 changes: 2 additions & 2 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt + completionStr
completionStr = completionReq.Prompt.(string) + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt) * completionReq.N
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
Expand Down
26 changes: 25 additions & 1 deletion request_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{})
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
Expand Down Expand Up @@ -146,3 +146,27 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}
5 changes: 5 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ func (c *Client) CreateCompletionStream(
return
}

if !checkPromptType(request.Prompt) {
err = ErrCompletionRequestPromptTypeNotSupported
return
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
if err != nil {
Expand Down