Skip to content

Commit

Permalink
Cosmos: strip implicit casts to allow vector search over arrays (#34437)
Browse files Browse the repository at this point in the history
Fixes #34402

Co-authored-by: Arthur Vickers <ajcvickers@hotmail.com>
  • Loading branch information
roji and ajcvickers committed Aug 21, 2024
1 parent fc1ebc5 commit c308af5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -820,18 +820,28 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
ExpressionType.Negate or ExpressionType.NegateChecked
=> sqlExpressionFactory.Negate(sqlOperand!),

ExpressionType.Convert or ExpressionType.ConvertChecked
when operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is
// passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ
// artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a
// type change.
ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory<float> (for vector search)
|| (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" }
&& IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType()))
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
// Object convert needs to be converted to explicit cast when mismatching types
// But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
// But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object)
=> sqlOperand!,

_ => QueryCompilationContext.NotTranslatedExpression
};

static bool IsReadOnlyMemory(Type type)
=> type is { IsGenericType: true, IsGenericTypeDefinition: false }
&& type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>);
}

/// <inheritdoc />
Expand Down
35 changes: 34 additions & 1 deletion src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Storage.Json;
using Newtonsoft.Json.Linq;

namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;

Expand Down Expand Up @@ -63,7 +64,8 @@ public CosmosVectorTypeMapping(CosmosTypeMapping mapping, CosmosVectorType vecto
: this(
new CoreTypeMappingParameters(
mapping.ClrType,
converter: mapping.Converter,
// This is a hack to allow both arrays and ROM types without different function overloads or type mappings.
converter: mapping.Converter?.GetType() == typeof(BytesToStringConverter) ? null : mapping.Converter,
mapping.Comparer,
mapping.KeyComparer,
elementMapping: mapping.ElementTypeMapping,
Expand Down Expand Up @@ -114,4 +116,35 @@ public override CoreTypeMapping WithComposedConverter(
/// </summary>
protected override CoreTypeMapping Clone(CoreTypeMappingParameters parameters)
=> new CosmosVectorTypeMapping(parameters, VectorType);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override JToken? GenerateJToken(object? value)
{
// This is a hack to allow both arrays and ROM types without different function overloads or type mappings.
var type = value?.GetType();
if (type?.IsArray is false)
{
if (type == typeof(ReadOnlyMemory<byte>))
{
value = ((ReadOnlyMemory<byte>)value!).ToArray();
}
else if (type == typeof(ReadOnlyMemory<sbyte>))
{
value = ((ReadOnlyMemory<sbyte>)value!).ToArray();
}
else if (type == typeof(ReadOnlyMemory<float>))
{
value = ((ReadOnlyMemory<float>)value!).ToArray();
}
}

return value == null
? null
: JToken.FromObject(value, CosmosClientWrapper.Serializer);
}
}
67 changes: 43 additions & 24 deletions test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ public virtual async Task Query_for_vector_distance_bytes_array()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["BytesArray"]
@__p_1='[2,1,4,3,5,2,5,7,3,1]'

SELECT VALUE VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
FROM root c
""");
}
Expand All @@ -119,17 +122,20 @@ public virtual async Task Query_for_vector_distance_singles_array()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(
e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["SinglesArray"]
@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]'

SELECT VALUE VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'})
FROM root c
""");
}
Expand Down Expand Up @@ -207,14 +213,20 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

Assert.Equal(3, booksFromStore.Count);
AssertSql(
);
"""
@__p_1='[2,1,4,6,5,2,5,7,3,1]'

SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
""");
}

[ConditionalFact]
Expand All @@ -223,13 +235,20 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
Assert.Equal(3, booksFromStore.Count);
AssertSql(
"""
@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78]'

AssertSql();
SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'})
""");
}

[ConditionalFact]
Expand Down

0 comments on commit c308af5

Please sign in to comment.