Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LINQ : Adds support for case-insensitive searches #1721

Merged
merged 14 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,20 @@ public StringVisitContains()
false,
new List<Type[]>()
{
new Type[]{typeof(string)}
new Type[]{typeof(string)},
new Type[]{typeof(char)}
})
{
}

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
if (methodCallExpression.Arguments.Count == 1)
j82w marked this conversation as resolved.
Show resolved Hide resolved
if (methodCallExpression.Arguments.Count == 2)
{
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Object, context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle);
}
else if (methodCallExpression.Arguments.Count == 2)
{
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle);

SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle, caseInsensitive);
}
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

return null;
Expand Down Expand Up @@ -161,6 +156,63 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
}
}

private sealed class SqlStringWithComparisonVisitor : BuiltinFunctionVisitor
{
private static readonly HashSet<StringComparison> IgnoreCaseComparisons = new HashSet<StringComparison>(new[]
{
StringComparison.CurrentCultureIgnoreCase,
StringComparison.InvariantCultureIgnoreCase,
StringComparison.OrdinalIgnoreCase
});

public string SqlName { get; }

public SqlStringWithComparisonVisitor(string sqlName)
{
this.SqlName = sqlName ?? throw new ArgumentNullException(nameof(sqlName));
}

public static SqlScalarExpression GetCaseInsensitiveExpression(Expression expression)
{
if (expression is ConstantExpression inputExpression
&& inputExpression.Value is StringComparison comparisonValue
&& IgnoreCaseComparisons.Contains(comparisonValue))
{
SqlBooleanLiteral literal = SqlBooleanLiteral.Create(true);
return SqlLiteralScalarExpression.Create(literal);
}

return null;
}
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
int argumentCount = methodCallExpression.Arguments.Count;
if (argumentCount == 0 || argumentCount > 2)
{
return null;
}

List<SqlScalarExpression> arguments = new List<SqlScalarExpression>
{
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Object, context),
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context)
};

if (argumentCount > 1)
{
arguments.Add(GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]));
}

return SqlFunctionCallScalarExpression.CreateBuiltin(this.SqlName, arguments.ToArray());
}

protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
return null;
}
}

private class StringVisitTrimEnd : SqlBuiltinFunctionVisitor
{
public StringVisitTrimEnd()
Expand Down Expand Up @@ -240,6 +292,15 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
return SqlBinaryScalarExpression.Create(SqlBinaryScalarOperatorKind.Equal, left, right);
}

if (methodCallExpression.Arguments.Count == 2)
jeffpardy marked this conversation as resolved.
Show resolved Hide resolved
{
SqlScalarExpression left = ExpressionToSql.VisitScalarExpression(methodCallExpression.Object, context);
SqlScalarExpression right = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]);

return SqlFunctionCallScalarExpression.CreateBuiltin("STRINGEQUALS", left, right, caseInsensitive);
}

return null;
}

Expand All @@ -263,12 +324,7 @@ static StringBuiltinFunctions()
},
{
"EndsWith",
new SqlBuiltinFunctionVisitor("ENDSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("ENDSWITH")
},
{
"IndexOf",
Expand Down Expand Up @@ -319,12 +375,7 @@ static StringBuiltinFunctions()
},
{
"StartsWith",
new SqlBuiltinFunctionVisitor("STARTSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("STARTSWITH")
},
{
"Substring",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
<Output>
<SqlQuery><![CDATA[
SELECT VALUE (root["StringField"] = "str")
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[Equals (case-insensitive)]]></Description>
<Expression><![CDATA[query.Select(doc => doc.StringField.Equals("STR", OrdinalIgnoreCase))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE STRINGEQUALS(root["StringField"], "STR", true)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
Expand All @@ -18,7 +29,7 @@ FROM root ]]></SqlQuery>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root["StringField"]
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
Expand All @@ -29,7 +40,7 @@ FROM root ]]></SqlQuery>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root["EnumerableField"][0]
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
</Results>
Loading