Skip to content

Commit

Permalink
feat(isthmus): add up-converting signature matchers, and coerce types…
Browse files Browse the repository at this point in the history
… to match
  • Loading branch information
bvolpato committed Feb 22, 2024
1 parent 0bdac49 commit b9574cb
Show file tree
Hide file tree
Showing 4 changed files with 551 additions and 43 deletions.
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,20 @@ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeExceptio
return super.visit(expr);
}
}

/**
* Subclass of ToTypeString that doesn't lose the context on the wildcard being used (for example,
* that can return any1, any2, etc, instead of only any, any).
*/
public static class ToTypeLiteralStringLossless extends ToTypeString {

public static final ToTypeLiteralStringLossless INSTANCE = new ToTypeLiteralStringLossless();

private ToTypeLiteralStringLossless() {}

@Override
public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException {
return expr.value().toLowerCase();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus.expression;

import com.github.bsideup.jabel.Desugar;
import com.google.common.collect.*;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
Expand All @@ -13,11 +14,14 @@
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -170,8 +174,52 @@ public boolean allowedArgCount(int count) {

private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(
SqlOperator operator, List<F> functions) {
// TODO: define up-converting matchers.
return (a, b) -> Optional.empty();
return (inputTypes, outputType) -> {
for (F function : functions) {
List<SimpleExtension.Argument> args = function.requiredArguments();

var bounds = ArgumentBounds.parse(function);

// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& bounds.within(inputTypes.size())
&& argumentsMatchType(inputTypes, args)) {
return Optional.of(function);
}
}

return Optional.empty();
};
}

private static boolean argumentsMatchType(
List<Type> inputTypes, List<SimpleExtension.Argument> args) {

Map<String, Set<Type>> wildcardToType = new HashMap<>();
for (int i = 0; i < inputTypes.size(); i++) {
Type givenType = inputTypes.get(i);
// Variadic arguments should match the last argument's type
SimpleExtension.ValueArgument wantType =
(SimpleExtension.ValueArgument) args.get(Integer.min(i, args.size() - 1));

if (!isMatch(givenType, wantType.value())) {
return false;
}

// Register the wildcard to type
if (wantType.value().isWildcard()) {
wildcardToType
.computeIfAbsent(
wantType.value().accept(ToTypeString.ToTypeLiteralStringLossless.INSTANCE),
k -> new HashSet<>())
.add(givenType);
}
}

// If all the types match, check if the wildcard types are compatible.
// Note: We could ignore the "any" key here if we decide to not match non-enumerated types.
return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
}

/**
Expand Down Expand Up @@ -289,12 +337,10 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
var outputType = typeConverter.toSubstrait(call.getType());

// try to do a direct match
var typeStrings =
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
var possibleKeys =
matchKeys(
call.getOperands().collect(java.util.stream.Collectors.toList()),
opTypes.stream()
.map(t -> t.accept(ToTypeString.INSTANCE))
.collect(java.util.stream.Collectors.toList()));
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);

var directMatchKey =
possibleKeys
Expand Down Expand Up @@ -327,34 +373,77 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}

if (singularInputType.isPresent()) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands()
.map(RexNode::getType)
.collect(java.util.stream.Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream()
.map(FunctionArg.class::cast)
.collect(java.util.stream.Collectors.toList()),
outputType));

Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
}
return Optional.empty();
}

private Optional<T> matchByLeastRestrictive(
C call, Type outputType, List<Expression> operands) {
RelDataType leastRestrictive =
typeFactory.leastRestrictive(
call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
if (leastRestrictive == null) {
return Optional.empty();
}
Type type = typeConverter.toSubstrait(leastRestrictive);
var out = singularInputType.get().tryMatch(type, outputType);

if (out.isPresent()) {
var declaration = out.get();
var coercedArgs = coerceArguments(operands, type);
declaration.validateOutputType(coercedArgs, outputType);
return Optional.of(
generateBinding(
call,
out.get(),
coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}
return Optional.empty();
}

private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {

// Convert the operands to the proper Substrait type
List<Type> allTypes =
call.getOperands()
.map(RexNode::getType)
.map(typeConverter::toSubstrait)
.collect(Collectors.toList());

// See if all the input types match the function
Optional<F> matchFunction = this.matcher.tryMatch(allTypes, outputType);
if (matchFunction.isPresent()) {
List<Expression> coerced =
Streams.zip(
operands.stream(),
call.getOperands(),
(a, b) -> {
Type type = typeConverter.toSubstrait(b.getType());
return coerceArgument(a, type);
})
.collect(Collectors.toList());

return Optional.of(
generateBinding(
call,
matchFunction.get(),
coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
outputType));
}

return Optional.empty();
}

protected String getName() {
return name;
}
Expand All @@ -374,18 +463,16 @@ public interface GenericCall {
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
* not for nullability or parameter mismatches.
*/
private List<Expression> coerceArguments(List<Expression> arguments, Type type) {

return arguments.stream()
.map(
a -> {
var typeMatches = isMatch(type, a.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, a);
}
return a;
})
.collect(java.util.stream.Collectors.toList());
private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
}

private static Expression coerceArgument(Expression argument, Type type) {
var typeMatches = isMatch(type, argument.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, argument);
}
return argument;
}

protected abstract T generateBinding(
Expand Down Expand Up @@ -428,4 +515,33 @@ private static boolean isMatch(ParameterizedType inputType, ParameterizedType ty
}
return inputType.accept(new IgnoreNullableAndParameters(type));
}

@Desugar
record ArgumentBounds(int lower, int upper) {

static ArgumentBounds parse(SimpleExtension.Function function) {
List<SimpleExtension.Argument> args = function.requiredArguments();

int lowerBoundRequiredArgs = args.size();
int upperBoundRequiredArgs = args.size();

if (function.variadic().isPresent()) {
SimpleExtension.VariadicBehavior variadicBehavior = function.variadic().get();
// Do not count variadic as a required argument, use the behavior.
lowerBoundRequiredArgs += 1 - variadicBehavior.getMin();

if (variadicBehavior.getMax().isEmpty()) {
upperBoundRequiredArgs = Integer.MAX_VALUE;
} else {
upperBoundRequiredArgs += variadicBehavior.getMax().getAsInt();
}
}

return new ArgumentBounds(lowerBoundRequiredArgs, upperBoundRequiredArgs);
}

boolean within(int count) {
return count >= lower && count <= upper;
}
}
}
Loading

0 comments on commit b9574cb

Please sign in to comment.