Skip to content

Commit a233b67

Browse files
MB-69655: Fix vector normalization to handle multi-vectors correctly (#2260)
- When indexing multi-vector fields (e.g., `[[3,0,0], [0,4,0]]`) with `cosine` similarity, normalization was incorrectly applied to the entire flattened array instead of each sub-vector independently, resulting in degraded similarity scores. - Added `NormalizeMultiVector(vec, dims)` that normalizes each sub-vector separately, fixing scores for multi-vector documents (e.g., score now correctly returns 1.0 instead of 0.6 for exact matches).
1 parent 8721d16 commit a233b67

File tree

3 files changed

+285
-4
lines changed

3 files changed

+285
-4
lines changed

mapping/mapping_vectors.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package mapping
2020
import (
2121
"fmt"
2222
"reflect"
23+
"slices"
2324

2425
"github.com/blevesearch/bleve/v2/document"
2526
"github.com/blevesearch/bleve/v2/util"
@@ -151,8 +152,10 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{},
151152
vectorIndexOptimizedFor = index.DefaultIndexOptimization
152153
}
153154
// normalize raw vector if similarity is cosine
155+
// Since the vector can be multi-vector (flattened array of multiple vectors),
156+
// we use NormalizeMultiVector to normalize each sub-vector independently.
154157
if similarity == index.CosineSimilarity {
155-
vector = NormalizeVector(vector)
158+
vector = NormalizeMultiVector(vector, fm.Dims)
156159
}
157160

158161
fieldName := getFieldName(pathString, path, fm)
@@ -186,7 +189,8 @@ func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interfac
186189
if err != nil || len(decodedVector) != fm.Dims {
187190
return
188191
}
189-
// normalize raw vector if similarity is cosine
192+
// normalize raw vector if similarity is cosine, multi-vector is not supported
193+
// for base64 encoded vectors, so we use NormalizeVector directly.
190194
if similarity == index.CosineSimilarity {
191195
decodedVector = NormalizeVector(decodedVector)
192196
}
@@ -292,11 +296,33 @@ func validateVectorFieldAlias(field *FieldMapping, path []string,
292296
return nil
293297
}
294298

299+
// NormalizeVector normalizes a single vector to unit length.
300+
// It makes a copy of the input vector to avoid modifying it in-place.
295301
func NormalizeVector(vec []float32) []float32 {
296302
// make a copy of the vector to avoid modifying the original
297303
// vector in-place
298-
vecCopy := make([]float32, len(vec))
299-
copy(vecCopy, vec)
304+
vecCopy := slices.Clone(vec)
300305
// normalize the vector copy using in-place normalization provided by faiss
301306
return faiss.NormalizeVector(vecCopy)
302307
}
308+
309+
// NormalizeMultiVector normalizes each sub-vector of size `dims` independently.
310+
// For a flattened array containing multiple vectors, each sub-vector is
311+
// normalized separately to unit length.
312+
// It makes a copy of the input vector to avoid modifying it in-place.
313+
func NormalizeMultiVector(vec []float32, dims int) []float32 {
314+
if len(vec) == 0 || dims <= 0 || len(vec)%dims != 0 {
315+
return vec
316+
}
317+
// Single vector - delegate to NormalizeVector
318+
if len(vec) == dims {
319+
return NormalizeVector(vec)
320+
}
321+
// Multi-vector - make a copy to avoid modifying the original
322+
result := slices.Clone(vec)
323+
// Normalize each sub-vector in-place
324+
for i := 0; i < len(result); i += dims {
325+
faiss.NormalizeVector(result[i : i+dims])
326+
}
327+
return result
328+
}

mapping/mapping_vectors_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package mapping
1919

2020
import (
21+
"math"
2122
"reflect"
2223
"strings"
2324
"testing"
@@ -1069,3 +1070,120 @@ func TestNormalizeVector(t *testing.T) {
10691070
}
10701071
}
10711072
}
1073+
1074+
func TestNormalizeMultiVectors(t *testing.T) {
1075+
tests := []struct {
1076+
name string
1077+
input []float32
1078+
dims int
1079+
expected []float32
1080+
}{
1081+
{
1082+
name: "single vector - already normalized",
1083+
input: []float32{1, 0, 0},
1084+
dims: 3,
1085+
expected: []float32{1, 0, 0},
1086+
},
1087+
{
1088+
name: "single vector - needs normalization",
1089+
input: []float32{3, 0, 0},
1090+
dims: 3,
1091+
expected: []float32{1, 0, 0},
1092+
},
1093+
{
1094+
name: "two vectors - X and Y directions",
1095+
input: []float32{3, 0, 0, 0, 4, 0},
1096+
dims: 3,
1097+
expected: []float32{1, 0, 0, 0, 1, 0},
1098+
},
1099+
{
1100+
name: "three vectors",
1101+
input: []float32{3, 0, 0, 0, 4, 0, 0, 0, 5},
1102+
dims: 3,
1103+
expected: []float32{1, 0, 0, 0, 1, 0, 0, 0, 1},
1104+
},
1105+
{
1106+
name: "two 2D vectors",
1107+
input: []float32{3, 4, 5, 12},
1108+
dims: 2,
1109+
expected: []float32{0.6, 0.8, 0.38461538, 0.92307693},
1110+
},
1111+
{
1112+
name: "empty vector",
1113+
input: []float32{},
1114+
dims: 3,
1115+
expected: []float32{},
1116+
},
1117+
{
1118+
name: "zero dims",
1119+
input: []float32{1, 2, 3},
1120+
dims: 0,
1121+
expected: []float32{1, 2, 3},
1122+
},
1123+
{
1124+
name: "negative dims",
1125+
input: []float32{1, 2, 3},
1126+
dims: -1,
1127+
expected: []float32{1, 2, 3},
1128+
},
1129+
}
1130+
1131+
for _, tt := range tests {
1132+
t.Run(tt.name, func(t *testing.T) {
1133+
// Make a copy of input to verify original is not modified
1134+
inputCopy := make([]float32, len(tt.input))
1135+
copy(inputCopy, tt.input)
1136+
1137+
result := NormalizeMultiVector(tt.input, tt.dims)
1138+
1139+
// Check result matches expected
1140+
if len(result) != len(tt.expected) {
1141+
t.Errorf("length mismatch: expected %d, got %d", len(tt.expected), len(result))
1142+
return
1143+
}
1144+
1145+
for i := range result {
1146+
if !floatApproxEqual(result[i], tt.expected[i], 1e-5) {
1147+
t.Errorf("value mismatch at index %d: expected %v, got %v",
1148+
i, tt.expected[i], result[i])
1149+
}
1150+
}
1151+
1152+
// Verify original input was not modified
1153+
if !reflect.DeepEqual(tt.input, inputCopy) {
1154+
t.Errorf("original input was modified: was %v, now %v", inputCopy, tt.input)
1155+
}
1156+
1157+
// For valid multi-vectors, verify each sub-vector has unit magnitude
1158+
if tt.dims > 0 && len(tt.input) > 0 && len(tt.input)%tt.dims == 0 {
1159+
numVecs := len(result) / tt.dims
1160+
for i := 0; i < numVecs; i++ {
1161+
subVec := result[i*tt.dims : (i+1)*tt.dims]
1162+
mag := magnitude(subVec)
1163+
// Allow for zero vectors (magnitude 0) or unit vectors (magnitude 1)
1164+
if mag > 1e-6 && !floatApproxEqual(mag, 1.0, 1e-5) {
1165+
t.Errorf("sub-vector %d has magnitude %v, expected 1.0", i, mag)
1166+
}
1167+
}
1168+
}
1169+
})
1170+
}
1171+
}
1172+
1173+
// Helper to compute magnitude of a vector
1174+
func magnitude(v []float32) float32 {
1175+
var sum float32
1176+
for _, x := range v {
1177+
sum += x * x
1178+
}
1179+
return float32(math.Sqrt(float64(sum)))
1180+
}
1181+
1182+
// Helper for approximate float comparison
1183+
func floatApproxEqual(a, b, epsilon float32) bool {
1184+
diff := a - b
1185+
if diff < 0 {
1186+
diff = -diff
1187+
}
1188+
return diff < epsilon
1189+
}

