From c308af56dfa84181527fb426d6c6a7114519a59d Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 21 Aug 2024 18:49:18 +0200 Subject: [PATCH] Cosmos: strip implicit casts to allow vector search over arrays (#34437) Fixes #34402 Co-authored-by: Arthur Vickers --- .../CosmosSqlTranslatingExpressionVisitor.cs | 18 +++-- .../Internal/CosmosVectorTypeMapping.cs | 35 +++++++++- .../VectorSearchCosmosTest.cs | 67 ++++++++++++------- 3 files changed, 91 insertions(+), 29 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index e71bf98a6a1..9d803df02c6 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -820,18 +820,28 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) ExpressionType.Negate or ExpressionType.NegateChecked => sqlExpressionFactory.Negate(sqlOperand!), - ExpressionType.Convert or ExpressionType.ConvertChecked - when operand.Type.IsInterface - && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + // Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is + // passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ + // artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a + // type change. + ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs + when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + // We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory (for vector search) + || (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" } + && IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType())) || unaryExpression.Type.UnwrapNullableType() == operand.Type || unaryExpression.Type.UnwrapNullableType() == typeof(Enum) // Object convert needs to be converted to explicit cast when mismatching types - // But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types + // But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types || unaryExpression.Type == typeof(object) => sqlOperand!, _ => QueryCompilationContext.NotTranslatedExpression }; + + static bool IsReadOnlyMemory(Type type) + => type is { IsGenericType: true, IsGenericTypeDefinition: false } + && type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); } /// diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs index ab83b515244..e9279c00e62 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Storage.Json; +using Newtonsoft.Json.Linq; namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; @@ -63,7 +64,8 @@ public CosmosVectorTypeMapping(CosmosTypeMapping mapping, CosmosVectorType vecto : this( new CoreTypeMappingParameters( mapping.ClrType, - converter: mapping.Converter, + // This is a hack to allow both arrays and ROM types without different function overloads or type mappings. + converter: mapping.Converter?.GetType() == typeof(BytesToStringConverter) ? null : mapping.Converter, mapping.Comparer, mapping.KeyComparer, elementMapping: mapping.ElementTypeMapping, @@ -114,4 +116,35 @@ public override CoreTypeMapping WithComposedConverter( /// protected override CoreTypeMapping Clone(CoreTypeMappingParameters parameters) => new CosmosVectorTypeMapping(parameters, VectorType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override JToken? GenerateJToken(object? value) + { + // This is a hack to allow both arrays and ROM types without different function overloads or type mappings. + var type = value?.GetType(); + if (type?.IsArray is false) + { + if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + else if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + else if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + } + + return value == null + ? null + : JToken.FromObject(value, CosmosClientWrapper.Serializer); + } } diff --git a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs index c67ca8c3cba..38e32f1bccc 100644 --- a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs @@ -99,16 +99,19 @@ public virtual async Task Query_for_vector_distance_bytes_array() await using var context = CreateContext(); var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync()); + var booksFromStore = await context + .Set() + .Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); - // Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); + Assert.Equal(3, booksFromStore.Count); + Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); AssertSql( """ -SELECT VALUE c["BytesArray"] +@__p_1='[2,1,4,3,5,2,5,7,3,1]' + +SELECT VALUE VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) FROM root c """); } @@ -119,17 +122,20 @@ public virtual async Task Query_for_vector_distance_singles_array() await using var context = CreateContext(); var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set() - .Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync()); + var booksFromStore = await context + .Set() + .Select( + e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); - // Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); + Assert.Equal(3, booksFromStore.Count); + Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); AssertSql( """ -SELECT VALUE c["SinglesArray"] +@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]' + +SELECT VALUE VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'}) FROM root c """); } @@ -207,14 +213,20 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy() await using var context = CreateContext(); var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync()); - - // Assert.Equal(3, booksFromStore.Count); + var booksFromStore = await context + .Set() + .OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)) + .ToListAsync(); + Assert.Equal(3, booksFromStore.Count); AssertSql( -); + """ +@__p_1='[2,1,4,6,5,2,5,7,3,1]' + +SELECT VALUE c +FROM root c +ORDER BY VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) +"""); } [ConditionalFact] @@ -223,13 +235,20 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy() await using var context = CreateContext(); var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync()); + var booksFromStore = await context + .Set() + .OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); + Assert.Equal(3, booksFromStore.Count); + AssertSql( + """ +@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78]' - AssertSql(); +SELECT VALUE c +FROM root c +ORDER BY VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'}) +"""); } [ConditionalFact]