From 5c7d88212f6e73fdac89723d42b9e3a1b113931c Mon Sep 17 00:00:00 2001 From: Jackson Stone Date: Wed, 5 Jul 2023 16:53:53 -0500 Subject: [PATCH] Allow embeddings requests to be tokens or strings (#417) * Allow raw tokens to be used as embedding input * fix linting issues (lines too long) * add endpoint test for embedding from tokens * remove redundant comments * fix comment to match new param name * change interface to any * Rename methods and implement convert for base req * add comments to CreateEmbeddings * update tests * shorten line length * rename parameter --- embeddings.go | 62 +++++++++++++++++++++++++++++++++++++++++----- embeddings_test.go | 38 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/embeddings.go b/embeddings.go index ba327ce77..41af50b4b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -113,10 +113,25 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } -// EmbeddingRequest is the input to a Create embeddings request. +type EmbeddingRequestConverter interface { + // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens + Convert() EmbeddingRequest +} + type EmbeddingRequest struct { + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` +} + +func (r EmbeddingRequest) Convert() EmbeddingRequest { + return r +} + +// EmbeddingRequestStrings is the input to a create embeddings request with a slice of strings. +type EmbeddingRequestStrings struct { // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. + // Each input must not exceed 8192 tokens in length. // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they // have observed inferior results when newlines are present. // E.g. @@ -129,15 +144,50 @@ type EmbeddingRequest struct { User string `json:"user"` } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. +func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +type EmbeddingRequestTokens struct { + // Input is a slice of slices of ints ([][]int) for which you want to generate an Embedding vector. + // Each input must not exceed 8192 tokens in length. + // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they + // have observed inferior results when newlines are present. + // E.g. + // "The food was delicious and the waiter..." + Input [][]int `json:"input"` + // ID of the model to use. You can use the List models API to see all of your available models, + // or see our Model overview for descriptions of them. + Model EmbeddingModel `json:"model"` + // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + User string `json:"user"` +} + +func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |body.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) +// +// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens +// for embedding groups of text already converted to tokens. +func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll + baseReq := conv.Convert() + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &resp) + err = c.sendRequest(req, &res) return } diff --git a/embeddings_test.go b/embeddings_test.go index d7892cd5d..47c4f5108 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -32,6 +32,7 @@ func TestEmbedding(t *testing.T) { BabbageCodeSearchText, } for _, model := range embeddedModels { + // test embedding request with strings (simple embedding request) embeddingReq := EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", @@ -46,6 +47,34 @@ func TestEmbedding(t *testing.T) { if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { t.Fatalf("Expected embedding request to contain model field") } + + // test embedding request with strings + embeddingReqStrings := EmbeddingRequestStrings{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqStrings) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with tokens + embeddingReqTokens := EmbeddingRequestTokens{ + Input: [][]int{ + {464, 2057, 373, 12625, 290, 262, 46612}, + {6395, 6096, 286, 11525, 12083, 2581}, + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqTokens) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } } } @@ -75,6 +104,15 @@ func TestEmbeddingEndpoint(t *testing.T) { fmt.Fprintln(w, string(resBytes)) }, ) + // test create embeddings with strings (simple embedding request) _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") + + // test create embeddings with strings + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + checks.NoError(t, err, "CreateEmbeddings strings error") + + // test create embeddings with tokens + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + checks.NoError(t, err, "CreateEmbeddings tokens error") }