Skip to content

Commit

Permalink
Added more tests for EF Core
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 17, 2024
1 parent 35b9f56 commit 8a642c2
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public async Task Main()
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }), BinaryEmbedding = new BitArray(new bool[] { true, true, true }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 2 }) });
ctx.SaveChanges();

// vector

var embedding = new Vector(new float[] { 1, 1, 1 });
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Expand All @@ -83,21 +85,47 @@ public async Task Main()
items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

// halfvec

var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

// sparsevec

var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

// bit

var binaryEmbedding = new BitArray(new bool[] { true, false, true });
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());

// additional

items = await ctx.Items
.OrderBy(x => x.Id)
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
Expand Down

0 comments on commit 8a642c2

Please sign in to comment.