Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LINQ : Adds support for case-insensitive searches #1721

Merged
merged 14 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle);
}

if (methodCallExpression.Arguments.Count == 3)
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context);
SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[2]);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle, caseInsensitive);
}

return null;
Expand Down Expand Up @@ -155,6 +162,65 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
}
}

private class SqlStringWithComparisonVisitor : BuiltinFunctionVisitor
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
private static readonly HashSet<StringComparison> IgnoreCaseComparisons = new HashSet<StringComparison>(new[]
{
StringComparison.CurrentCultureIgnoreCase,
StringComparison.InvariantCultureIgnoreCase,
StringComparison.OrdinalIgnoreCase
});

public string SqlName { get; private set; }
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

public SqlStringWithComparisonVisitor(string sqlName)
{
this.SqlName = sqlName;
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
}

public static SqlScalarExpression GetCaseInsensitiveExpression(Expression expression)
{
ConstantExpression inputExpression = expression as ConstantExpression;
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
if (inputExpression?.Value is StringComparison comparisonValue)
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
if (IgnoreCaseComparisons.Contains(comparisonValue))
{
SqlBooleanLiteral literal = SqlBooleanLiteral.Create(true);
return SqlLiteralScalarExpression.Create(literal);
}
}

return null;

}
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
int argumentCount = methodCallExpression.Arguments.Count;
if (argumentCount == 0 || argumentCount > 2)
{
return null;
}

List<SqlScalarExpression> arguments = new List<SqlScalarExpression>();
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Object, context));
arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context));

if (argumentCount > 1)
{
arguments.Add(GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]));
}

return SqlFunctionCallScalarExpression.CreateBuiltin(this.SqlName, arguments.ToArray());
}

protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
return null;
}
}

private class StringVisitTrimEnd : SqlBuiltinFunctionVisitor
{
public StringVisitTrimEnd()
Expand Down Expand Up @@ -234,6 +300,15 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
return SqlBinaryScalarExpression.Create(SqlBinaryScalarOperatorKind.Equal, left, right);
}

if (methodCallExpression.Arguments.Count == 2)
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
SqlScalarExpression left = ExpressionToSql.VisitScalarExpression(methodCallExpression.Object, context);
SqlScalarExpression right = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]);

return SqlFunctionCallScalarExpression.CreateBuiltin("STRINGEQUALS", left, right, caseInsensitive);
}

return null;
}

Expand All @@ -257,12 +332,7 @@ static StringBuiltinFunctions()
},
{
"EndsWith",
new SqlBuiltinFunctionVisitor("ENDSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("ENDSWITH")
},
{
"IndexOf",
Expand Down Expand Up @@ -313,12 +383,7 @@ static StringBuiltinFunctions()
},
{
"StartsWith",
new SqlBuiltinFunctionVisitor("STARTSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("STARTSWITH")
},
{
"Substring",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Dynamic;
using System.Linq.Expressions;
using System.Threading.Tasks;

[TestClass]
Expand Down Expand Up @@ -712,6 +713,50 @@ public async Task LinqParameterisedTest3()

}

[TestMethod]
public async Task LinqCaseInsensitiveStringTest()
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
//Creating items for query.
IList<ToDoActivity> itemList = await ToDoActivity.CreateRandomItems(container: this.Container, pkCount: 2, perPKItemCount: 1, randomPartitionKey: true);
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

IOrderedQueryable<ToDoActivity> linqQueryable = this.Container.GetItemLinqQueryable<ToDoActivity>();

async Task TestSearch(Expression<Func<ToDoActivity, bool>> expression, string expectedMethod, bool shouldBeCaseInsensitive, int expectedResults)
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
string expectedQueryText = $"SELECT VALUE root FROM root WHERE {expectedMethod}(root[\"description\"], @param1{(shouldBeCaseInsensitive ? ", true" : "")})";

IArgumentProvider arguments = (IArgumentProvider)expression.Body;
int index = arguments.ArgumentCount > 2 ? 1 : 0;

string searchString = (arguments.GetArgument(index) as ConstantExpression).Value as string;

IQueryable<ToDoActivity> queryable = linqQueryable.Where(expression);

Dictionary<object, string> parameters = new Dictionary<object, string>();
parameters.Add(searchString, "@param1");

QueryDefinition queryDefinition = queryable.ToQueryDefinition(parameters);

string queryText = queryDefinition.ToSqlQuerySpec().QueryText;

Assert.AreEqual(expectedQueryText, queryText);

Assert.AreEqual(expectedResults, await queryable.CountAsync());
}

await TestSearch(x => x.description.StartsWith("create"), "STARTSWITH", false, 0);
await TestSearch(x => x.description.StartsWith("create", StringComparison.OrdinalIgnoreCase), "STARTSWITH", true, 2);
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

await TestSearch(x => x.description.EndsWith("activity"), "ENDSWITH", false, 0);
await TestSearch(x => x.description.EndsWith("activity", StringComparison.OrdinalIgnoreCase), "ENDSWITH", true, 2);

await TestSearch(x => x.description.Equals("createrandomtodoactivity", StringComparison.OrdinalIgnoreCase), "STRINGEQUALS", true, 2);

await TestSearch(x => x.description.Contains("todo"), "CONTAINS", false, 0);
await TestSearch(x => x.description.Contains("todo", StringComparison.OrdinalIgnoreCase), "CONTAINS", true, 2);

}

private class NumberLinqItem
{
public string id;
Expand Down Expand Up @@ -764,4 +809,12 @@ private void VerifyResponse<T>(
disableDiagnostics: disableDiagnostics);
}
}

static class StringTestExtensions
{
public static bool Contains(this string haystack, string needle, StringComparison stringComparison)
{
throw new NotImplementedException("Method for testing SQL translation only");
}
}
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
}