Skip to content

Commit

Permalink
Added support for sparsevec type to EF Core
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 17, 2024
1 parent 46236c3 commit 29ce6df
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/Pgvector.EntityFrameworkCore/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.2.1 (unreleased)

- Added support for `halfvec` type
- Added support for `halfvec` and `sparsevec` types
- Added support for compiled models
- Added `L1Distance` function

Expand Down
19 changes: 19 additions & 0 deletions src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.EntityFrameworkCore.Storage;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using NpgsqlTypes;

namespace Pgvector.EntityFrameworkCore;

public class SparsevecTypeMapping : RelationalTypeMapping
{
public static SparsevecTypeMapping Default { get; } = new();

public SparsevecTypeMapping() : base("sparsevec", typeof(SparseVector)) { }

public SparsevecTypeMapping(string storeType) : base(storeType, typeof(SparseVector)) { }

protected SparsevecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }

protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
=> new SparsevecTypeMapping(parameters);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Microsoft.EntityFrameworkCore.Storage;

namespace Pgvector.EntityFrameworkCore;

public class SparsevecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
{
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
=> mappingInfo.ClrType == typeof(SparseVector)
? new SparsevecTypeMapping(mappingInfo.StoreTypeName ?? "sparsevec")
: null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public void ApplyServices(IServiceCollection services)

services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, SparsevecTypeMappingSourcePlugin>();
}

public void Validate(IDbContextOptions options) { }
Expand Down
13 changes: 10 additions & 3 deletions tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class Item

[Column("half_embedding", TypeName = "halfvec(3)")]
public HalfVector? HalfEmbedding { get; set; }

[Column("sparse_embedding", TypeName = "sparsevec(3)")]
public SparseVector? SparseEmbedding { get; set; }
}

public class EntityFrameworkCoreTests
Expand All @@ -52,9 +55,9 @@ public async Task Main()
var databaseCreator = ctx.GetService<IRelationalDatabaseCreator>();
databaseCreator.CreateTables();

ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 1 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 1 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 2, 2, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)2, (Half)2, (Half)2 }), SparseEmbedding = new SparseVector(new float[] { 2, 2, 2 }) });
ctx.Items.Add(new Item { Embedding = new Vector(new float[] { 1, 1, 2 }), HalfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)2 }), SparseEmbedding = new SparseVector(new float[] { 1, 1, 2 }) });
ctx.SaveChanges();

var embedding = new Vector(new float[] { 1, 1, 1 });
Expand All @@ -80,6 +83,10 @@ public async Task Main()
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());

items = await ctx.Items
.OrderBy(x => x.Id)
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
Expand Down

0 comments on commit 29ce6df

Please sign in to comment.