Skip to content

Commit

Permalink
feat(isthmus): improve signature matching for functions with wildcard…
Browse files Browse the repository at this point in the history
… arguments (#226)

Add up-converting signature matchers to assist in matching functions containing `any` arguments
  • Loading branch information
bvolpato authored Feb 24, 2024
1 parent a5e1a21 commit ec1887c
Show file tree
Hide file tree
Showing 4 changed files with 603 additions and 44 deletions.
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,23 @@ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeExceptio
return super.visit(expr);
}
}

/**
* {@link ToTypeString} emits the string `any` for all wildcard any types, even if they have
* numeric suffixes (i.e. `any1`, `any2`, etc).
*
* <p>These suffixes are needed to correctly perform function matching based on arguments. This
* subclass retains the numerics suffixes when emitting type strings for this.
*/
public static class ToTypeLiteralStringLossless extends ToTypeString {

public static final ToTypeLiteralStringLossless INSTANCE = new ToTypeLiteralStringLossless();

private ToTypeLiteralStringLossless() {}

@Override
public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException {
return expr.value().toLowerCase();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package io.substrait.isthmus.expression;

import com.google.common.collect.*;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
Expand All @@ -13,11 +19,14 @@
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -170,8 +179,64 @@ public boolean allowedArgCount(int count) {

private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(
SqlOperator operator, List<F> functions) {
// TODO: define up-converting matchers.
return (a, b) -> Optional.empty();
return (inputTypes, outputType) -> {
for (F function : functions) {
List<SimpleExtension.Argument> args = function.requiredArguments();
// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& inputTypesSatisfyDefinedArguments(inputTypes, args)) {
return Optional.of(function);
}
}

return Optional.empty();
};
}

/**
* Checks to see if the given input types satisfy the function arguments given. Checks that
*
* <ul>
* <li>Variadic arguments all have the same input type
* <li>Matched wildcard arguments (i.e.`any`, `any1`, `any2`, etc) all have the same input
* type
* </ul>
*
* @param inputTypes input types to check against arguments
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
*/
private static boolean inputTypesSatisfyDefinedArguments(
List<Type> inputTypes, List<SimpleExtension.Argument> args) {

Map<String, Set<Type>> wildcardToType = new HashMap<>();
for (int i = 0; i < inputTypes.size(); i++) {
Type givenType = inputTypes.get(i);
SimpleExtension.ValueArgument wantType =
(SimpleExtension.ValueArgument)
args.get(
// Variadic arguments should match the last argument's type
Integer.min(i, args.size() - 1));

if (!isMatch(givenType, wantType.value())) {
return false;
}

// Register the wildcard to type
if (wantType.value().isWildcard()) {
wildcardToType
.computeIfAbsent(
wantType.value().accept(ToTypeString.ToTypeLiteralStringLossless.INSTANCE),
k -> new HashSet<>())
.add(givenType);
}
}

// If all the types match, check if the wildcard types are compatible.
// TODO: Determine if non-enumerated wildcard types (i.e. `any` as opposed to `any1`) need to
// have the same type.
return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
}

/**
Expand Down Expand Up @@ -289,12 +354,10 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
var outputType = typeConverter.toSubstrait(call.getType());

// try to do a direct match
var typeStrings =
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
var possibleKeys =
matchKeys(
call.getOperands().collect(java.util.stream.Collectors.toList()),
opTypes.stream()
.map(t -> t.accept(ToTypeString.INSTANCE))
.collect(java.util.stream.Collectors.toList()));
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);

var directMatchKey =
possibleKeys
Expand Down Expand Up @@ -327,34 +390,77 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}

if (singularInputType.isPresent()) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands()
.map(RexNode::getType)
.collect(java.util.stream.Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream()
.map(FunctionArg.class::cast)
.collect(java.util.stream.Collectors.toList()),
outputType));

Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
}
return Optional.empty();
}

private Optional<T> matchByLeastRestrictive(
C call, Type outputType, List<Expression> operands) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}
return Optional.empty();
}

private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {

// Convert the operands to the proper Substrait type
List<Type> allTypes =
call.getOperands()
.map(RexNode::getType)
.map(typeConverter::toSubstrait)
.collect(Collectors.toList());

// See if all the input types match the function
Optional<F> matchFunction = this.matcher.tryMatch(allTypes, outputType);
if (matchFunction.isPresent()) {
List<Expression> coerced =
Streams.zip(
operands.stream(),
call.getOperands(),
(a, b) -> {
Type type = typeConverter.toSubstrait(b.getType());
return coerceArgument(a, type);
})
.collect(Collectors.toList());

return Optional.of(
generateBinding(
call,
matchFunction.get(),
coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}

return Optional.empty();
}

protected String getName() {
return name;
}
Expand All @@ -374,18 +480,16 @@ public interface GenericCall {
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
* not for nullability or parameter mismatches.
*/
private List<Expression> coerceArguments(List<Expression> arguments, Type type) {

return arguments.stream()
.map(
a -> {
var typeMatches = isMatch(type, a.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, a);
}
return a;
})
.collect(java.util.stream.Collectors.toList());
private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
}

private static Expression coerceArgument(Expression argument, Type type) {
var typeMatches = isMatch(type, argument.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, argument);
}
return argument;
}

protected abstract T generateBinding(
Expand Down
Loading

0 comments on commit ec1887c

Please sign in to comment.