Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/EFCore.SqlServer/Extensions/SqlServerQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ public static class SqlServerQueryableExtensions
/// An ANN (Approximate Nearest Neighbor) index is used only if a matching ANN index, with the same metric and on the same column,
/// is found. If there are no compatible ANN indexes, a warning is raised and the KNN (k-Nearest Neighbor) algorithm is used.
/// </param>
/// <param name="topN">The maximum number of similar vectors that must be returned. It must be a positive integer.</param>
/// <remarks>
/// <para>
/// Compose the returned query with <c>OrderBy(r => r.Distance)</c> and <c>Take(...)</c> to limit the results as required
/// for approximate vector search.
/// </para>
/// </remarks>
/// <seealso href="https://learn.microsoft.com/sql/t-sql/functions/vector-search-transact-sql">
/// SQL Server documentation for <c>VECTOR_SEARCH()</c>.
/// </seealso>
Expand All @@ -38,8 +43,7 @@ public static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
this DbSet<T> source,
Expression<Func<T, TVector>> vectorPropertySelector,
TVector similarTo,
[NotParameterized] string metric,
int topN)
[NotParameterized] string metric)
where T : class
where TVector : unmanaged
{
Expand All @@ -50,12 +54,11 @@ public static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
? queryableSource.Provider.CreateQuery<VectorSearchResult<T>>(
Expression.Call(
// Note that the method used is the one below, accepting IQueryable<T>, not DbSet<T>
method: new Func<IQueryable<T>, Expression<Func<T, TVector>>, TVector, string, int, IQueryable<VectorSearchResult<T>>>(VectorSearch).Method,
method: new Func<IQueryable<T>, Expression<Func<T, TVector>>, TVector, string, IQueryable<VectorSearchResult<T>>>(VectorSearch).Method,
root,
Expression.Quote(vectorPropertySelector),
Expression.Constant(similarTo),
Expression.Constant(metric),
Expression.Constant(topN)))
Expression.Constant(metric)))
: throw new InvalidOperationException(CoreStrings.FunctionOnNonEfLinqProvider(nameof(VectorSearch)));
}

Expand All @@ -67,8 +70,7 @@ private static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
this IQueryable<T> source,
Expression<Func<T, TVector>> vectorPropertySelector,
TVector similarTo,
[NotParameterized] string metric,
int topN)
[NotParameterized] string metric)
where T : class
where TVector : unmanaged
=> throw new UnreachableException();
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.SqlServer/Extensions/VectorSearchResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Microsoft.EntityFrameworkCore;

/// <summary>
/// Represents the results from a call to
/// <see cref="SqlServerQueryableExtensions.VectorSearch{T, TVector}(DbSet{T}, Expression{Func{T, TVector}}, TVector, string, int)" />.
/// <see cref="SqlServerQueryableExtensions.VectorSearch{T, TVector}(DbSet{T}, Expression{Func{T, TVector}}, TVector, string)" />.
/// </summary>
[Experimental(EFDiagnostics.SqlServerVectorSearch)]
public readonly struct VectorSearchResult<T>(T value, double distance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,14 @@ protected override Expression VisitTableValuedFunction(TableValuedFunctionExpres
TableExpression table,
ColumnExpression column,
SqlExpression similarTo,
SqlConstantExpression { Value: string } metric,
SqlExpression topN
SqlConstantExpression { Value: string } metric
]
}:
// VECTOR_SEARCH(
// TABLE = [Articles] AS t,
// COLUMN = [Vector],
// SIMILAR_TO = @qv,
// METRIC = 'Cosine',
// TOP_N = 3
// METRIC = 'Cosine'
// )
Sql.AppendLine("VECTOR_SEARCH(");

Expand All @@ -185,10 +183,6 @@ SqlExpression topN

Sql.Append("METRIC = ");
Visit(metric);
Sql.AppendLine(",");

Sql.Append("TOP_N = ");
Visit(topN);
Sql.AppendLine();
}

Expand Down Expand Up @@ -558,6 +552,13 @@ protected override void GenerateTop(SelectExpression selectExpression)
Visit(selectExpression.Limit);

Sql.Append(") ");