search_knn_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,143 @@ func TestKNNMerger(t *testing.T) {
18561856
})
18571857
}
18581858

1859+
// TestMultiVectorCosineNormalization verifies that multi-vector fields are
1860+
// normalized correctly with cosine similarity. Each sub-vector in a multi-vector
1861+
// should be independently normalized, producing correct similarity scores.
1862+
func TestMultiVectorCosineNormalization(t *testing.T) {
1863+
tmpIndexPath := createTmpIndexPath(t)
1864+
defer cleanupTmpIndexPath(t, tmpIndexPath)
1865+
1866+
const dims = 3
1867+
1868+
// Create index with cosine similarity
1869+
indexMapping := NewIndexMapping()
1870+
vecFieldMapping := mapping.NewVectorFieldMapping()
1871+
vecFieldMapping.Dims = dims
1872+
vecFieldMapping.Similarity = index.CosineSimilarity
1873+
indexMapping.DefaultMapping.AddFieldMappingsAt("vec", vecFieldMapping)
1874+
1875+
// Multi-vector field
1876+
vecFieldMappingNested := mapping.NewVectorFieldMapping()
1877+
vecFieldMappingNested.Dims = dims
1878+
vecFieldMappingNested.Similarity = index.CosineSimilarity
1879+
indexMapping.DefaultMapping.AddFieldMappingsAt("vec_nested", vecFieldMappingNested)
1880+
1881+
idx, err := New(tmpIndexPath, indexMapping)
1882+
if err != nil {
1883+
t.Fatal(err)
1884+
}
1885+
defer func() {
1886+
err := idx.Close()
1887+
if err != nil {
1888+
t.Fatal(err)
1889+
}
1890+
}()
1891+
1892+
docsString := []string{
1893+
`{"vec": [3, 0, 0]}`,
1894+
`{"vec": [0, 4, 0]}`,
1895+
`{"vec_nested": [[3, 0, 0], [0, 4, 0]]}`,
1896+
}
1897+
1898+
for i, docStr := range docsString {
1899+
var doc map[string]interface{}
1900+
err = json.Unmarshal([]byte(docStr), &doc)
1901+
if err != nil {
1902+
t.Fatal(err)
1903+
}
1904+
err = idx.Index(fmt.Sprintf("doc%d", i+1), doc)
1905+
if err != nil {
1906+
t.Fatal(err)
1907+
}
1908+
}
1909+
1910+
// Query for X direction [1,0,0]
1911+
searchReq := NewSearchRequest(query.NewMatchNoneQuery())
1912+
searchReq.AddKNN("vec", []float32{1, 0, 0}, 3, 1.0)
1913+
res, err := idx.Search(searchReq)
1914+
if err != nil {
1915+
t.Fatal(err)
1916+
}
1917+
if len(res.Hits) != 2 {
1918+
t.Fatalf("expected 2 hits, got %d", len(res.Hits))
1919+
}
1920+
// Hit 1 should be doc1 with score 1.0 (perfect match)
1921+
if res.Hits[0].ID != "doc1" {
1922+
t.Fatalf("expected doc1 as first hit, got %s", res.Hits[0].ID)
1923+
}
1924+
if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 {
1925+
t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score)
1926+
}
1927+
// Hit 2 should be doc2 with a score of 0.0 (orthogonal)
1928+
if res.Hits[1].ID != "doc2" {
1929+
t.Fatalf("expected doc2 as second hit, got %s", res.Hits[1].ID)
1930+
}
1931+
if math.Abs(float64(res.Hits[1].Score-0.0)) > 1e-6 {
1932+
t.Fatalf("expected score 0.0, got %f", res.Hits[1].Score)
1933+
}
1934+
1935+
// Query for Y direction [0,1,0]
1936+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
1937+
searchReq.AddKNN("vec", []float32{0, 1, 0}, 3, 1.0)
1938+
res, err = idx.Search(searchReq)
1939+
if err != nil {
1940+
t.Fatal(err)
1941+
}
1942+
if len(res.Hits) != 2 {
1943+
t.Fatalf("expected 2 hits, got %d", len(res.Hits))
1944+
}
1945+
// Hit 1 should be doc2 with score 1.0 (perfect match)
1946+
if res.Hits[0].ID != "doc2" {
1947+
t.Fatalf("expected doc2 as first hit, got %s", res.Hits[0].ID)
1948+
}
1949+
if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 {
1950+
t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score)
1951+
}
1952+
// Hit 2 should be doc1 with a score of 0.0 (orthogonal)
1953+
if res.Hits[1].ID != "doc1" {
1954+
t.Fatalf("expected doc1 as second hit, got %s", res.Hits[1].ID)
1955+
}
1956+
if math.Abs(float64(res.Hits[1].Score-0.0)) > 1e-6 {
1957+
t.Fatalf("expected score 0.0, got %f", res.Hits[1].Score)
1958+
}
1959+
1960+
// Now test querying the nested multi-vector field
1961+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
1962+
searchReq.AddKNN("vec_nested", []float32{1, 0, 0}, 3, 1.0)
1963+
res, err = idx.Search(searchReq)
1964+
if err != nil {
1965+
t.Fatal(err)
1966+
}
1967+
if len(res.Hits) != 1 {
1968+
t.Fatalf("expected 1 hit, got %d", len(res.Hits))
1969+
}
1970+
// Hit should be doc3 with score 1.0 (perfect match on first sub-vector)
1971+
if res.Hits[0].ID != "doc3" {
1972+
t.Fatalf("expected doc3 as first hit, got %s", res.Hits[0].ID)
1973+
}
1974+
if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 {
1975+
t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score)
1976+
}
1977+
// Query for Y direction [0,1,0] on nested field
1978+
searchReq = NewSearchRequest(query.NewMatchNoneQuery())
1979+
searchReq.AddKNN("vec_nested", []float32{0, 1, 0}, 3, 1.0)
1980+
res, err = idx.Search(searchReq)
1981+
if err != nil {
1982+
t.Fatal(err)
1983+
}
1984+
if len(res.Hits) != 1 {
1985+
t.Fatalf("expected 1 hit, got %d", len(res.Hits))
1986+
}
1987+
// Hit should be doc3 with score 1.0 (perfect match on second sub-vector)
1988+
if res.Hits[0].ID != "doc3" {
1989+
t.Fatalf("expected doc3 as first hit, got %s", res.Hits[0].ID)
1990+
}
1991+
if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 {
1992+
t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score)
1993+
}
1994+
}
1995+
18591996
func TestNumVecsStat(t *testing.T) {
18601997

18611998
dataset, _, err := readDatasetAndQueries(testInputCompressedFile)

0 commit comments

Comments
 (0)