Skip to content

fix(stores): Stores fixes and testing #4663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions backend/go/stores/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,16 @@ func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error)
}

func isNormalized(k []float32) bool {
var sum float32
var sum float64

for _, v := range k {
sum += v
v64 := float64(v)
sum += v64*v64
}

return sum == 1.0
s := math.Sqrt(sum)

return s >= 0.99 && s <= 1.01
}

// TODO: This we could replace with handwritten SIMD code
Expand All @@ -328,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
dot += k1[i] * k2[i]
}

assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot))
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))

// 2.0 * (1.0 - dot) would be the Euclidean distance
return dot
Expand Down Expand Up @@ -418,7 +422,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {

sim := float32(dot / (mag1 * math.Sqrt(mag2)))

assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim))
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))

return sim
}
Expand Down
143 changes: 132 additions & 11 deletions tests/integration/stores_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"math"
"math/rand"
"os"
"path/filepath"

Expand All @@ -22,6 +23,19 @@ import (
//go:embed backend-assets/*
var backendAssets embed.FS

func normalize(vecs [][]float32) {
for i, k := range vecs {
norm := float64(0)
for _, x := range k {
norm += float64(x * x)
}
norm = math.Sqrt(norm)
for j, x := range k {
vecs[i][j] = x / float32(norm)
}
}
}

var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
Context("Embedded Store get,set and delete", func() {
var sl *model.ModelLoader
Expand Down Expand Up @@ -192,17 +206,8 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
// normalize the keys
for i, k := range keys {
norm := float64(0)
for _, x := range k {
norm += float64(x * x)
}
norm = math.Sqrt(norm)
for j, x := range k {
keys[i][j] = x / float32(norm)
}
}

normalize(keys)

err := store.SetCols(context.Background(), sc, keys, vals)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -225,5 +230,121 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
Expect(ks[1]).To(Equal(keys[1]))
Expect(vals[1]).To(Equal(vals[1]))
})

It("It produces the correct cosine similarities for orthogonal and opposite unit vectors", func() {
keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}}
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
Expect(err).ToNot(HaveOccurred())
Expect(sims).To(Equal([]float32{1.0, 0.0, 0.0, -1.0}))
})

It("It produces the correct cosine similarities for orthogonal and opposite vectors", func() {
keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}}
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
Expect(err).ToNot(HaveOccurred())
Expect(sims[0]).To(BeNumerically("~", 1, 0.1))
Expect(sims[1]).To(BeNumerically("~", 0, 0.1))
Expect(sims[2]).To(BeNumerically("~", -0.7, 0.1))
Expect(sims[3]).To(BeNumerically("~", -1, 0.1))
})

expectTriangleEq := func(keys [][]float32, vals [][]byte) {
sims := map[string]map[string]float32{}

// compare every key vector pair and store the similarities in a lookup table
// that uses the values as keys
for i, k := range keys {
_, valsk, simsk, err := store.Find(context.Background(), sc, k, 9)
Expect(err).ToNot(HaveOccurred())

for j, v := range valsk {
p := string(vals[i])
q := string(v)

if sims[p] == nil {
sims[p] = map[string]float32{}
}

//log.Debug().Strs("vals", []string{p, q}).Float32("similarity", simsk[j]).Send()

sims[p][q] = simsk[j]
}
}

// Check that the triangle inequality holds for every combination of the triplet
// u, v and w
for _, simsu := range sims {
for w, simw := range simsu {
// acos(u,w) <= ...
uws := math.Acos(float64(simw))

// ... acos(u,v) + acos(v,w)
for v, _ := range simsu {
uvws := math.Acos(float64(simsu[v])) + math.Acos(float64(sims[v][w]))

//log.Debug().Str("u", u).Str("v", v).Str("w", w).Send()
//log.Debug().Float32("uw", simw).Float32("uv", simsu[v]).Float32("vw", sims[v][w]).Send()
Expect(uws).To(BeNumerically("<=", uvws))
}
}
}
}

It("It obeys the triangle inequality for normalized values", func() {
keys := [][]float32{
{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0},
{-1.0, 0.0, 0.0}, {0.0, -1.0, 0.0}, {0.0, 0.0, -1.0},
{2.0, 3.0, 4.0}, {9.0, 7.0, 1.0}, {0.0, -1.2, 2.3},
}
vals := [][]byte{
[]byte("x"), []byte("y"), []byte("z"),
[]byte("-x"), []byte("-y"), []byte("-z"),
[]byte("u"), []byte("v"), []byte("w"),
}

normalize(keys[6:])

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

expectTriangleEq(keys, vals)
})

It("It obeys the triangle inequality", func() {
rnd := rand.New(rand.NewSource(151))
keys := make([][]float32, 20)
vals := make([][]byte, 20)

for i := range keys {
k := make([]float32, 768)

for j := range k {
k[j] = rnd.Float32()
}

keys[i] = k
}

c := byte('a')
for i := range vals {
vals[i] = []byte{c}
c += 1
}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

expectTriangleEq(keys, vals)
})
})
})
Loading