diff --git a/embeddings_test.go b/embeddings_test.go index 0259cead0..252f7a5a0 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -2,10 +2,14 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" + "context" "encoding/json" + "fmt" + "net/http" "testing" ) @@ -45,3 +49,43 @@ func TestEmbedding(t *testing.T) { } } } + +func TestEmbeddingModel(t *testing.T) { + var em EmbeddingModel + err := em.UnmarshalText([]byte("text-similarity-ada-001")) + checks.NoError(t, err, "Could not marshal embedding model") + + if em != AdaSimilarity { + t.Errorf("Model is not equal to AdaSimilarity") + } + + err = em.UnmarshalText([]byte("some-non-existent-model")) + checks.NoError(t, err, "Could not marshal embedding model") + if em != Unknown { + t.Errorf("Model is not equal to Unknown") + } +} + +func TestEmbeddingEndpoint(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler( + "/v1/embeddings", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(EmbeddingResponse{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + checks.NoError(t, err, "CreateEmbeddings error") +}