From cf66e8399df4f25e4ef3a8021525912b6fdb1240 Mon Sep 17 00:00:00 2001 From: Rasmus Holm Date: Tue, 25 Feb 2025 10:36:11 +0100 Subject: [PATCH 1/2] adding mode/type for embeddings --- README.md | 25 +++++++++++++++++++++++++ models/embed/models.go | 18 ++++++++++++++++-- services/vertexai/google.go | 29 ++++++++++++++++++++++++----- services/voyageai/voyageai.go | 16 +++++++++++++++- 4 files changed, 80 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b1dc7a2..79f7e5e 100644 --- a/README.md +++ b/README.md @@ -470,6 +470,31 @@ for _, p := range res.Promps { ``` +## Embeddings + +Bellman integrates with most the embeddig models aswell as the llms that is provided by the supported +providers. There is also a VoyageAI, voyageai.com, that only really deals with embeddings + +```go +client := bellman_client := bellman.New(...) +res, err := client.Embed(embed.Request{ + Model: vertexai.EmbedModel_text_005.WithMode(embed.ModeDocument), + Text: "The document to embed", + }) + +fmt.Println(res.AsFloat32()) +// [-0.06821047514677048 -0.00014664272021036595 0.011814368888735771 .... +``` + +### Mode / Type +Some embeddings models support specific types of input. +Eg. +VertexAI https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#retrieve_information_from_texts +and VoyageAI https://docs.voyageai.com/docs/embeddings + +This allows you to define what type of text you are sending. +For example `embed.ModeDocument` for initial embedding and `embed.ModeQuery` +for getting a vector that is to be compared ## License diff --git a/models/embed/models.go b/models/embed/models.go index 497a414..5f5ca01 100644 --- a/models/embed/models.go +++ b/models/embed/models.go @@ -10,15 +10,29 @@ type Embeder interface { Embed(embed Request) (*Response, error) } +type Mode string + +const ModeQuery = "query" +const ModeDocument = "document" +const ModeNone = "" + type Model struct { - Provider string `json:"provider"` - Name string `json:"name"` + Provider string `json:"provider"` + Name string `json:"name"` + + Mode Mode `json:"mode,omitempty"` + Description string `json:"description,omitempty"` InputMaxTokens int `json:"input_max_tokens,omitempty"` OutputDimensions int `json:"output_dimensions,omitempty"` } +func (m Model) WithMode(mode Mode) Model { + m.Mode = mode + return m +} + func (m Model) FQN() string { return m.String() } diff --git a/services/vertexai/google.go b/services/vertexai/google.go index f89dda0..59bf554 100644 --- a/services/vertexai/google.go +++ b/services/vertexai/google.go @@ -17,9 +17,18 @@ import ( "golang.org/x/oauth2/google" ) +const ModeDocument embed.Mode = "RETRIEVAL_DOCUMENT" +const ModeQuery embed.Mode = "RETRIEVAL_QUERY" +const ModeQuestionAnswer embed.Mode = "QUESTION_ANSWERING" +const ModeFactVerification embed.Mode = "FACT_VERIFICATION" +const ModeCodeRetrieval embed.Mode = "CODE_RETRIEVAL_QUERY" +const ModeClustering embed.Mode = "CLUSTERING" +const ModeClassification embed.Mode = "CLASSIFICATION" +const ModeSemanticSimilarity embed.Mode = "SEMANTIC_SIMILARITY" + type GoogleEmbedRequest struct { Instances []struct { - //TaskType string `json:"task_type"` + TaskType string `json:"task_type,omitempty"` //Title string `json:"title"` Content string `json:"content"` } `json:"instances"` @@ -88,14 +97,24 @@ func (g *Google) Provider() string { func (g *Google) Embed(request embed.Request) (*embed.Response, error) { var reqc = atomic.AddInt64(&requestNo, 1) + mode := "" + switch request.Model.Mode { + case embed.ModeDocument: + mode = "RETRIEVAL_DOCUMENT" + case embed.ModeQuery: + mode = "RETRIEVAL_QUERY" + default: + mode = string(request.Model.Mode) + } + req := GoogleEmbedRequest{ Instances: []struct { - //TaskType string `json:"task_type"` - Content string `json:"content"` + TaskType string `json:"task_type,omitempty"` + Content string `json:"content"` }{ { - //TaskType: model.Name, - Content: request.Text, + TaskType: mode, + Content: request.Text, }, }, } diff --git a/services/voyageai/voyageai.go b/services/voyageai/voyageai.go index ef9d91d..4242909 100644 --- a/services/voyageai/voyageai.go +++ b/services/voyageai/voyageai.go @@ -10,6 +10,7 @@ import ( "io" "log/slog" "net/http" + "strings" "sync/atomic" ) @@ -58,9 +59,22 @@ func (v *VoyageAI) Embed(request embed.Request) (*embed.Response, error) { u := `https://api.voyageai.com/v1/embeddings` + text := request.Text + + switch request.Model.Mode { + case embed.ModeQuery: + if !strings.HasPrefix(text, "Represent the query for retrieving supporting documents:") { + text = "Represent the query for retrieving supporting documents: " + text + } + case embed.ModeDocument: + if !strings.HasPrefix(text, "Represent the document for retrieval:") { + text = "Represent the document for retrieval: " + text + } + } + reqModel := localRequest{ Input: []string{ - request.Text, + text, }, Model: request.Model.Name, } From 12cbb65b5342e358fe2c05db25f0c2f474f4a002 Mon Sep 17 00:00:00 2001 From: Rasmus Holm Date: Tue, 25 Feb 2025 11:00:01 +0100 Subject: [PATCH 2/2] refactoring some naming --- models/embed/models.go | 14 +++++++------- services/vertexai/google.go | 25 ++++++++----------------- services/vertexai/models.go | 18 ++++++++---------- services/voyageai/models.go | 3 +++ services/voyageai/voyageai.go | 14 +++++++------- 5 files changed, 33 insertions(+), 41 deletions(-) diff --git a/models/embed/models.go b/models/embed/models.go index 5f5ca01..9fd45ca 100644 --- a/models/embed/models.go +++ b/models/embed/models.go @@ -10,17 +10,17 @@ type Embeder interface { Embed(embed Request) (*Response, error) } -type Mode string +type Type string -const ModeQuery = "query" -const ModeDocument = "document" -const ModeNone = "" +const TypeQuery = "query" +const TypeDocument = "document" +const TypeNone = "" type Model struct { Provider string `json:"provider"` Name string `json:"name"` - Mode Mode `json:"mode,omitempty"` + Type Type `json:"type,omitempty"` Description string `json:"description,omitempty"` @@ -28,8 +28,8 @@ type Model struct { OutputDimensions int `json:"output_dimensions,omitempty"` } -func (m Model) WithMode(mode Mode) Model { - m.Mode = mode +func (m Model) WithType(mode Type) Model { + m.Type = mode return m } diff --git a/services/vertexai/google.go b/services/vertexai/google.go index 59bf554..30f20de 100644 --- a/services/vertexai/google.go +++ b/services/vertexai/google.go @@ -17,15 +17,6 @@ import ( "golang.org/x/oauth2/google" ) -const ModeDocument embed.Mode = "RETRIEVAL_DOCUMENT" -const ModeQuery embed.Mode = "RETRIEVAL_QUERY" -const ModeQuestionAnswer embed.Mode = "QUESTION_ANSWERING" -const ModeFactVerification embed.Mode = "FACT_VERIFICATION" -const ModeCodeRetrieval embed.Mode = "CODE_RETRIEVAL_QUERY" -const ModeClustering embed.Mode = "CLUSTERING" -const ModeClassification embed.Mode = "CLASSIFICATION" -const ModeSemanticSimilarity embed.Mode = "SEMANTIC_SIMILARITY" - type GoogleEmbedRequest struct { Instances []struct { TaskType string `json:"task_type,omitempty"` @@ -97,14 +88,14 @@ func (g *Google) Provider() string { func (g *Google) Embed(request embed.Request) (*embed.Response, error) { var reqc = atomic.AddInt64(&requestNo, 1) - mode := "" - switch request.Model.Mode { - case embed.ModeDocument: - mode = "RETRIEVAL_DOCUMENT" - case embed.ModeQuery: - mode = "RETRIEVAL_QUERY" + tasktype := "" + switch request.Model.Type { + case embed.TypeDocument: + tasktype = string(TypeDocument) + case embed.TypeQuery: + tasktype = string(TypeQuery) default: - mode = string(request.Model.Mode) + tasktype = string(request.Model.Type) } req := GoogleEmbedRequest{ @@ -113,7 +104,7 @@ func (g *Google) Embed(request embed.Request) (*embed.Response, error) { Content string `json:"content"` }{ { - TaskType: mode, + TaskType: tasktype, Content: request.Text, }, }, diff --git a/services/vertexai/models.go b/services/vertexai/models.go index e51d9d2..036c179 100644 --- a/services/vertexai/models.go +++ b/services/vertexai/models.go @@ -166,16 +166,14 @@ var EmbedModel_text_gecko_multilang_001 = embed.Model{ const EmbedDimensions = 768 -type EmbedType string - -const EmbedTypeQuery EmbedType = "RETRIEVAL_QUERY" -const EmbedTypeDocument EmbedType = "RETRIEVAL_DOCUMENT" -const EmbedTypeSimilarity EmbedType = "SEMANTIC_SIMILARITY" -const EmbedTypeClassification EmbedType = "CLASSIFICATION" -const EmbedTypeClustring EmbedType = "CLUSTERING" -const EmbedTypeQA EmbedType = "QUESTION_ANSWERING" -const EmbedTypeVerification EmbedType = "FACT_VERIFICATION" -const EmbedTypeCode EmbedType = "CODE_RETRIEVAL_QUERY" +const TypeDocument embed.Type = "RETRIEVAL_DOCUMENT" +const TypeQuery embed.Type = "RETRIEVAL_QUERY" +const TypeQuestionAnswer embed.Type = "QUESTION_ANSWERING" +const TypeFactVerification embed.Type = "FACT_VERIFICATION" +const TypeCodeRetrieval embed.Type = "CODE_RETRIEVAL_QUERY" +const TypeClustering embed.Type = "CLUSTERING" +const TypeClassification embed.Type = "CLASSIFICATION" +const TypeSemanticSimilarity embed.Type = "SEMANTIC_SIMILARITY" var EmbedModels = map[string]embed.Model{ EmbedModel_text_005.Name: EmbedModel_text_005, diff --git a/services/voyageai/models.go b/services/voyageai/models.go index 7319b86..5772785 100644 --- a/services/voyageai/models.go +++ b/services/voyageai/models.go @@ -4,6 +4,9 @@ import ( "github.com/modfin/bellman/models/embed" ) +const TypeQuery embed.Type = "Represent the query for retrieving supporting documents" +const TypeDocument embed.Type = "Represent the document for retrieval" + const Provider = "VoyageAI" // https://docs.voyageai.com/docs/embeddings diff --git a/services/voyageai/voyageai.go b/services/voyageai/voyageai.go index 4242909..38fb35d 100644 --- a/services/voyageai/voyageai.go +++ b/services/voyageai/voyageai.go @@ -61,14 +61,14 @@ func (v *VoyageAI) Embed(request embed.Request) (*embed.Response, error) { text := request.Text - switch request.Model.Mode { - case embed.ModeQuery: - if !strings.HasPrefix(text, "Represent the query for retrieving supporting documents:") { - text = "Represent the query for retrieving supporting documents: " + text + switch request.Model.Type { + case embed.TypeQuery, TypeQuery: + if !strings.HasPrefix(text, string(TypeQuery)) { + text = fmt.Sprintf("%s: %s", TypeQuery, text) } - case embed.ModeDocument: - if !strings.HasPrefix(text, "Represent the document for retrieval:") { - text = "Represent the document for retrieval: " + text + case embed.TypeDocument, TypeDocument: + if !strings.HasPrefix(text, string(TypeDocument)) { + text = fmt.Sprintf("%s: %s", TypeDocument, text) } }