// When performing approximate vector search with VECTOR_SEARCH(), SQL Server requires adding
// WITH APPROXIMATE: https://learn.microsoft.com/sql/t-sql/functions/vector-search-transact-sql
if (selectExpression.Tables.Any(t => t.UnwrapJoin() is TableValuedFunctionExpression { Name: "VECTOR_SEARCH" }))
{
Sql.Append("WITH APPROXIMATE ");
}
}

_withinTable = parentWithinTable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ when methodCallExpression.Arguments is
_, // source, translated above
UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression vectorPropertySelector },
var similarTo,
var metric,
var topN
var metric
]
&& source is
{
Expand All @@ -113,8 +112,7 @@ var topN
}

if (TranslateExpression(similarTo) is not { } translatedSimilarTo
|| TranslateExpression(metric, applyDefaultTypeMapping: false) is not { } translatedMetric
|| TranslateExpression(topN) is not { } translatedTopN)
|| TranslateExpression(metric, applyDefaultTypeMapping: false) is not { } translatedMetric)
{
return QueryCompilationContext.NotTranslatedExpression;
}
Expand All @@ -135,8 +133,7 @@ var topN
// as required by SQL Server)
vectorColumn,
translatedSimilarTo,
translatedMetric,
translatedTopN
translatedMetric
]);

// We have the VECTOR_SEARCH() function call. Modify the SelectExpression and shaper to use it and project
Expand Down
1 change: 1 addition & 0 deletions test/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
<PackageVersion Include="Azure.ResourceManager.CosmosDB" Version="$(AzureResourceManagerCosmosDBVersion)" />
<PackageVersion Include="Microsoft.AspNetCore.Identity.EntityFrameworkCore" Version="9.0.5" />
<PackageVersion Include="Microsoft.AspNetCore.OData" Version="9.4.1" />
<PackageVersion Include="Microsoft.Data.SqlClient.Extensions.Azure" Version="1.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="$(MicrosoftExtensionsConfigurationVersion)" />
<PackageVersion Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="$(MicrosoftExtensionsConfigurationEnvironmentVariablesVersion)" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Json" Version="$(MicrosoftExtensionsConfigurationJsonVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" />
<PackageReference Include="Microsoft.Data.SqlClient.Extensions.Azure" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ ORDER BY VECTOR_DISTANCE('cosine', [v].[Vector], CAST('[1,2,100]' AS VECTOR(3)))
""");
}

// The latest vector index version (required for VECTOR_SEARCH) is only available on Azure SQL (#36384).
[ConditionalFact]
[SqlServerCondition(SqlServerCondition.IsAzureSql)]
[Experimental("EF9105")]
public async Task VectorSearch_project_entity_and_distance()
{
Expand All @@ -74,28 +76,32 @@ public async Task VectorSearch_project_entity_and_distance()
var vector = new SqlVector<float>(new float[] { 1, 2, 100 });

var results = await ctx.VectorEntities
.VectorSearch(e => e.Vector, similarTo: vector, "cosine", topN: 1)
.VectorSearch(e => e.Vector, similarTo: vector, "cosine")
.OrderBy(e => e.Distance)
.Take(1)
.ToListAsync();

Assert.Equal(2, results.Single().Value.Id);

AssertSql(
"""
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
@p1='1'
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)

SELECT [v].[Id], [v0].[Distance]
SELECT TOP(@p1) WITH APPROXIMATE [v].[Id], [v0].[Distance]
FROM VECTOR_SEARCH(
TABLE = [VectorEntities] AS [v],
COLUMN = [Vector],
SIMILAR_TO = @p,
METRIC = 'cosine',
TOP_N = @p1
METRIC = 'cosine'
) AS [v0]
ORDER BY [v0].[Distance]
""");
}

// The latest vector index version (required for VECTOR_SEARCH) is only available on Azure SQL (#36384).
[ConditionalFact]
[SqlServerCondition(SqlServerCondition.IsAzureSql)]
[Experimental("EF9105")]
public async Task VectorSearch_project_entity_only_with_distance_filter_and_ordering()
{
Expand All @@ -104,10 +110,11 @@ public async Task VectorSearch_project_entity_only_with_distance_filter_and_orde
var vector = new SqlVector<float>(new float[] { 1, 2, 100 });

var results = await ctx.VectorEntities
.VectorSearch(e => e.Vector, similarTo: vector, "cosine", topN: 3)
.VectorSearch(e => e.Vector, similarTo: vector, "cosine")
.Where(e => e.Distance < 0.01)
.OrderBy(e => e.Distance)
.Select(e => e.Value)
.Take(3)
.ToListAsync();

Assert.Collection(
Expand All @@ -117,22 +124,65 @@ public async Task VectorSearch_project_entity_only_with_distance_filter_and_orde

AssertSql(
"""
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
@p1='3'
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)

