Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING_CHANGES] convert EmbeddingModel to string type #629

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 21 additions & 99 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")

// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int

// String implements the fmt.Stringer interface.
func (e EmbeddingModel) String() string {
return enumToString[e]
}

// MarshalText implements the encoding.TextMarshaler interface.
func (e EmbeddingModel) MarshalText() ([]byte, error) {
return []byte(e.String()), nil
}

// UnmarshalText implements the encoding.TextUnmarshaler interface.
// On unrecognized value, it sets |e| to Unknown.
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
if val, ok := stringToEnum[(string(b))]; ok {
*e = val
return nil
}

*e = Unknown

return nil
}
type EmbeddingModel string

const (
Unknown EmbeddingModel = iota
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchText
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchText
AdaEmbeddingV2
// Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"

AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
)

var enumToString = map[EmbeddingModel]string{
AdaSimilarity: "text-similarity-ada-001",
BabbageSimilarity: "text-similarity-babbage-001",
CurieSimilarity: "text-similarity-curie-001",
DavinciSimilarity: "text-similarity-davinci-001",
AdaSearchDocument: "text-search-ada-doc-001",
AdaSearchQuery: "text-search-ada-query-001",
BabbageSearchDocument: "text-search-babbage-doc-001",
BabbageSearchQuery: "text-search-babbage-query-001",
CurieSearchDocument: "text-search-curie-doc-001",
CurieSearchQuery: "text-search-curie-query-001",
DavinciSearchDocument: "text-search-davinci-doc-001",
DavinciSearchQuery: "text-search-davinci-query-001",
AdaCodeSearchCode: "code-search-ada-code-001",
AdaCodeSearchText: "code-search-ada-text-001",
BabbageCodeSearchCode: "code-search-babbage-code-001",
BabbageCodeSearchText: "code-search-babbage-text-001",
AdaEmbeddingV2: "text-embedding-ada-002",
}

var stringToEnum = map[string]EmbeddingModel{
"text-similarity-ada-001": AdaSimilarity,
"text-similarity-babbage-001": BabbageSimilarity,
"text-similarity-curie-001": CurieSimilarity,
"text-similarity-davinci-001": DavinciSimilarity,
"text-search-ada-doc-001": AdaSearchDocument,
"text-search-ada-query-001": AdaSearchQuery,
"text-search-babbage-doc-001": BabbageSearchDocument,
"text-search-babbage-query-001": BabbageSearchQuery,
"text-search-curie-doc-001": CurieSearchDocument,
"text-search-curie-query-001": CurieSearchQuery,
"text-search-davinci-doc-001": DavinciSearchDocument,
"text-search-davinci-query-001": DavinciSearchQuery,
"code-search-ada-code-001": AdaCodeSearchCode,
"code-search-ada-text-001": AdaCodeSearchText,
"code-search-babbage-code-001": BabbageCodeSearchCode,
"code-search-babbage-text-001": BabbageCodeSearchText,
"text-embedding-ada-002": AdaEmbeddingV2,
}

// Embedding is a special format of data representation that can be easily utilized by machine
// learning models and algorithms. The embedding is an information dense representation of the
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
Expand Down Expand Up @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
if err != nil {
return
}
Expand Down
22 changes: 3 additions & 19 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

Expand All @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqStrings)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

Expand All @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqTokens)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

func TestEmbeddingModel(t *testing.T) {
var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")

if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}

err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown")
}
}

func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
Loading