diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index 1b7138317..c8d693855 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -178,4 +178,20 @@ public String visit(ParameterizedType.StringLiteral expr) throws RuntimeExceptio return super.visit(expr); } } + + /** + * Subclass of ToTypeString that doesn't lose the context on the wildcard being used (for example, + * that can return any1, any2, etc, instead of only any, any). + */ + public static class ToTypeLiteralStringLossless extends ToTypeString { + + public static final ToTypeLiteralStringLossless INSTANCE = new ToTypeLiteralStringLossless(); + + private ToTypeLiteralStringLossless() {} + + @Override + public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { + return expr.value().toLowerCase(); + } + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 10b3dd1df..03e4d0190 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -1,5 +1,6 @@ package io.substrait.isthmus.expression; +import com.github.bsideup.jabel.Desugar; import com.google.common.collect.*; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -13,11 +14,14 @@ import io.substrait.util.Util; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.IdentityHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -170,8 +174,52 @@ public boolean allowedArgCount(int count) { private static SignatureMatcher getSignatureMatcher( SqlOperator operator, List functions) { - // TODO: define up-converting matchers. - return (a, b) -> Optional.empty(); + return (inputTypes, outputType) -> { + for (F function : functions) { + List args = function.requiredArguments(); + + var bounds = ArgumentBounds.parse(function); + + // Make sure that arguments & return are within bounds and match the types + if (function.returnType() instanceof ParameterizedType + && isMatch(outputType, (ParameterizedType) function.returnType()) + && bounds.within(inputTypes.size()) + && argumentsMatchType(inputTypes, args)) { + return Optional.of(function); + } + } + + return Optional.empty(); + }; + } + + private static boolean argumentsMatchType( + List inputTypes, List args) { + + Map> wildcardToType = new HashMap<>(); + for (int i = 0; i < inputTypes.size(); i++) { + Type givenType = inputTypes.get(i); + // Variadic arguments should match the last argument's type + SimpleExtension.ValueArgument wantType = + (SimpleExtension.ValueArgument) args.get(Integer.min(i, args.size() - 1)); + + if (!isMatch(givenType, wantType.value())) { + return false; + } + + // Register the wildcard to type + if (wantType.value().isWildcard()) { + wildcardToType + .computeIfAbsent( + wantType.value().accept(ToTypeString.ToTypeLiteralStringLossless.INSTANCE), + k -> new HashSet<>()) + .add(givenType); + } + } + + // If all the types match, check if the wildcard types are compatible. + // Note: We could ignore the "any" key here if we decide to not match non-enumerated types. + return wildcardToType.values().stream().allMatch(s -> s.size() == 1); } /** @@ -289,12 +337,10 @@ public Optional attemptMatch(C call, Function topLevelCo var outputType = typeConverter.toSubstrait(call.getType()); // try to do a direct match + var typeStrings = + opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList()); var possibleKeys = - matchKeys( - call.getOperands().collect(java.util.stream.Collectors.toList()), - opTypes.stream() - .map(t -> t.accept(ToTypeString.INSTANCE)) - .collect(java.util.stream.Collectors.toList())); + matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings); var directMatchKey = possibleKeys @@ -327,34 +373,77 @@ public Optional attemptMatch(C call, Function topLevelCo } if (singularInputType.isPresent()) { - RelDataType leastRestrictive = - typeFactory.leastRestrictive( - call.getOperands() - .map(RexNode::getType) - .collect(java.util.stream.Collectors.toList())); - if (leastRestrictive == null) { - return Optional.empty(); + Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands); + if (leastRestrictive.isPresent()) { + return leastRestrictive; } - Type type = typeConverter.toSubstrait(leastRestrictive); - var out = singularInputType.get().tryMatch(type, outputType); - - if (out.isPresent()) { - var declaration = out.get(); - var coercedArgs = coerceArguments(operands, type); - declaration.validateOutputType(coercedArgs, outputType); - return Optional.of( - generateBinding( - call, - out.get(), - coercedArgs.stream() - .map(FunctionArg.class::cast) - .collect(java.util.stream.Collectors.toList()), - outputType)); + + Optional coerced = matchCoerced(call, outputType, operands); + if (coerced.isPresent()) { + return coerced; } } return Optional.empty(); } + private Optional matchByLeastRestrictive( + C call, Type outputType, List operands) { + RelDataType leastRestrictive = + typeFactory.leastRestrictive( + call.getOperands().map(RexNode::getType).collect(Collectors.toList())); + if (leastRestrictive == null) { + return Optional.empty(); + } + Type type = typeConverter.toSubstrait(leastRestrictive); + var out = singularInputType.get().tryMatch(type, outputType); + + if (out.isPresent()) { + var declaration = out.get(); + var coercedArgs = coerceArguments(operands, type); + declaration.validateOutputType(coercedArgs, outputType); + return Optional.of( + generateBinding( + call, + out.get(), + coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()), + outputType)); + } + return Optional.empty(); + } + + private Optional matchCoerced(C call, Type outputType, List operands) { + + // Convert the operands to the proper Substrait type + List allTypes = + call.getOperands() + .map(RexNode::getType) + .map(typeConverter::toSubstrait) + .collect(Collectors.toList()); + + // See if all the input types match the function + Optional matchFunction = this.matcher.tryMatch(allTypes, outputType); + if (matchFunction.isPresent()) { + List coerced = + Streams.zip( + operands.stream(), + call.getOperands(), + (a, b) -> { + Type type = typeConverter.toSubstrait(b.getType()); + return coerceArgument(a, type); + }) + .collect(Collectors.toList()); + + return Optional.of( + generateBinding( + call, + matchFunction.get(), + coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()), + outputType)); + } + + return Optional.empty(); + } + protected String getName() { return name; } @@ -374,18 +463,16 @@ public interface GenericCall { * Coerced types according to an expected output type. Coercion is only done for type mismatches, * not for nullability or parameter mismatches. */ - private List coerceArguments(List arguments, Type type) { - - return arguments.stream() - .map( - a -> { - var typeMatches = isMatch(type, a.getType()); - if (!typeMatches) { - return ExpressionCreator.cast(type, a); - } - return a; - }) - .collect(java.util.stream.Collectors.toList()); + private static List coerceArguments(List arguments, Type type) { + return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList()); + } + + private static Expression coerceArgument(Expression argument, Type type) { + var typeMatches = isMatch(type, argument.getType()); + if (!typeMatches) { + return ExpressionCreator.cast(type, argument); + } + return argument; } protected abstract T generateBinding( @@ -428,4 +515,33 @@ private static boolean isMatch(ParameterizedType inputType, ParameterizedType ty } return inputType.accept(new IgnoreNullableAndParameters(type)); } + + @Desugar + record ArgumentBounds(int lower, int upper) { + + static ArgumentBounds parse(SimpleExtension.Function function) { + List args = function.requiredArguments(); + + int lowerBoundRequiredArgs = args.size(); + int upperBoundRequiredArgs = args.size(); + + if (function.variadic().isPresent()) { + SimpleExtension.VariadicBehavior variadicBehavior = function.variadic().get(); + // Do not count variadic as a required argument, use the behavior. + lowerBoundRequiredArgs += 1 - variadicBehavior.getMin(); + + if (variadicBehavior.getMax().isEmpty()) { + upperBoundRequiredArgs = Integer.MAX_VALUE; + } else { + upperBoundRequiredArgs += variadicBehavior.getMax().getAsInt(); + } + } + + return new ArgumentBounds(lowerBoundRequiredArgs, upperBoundRequiredArgs); + } + + boolean within(int count) { + return count >= lower && count <= upper; + } + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 43679fa75..b613b50ff 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.google.protobuf.Any; import io.substrait.dsl.SubstraitBuilder; @@ -24,11 +25,13 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.junit.jupiter.api.Test; @@ -92,9 +95,25 @@ public RelDataType toCalcite(Type.UserDefined type) { } }; + static final RelDataType varcharType = + new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.VARCHAR); + static final RelDataType varcharArrayType = + new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createArrayType(varcharType, -1); + // Define additional mapping signatures for the custom scalar functions final List additionalScalarSignatures = - List.of(FunctionMappings.s(customScalarFn), FunctionMappings.s(toBType)); + List.of( + FunctionMappings.s(customScalarFn), + FunctionMappings.s(customScalarAnyFn), + FunctionMappings.s(customScalarAnyToAnyFn), + FunctionMappings.s(customScalarAny1Any1ToAny1Fn), + FunctionMappings.s(customScalarAny1Any2ToAny2Fn), + FunctionMappings.s(customScalarListAnyFn), + FunctionMappings.s(customScalarListAnyAndAnyFn), + FunctionMappings.s(customScalarListStringFn), + FunctionMappings.s(customScalarListStringAndAnyFn), + FunctionMappings.s(customScalarListStringAndAnyVariadicFn), + FunctionMappings.s(toBType)); static final SqlFunction customScalarFn = new SqlFunction( @@ -105,6 +124,84 @@ public RelDataType toCalcite(Type.UserDefined type) { null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarAnyFn = + new SqlFunction( + "custom_scalar_any", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARCHAR), + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + static final SqlFunction customScalarAnyToAnyFn = + new SqlFunction( + "custom_scalar_any_to_any", + SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG0_NULLABLE, + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarAny1Any1ToAny1Fn = + new SqlFunction( + "custom_scalar_any1any1_to_any1", + SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG0_NULLABLE, + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarAny1Any2ToAny2Fn = + new SqlFunction( + "custom_scalar_any1any2_to_any2", + SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG1_NULLABLE, + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + static final SqlFunction customScalarListAnyFn = + new SqlFunction( + "custom_scalar_listany_to_listany", + SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG0_NULLABLE, + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + static final SqlFunction customScalarListAnyAndAnyFn = + new SqlFunction( + "custom_scalar_listany_any_to_listany", + SqlKind.OTHER_FUNCTION, + ReturnTypes.ARG0_NULLABLE, + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + static final SqlFunction customScalarListStringFn = + new SqlFunction( + "custom_scalar_liststring_to_liststring", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(varcharArrayType), + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + static final SqlFunction customScalarListStringAndAnyFn = + new SqlFunction( + "custom_scalar_liststring_any_to_liststring", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(varcharArrayType), + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarListStringAndAnyVariadicFn = + new SqlFunction( + "custom_scalar_liststring_anyvariadic_to_liststring", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(varcharArrayType), + null, + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction toBType = new SqlFunction( "to_b_type", @@ -198,6 +295,210 @@ void customScalarFunctionRoundtrip() { assertEquals(rel, relReturned); } + @Test + void customScalarAnyFunctionRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, "custom_scalar_any:any", R.STRING, b.fieldReference(input, 0))), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarAnyToAnyFunctionRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_any_to_any:any", + R.FP64, + b.fieldReference(input, 0))), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(R.FP64))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarAny1Any1ToAny1FunctionRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_any1any1_to_any1:any_any", + R.FP64, + b.fieldReference(input, 0), + b.fieldReference(input, 1))), + b.remap(2), + b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.FP64))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarAny1Any1ToAny1FunctionMismatch() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_any1any1_to_any1:any_any", + R.FP64, + b.fieldReference(input, 0), + b.fieldReference(input, 1))), + b.remap(2), + b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); + + assertThrows( + IllegalArgumentException.class, + () -> { + RelNode calciteRel = substraitToCalcite.convert(rel); + calciteToSubstrait.apply(calciteRel); + }, + "Unable to convert call custom_scalar_any1any1_to_any1(fp64, string)"); + } + + @Test + void customScalarAny1Any2ToAny2FunctionRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_any1any2_to_any2:any_any", + R.STRING, + b.fieldReference(input, 0), + b.fieldReference(input, 1))), + b.remap(2), + b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarListAnyRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_listany_to_listany:list", + R.list(R.I64), + b.fieldReference(input, 0))), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.I64)))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarListAnyAndAnyRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_listany_any_to_listany:list_any", + R.list(R.STRING), + b.fieldReference(input, 0), + b.fieldReference(input, 1))), + b.remap(2), + b.namedScan( + List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarListStringRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_liststring_to_liststring:list", + R.list(R.STRING), + b.fieldReference(input, 0))), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarListStringAndAnyRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_liststring_any_to_liststring:list_any", + R.list(R.STRING), + b.fieldReference(input, 0), + b.fieldReference(input, 1))), + b.remap(2), + b.namedScan( + List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + @Test + void customScalarListStringAndAnyVariadicRoundtrip() { + Rel rel = + b.project( + input -> + List.of( + b.scalarFn( + NAMESPACE, + "custom_scalar_liststring_anyvariadic_to_liststring:list_any", + R.list(R.STRING), + b.fieldReference(input, 0), + b.fieldReference(input, 1), + b.fieldReference(input, 2), + b.fieldReference(input, 3))), + b.remap(4), + b.namedScan( + List.of("example"), + List.of("a", "b", "c", "d"), + List.of(R.list(R.STRING), R.STRING, R.STRING, R.STRING))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + var relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + @Test void customAggregateFunctionRoundtrip() { // CREATE TABLE example (a BIGINT) diff --git a/isthmus/src/test/resources/extensions/functions_custom.yaml b/isthmus/src/test/resources/extensions/functions_custom.yaml index 067102949..50da6ada8 100644 --- a/isthmus/src/test/resources/extensions/functions_custom.yaml +++ b/isthmus/src/test/resources/extensions/functions_custom.yaml @@ -6,12 +6,44 @@ types: scalar_functions: - name: "custom_scalar" - description: "a custom scalar functions" + description: "a custom scalar function" impls: - args: - name: some_arg value: string return: string + - name: "custom_scalar_any" + description: "a custom scalar function that takes any argument input" + impls: + - args: + - name: some_arg + value: any1 + return: string + - name: "custom_scalar_any_to_any" + description: "a custom scalar function that takes any argument input and returns the same type" + impls: + - args: + - name: some_arg + value: any1 + return: any1 + - name: "custom_scalar_any1any1_to_any1" + description: "a custom scalar function that takes two any1 inputs and returns the same type" + impls: + - args: + - name: some_arg + value: any1 + - name: another_arg + value: any1 + return: any1 + - name: "custom_scalar_any1any2_to_any2" + description: "a custom scalar function that takes any1 and any2 inputs and returns any2" + impls: + - args: + - name: some_arg + value: any1 + - name: another_arg + value: any2 + return: any2 - name: "to_b_type" description: "converts a nullable a_type to a b_type" impls: @@ -19,6 +51,49 @@ scalar_functions: - name: arg1 value: u!a_type? return: u!b_type + - name: "custom_scalar_listany_to_listany" + description: "custom function that takes list of any" + impls: + - args: + - name: list + value: LIST + return: LIST + - name: "custom_scalar_listany_any_to_listany" + description: "custom function that takes list of any and an any scalar" + impls: + - args: + - name: list + value: LIST + - name: val + value: any1 + return: LIST + - name: "custom_scalar_liststring_to_liststring" + description: "custom function that takes list of string" + impls: + - args: + - name: list + value: LIST + return: LIST + - name: "custom_scalar_liststring_any_to_liststring" + description: "custom function that takes list of string and an any scalar" + impls: + - args: + - name: list + value: LIST + - name: val + value: any1 + return: LIST + - name: "custom_scalar_liststring_anyvariadic_to_liststring" + description: "custom function that takes list of string and an any scalar" + impls: + - args: + - name: list + value: LIST + - name: val + value: any1 + variadic: + min: 1 + return: LIST aggregate_functions: - name: "custom_aggregate"