SELECT [v].[Id]
SELECT TOP(@p1) WITH APPROXIMATE [v].[Id]
FROM VECTOR_SEARCH(
TABLE = [VectorEntities] AS [v],
COLUMN = [Vector],
SIMILAR_TO = @p,
METRIC = 'cosine',
TOP_N = @p1
METRIC = 'cosine'
) AS [v0]
WHERE [v0].[Distance] < 0.01E0
ORDER BY [v0].[Distance]
""");
}

// The latest vector index version (required for VECTOR_SEARCH) is only available on Azure SQL (#36384).
[ConditionalFact]
[SqlServerCondition(SqlServerCondition.IsAzureSql)]
[Experimental("EF9105")]
public async Task VectorSearch_in_subquery()
{
using var ctx = CreateContext();

var vector = new SqlVector<float>(new float[] { 1, 2, 100 });

var results = await ctx.VectorEntities
.VectorSearch(e => e.Vector, similarTo: vector, "cosine")
.OrderBy(e => e.Distance)
.Take(3)
.Select(e => new { e.Value.Id, e.Distance })
.Where(e => e.Distance < 0.01)
.ToListAsync();

Assert.Collection(
results,
r => Assert.Equal(2, r.Id),
r => Assert.Equal(3, r.Id));

AssertSql(
"""
@p1='3'
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)

SELECT [v1].[Id], [v1].[Distance]
FROM (
SELECT TOP(@p1) WITH APPROXIMATE [v].[Id], [v0].[Distance]
FROM VECTOR_SEARCH(
TABLE = [VectorEntities] AS [v],
COLUMN = [Vector],
SIMILAR_TO = @p,
METRIC = 'cosine'
) AS [v0]
ORDER BY [v0].[Distance]
) AS [v1]
WHERE [v1].[Distance] < 0.01E0
ORDER BY [v1].[Distance]
""");
}

[ConditionalFact]
public async Task Length()
{
Expand Down Expand Up @@ -167,23 +217,27 @@ public class VectorQueryContext(DbContextOptions options) : PoolableDbContext(op

public static async Task SeedAsync(VectorQueryContext context)
{
var vectorEntities = new VectorEntity[]
{
new() { Id = 1, Vector = new SqlVector<float>(new float[] { 1, 2, 3 }) },
new() { Id = 2, Vector = new SqlVector<float>(new float[] { 1, 2, 100 }) },
new() { Id = 3, Vector = new SqlVector<float>(new float[] { 1, 2, 1000 }) }
};
// SQL Server vector indexes require at least 100 rows.
var vectorEntities = Enumerable.Range(1, 100).Select(
i => new VectorEntity
{
Id = i,
Vector = new SqlVector<float>(new float[] { i * 0.01f, i * 0.02f, i * 0.03f })
}).ToList();

// Override specific rows we use in test assertions
vectorEntities[0] = new VectorEntity { Id = 1, Vector = new SqlVector<float>(new float[] { 1, 2, 3 }) };
vectorEntities[1] = new VectorEntity { Id = 2, Vector = new SqlVector<float>(new float[] { 1, 2, 100 }) };
vectorEntities[2] = new VectorEntity { Id = 3, Vector = new SqlVector<float>(new float[] { 1, 2, 1000 }) };

context.VectorEntities.AddRange(vectorEntities);
await context.SaveChangesAsync();

// TODO (#36384): Remove this once it's out of preview
await context.Database.ExecuteSqlAsync($"ALTER DATABASE SCOPED CONFIGURATION SET PREVIEW_FEATURES = ON");

await context.Database.ExecuteSqlAsync($"""
CREATE VECTOR INDEX vec_idx ON VectorEntities(Vector)
WITH (METRIC = 'Cosine', TYPE = 'DiskANN')
ON [PRIMARY];
WITH (METRIC = 'Cosine', TYPE = 'DiskANN');
""");
}
}
Expand Down
Loading