Skip to content

Commit

Permalink
feat: add support for empty list literals (#227)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: ExpressionVisitor now has a `visit(Expression.EmptyListLiteral)` method
BREAKING CHANGE: LiteralConstructorConverter constructor now requires a TypeConverter
  • Loading branch information
patientstreetlight authored Feb 16, 2024
1 parent f148bbb commit 2a98e3c
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.EmptyListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,25 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract class EmptyListLiteral implements Literal {
public abstract Type elementType();

@Override
public Type.ListType getType() {
return Type.withNullability(nullable()).list(elementType());
}

public static ImmutableExpression.EmptyListLiteral.Builder builder() {
return ImmutableExpression.EmptyListLiteral.builder();
}

@Override
public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class StructLiteral implements Literal {
public abstract List<Literal> fields();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ public static Expression.ListLiteral list(
return Expression.ListLiteral.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.EmptyListLiteral emptyList(boolean listNullable, Type elementType) {
return Expression.EmptyListLiteral.builder()
.elementType(elementType)
.nullable(listNullable)
.build();
}

public static Expression.StructLiteral struct(boolean nullable, Expression.Literal... values) {
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.ListLiteral expr) throws E;

R visit(Expression.EmptyListLiteral expr) throws E;

R visit(Expression.StructLiteral expr) throws E;

R visit(Expression.Switch expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,23 @@ public Expression visit(io.substrait.expression.Expression.ListLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.EmptyListLiteral expr)
throws RuntimeException {
return lit(
builder -> {
var protoListType = expr.getType().accept(typeProtoConverter);
builder
.setEmptyList(protoListType.getList())
// For empty lists, the Literal message's own nullable field should be ignored
// in favor of the nullability of the Type.List in the literal's
// empty_list field. But for safety we set the literal's nullable field
// to match in case any readers either look in the wrong location
// or want to verify that they are consistent.
.setNullable(expr.nullable());
});
}

@Override
public Expression visit(io.substrait.expression.Expression.StructLiteral expr) {
return lit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
literal.getList().getValuesList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case EMPTY_LIST -> {
// literal.getNullable() is intentionally ignored in favor of the nullability
// specified in the literal.getEmptyList() type.
var listType = protoTypeConverter.fromList(literal.getEmptyList());
yield ExpressionCreator.emptyList(listType.nullable(), listType.elementType());
}
default -> throw new IllegalStateException(
"Unexpected value: " + literal.getLiteralTypeCase());
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ public Optional<Expression> visit(Expression.ListLiteral expr) throws EXCEPTION
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.EmptyListLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public Type.Struct struct(Stream<? extends Type> types) {
.build();
}

public Type list(Type type) {
public Type.ListType list(Type type) {
return Type.ListType.builder().nullable(nullable).elementType(type).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Type from(io.substrait.proto.Type type) {
type.getStruct().getTypesList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case LIST -> n(type.getList().getNullability()).list(from(type.getList().getType()));
case LIST -> fromList(type.getList());
case MAP -> n(type.getMap().getNullability())
.map(from(type.getMap().getKey()), from(type.getMap().getValue()));
case USER_DEFINED -> {
Expand All @@ -61,6 +61,10 @@ public Type from(io.substrait.proto.Type type) {
};
}

public Type.ListType fromList(io.substrait.proto.Type.List list) {
return n(list.getNullability()).list(from(list.getType()));
}

public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) {
return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
new FieldSelectionConverter(typeConverter),
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
new LiteralConstructorConverter());
new LiteralConstructorConverter(typeConverter));
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -248,6 +249,13 @@ public RexNode visit(Expression.ListLiteral expr) throws RuntimeException {
return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.EmptyListLiteral expr) throws RuntimeException {
var calciteType = typeConverter.toCalcite(typeFactory, expr.getType());
return rexBuilder.makeCall(
calciteType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, Collections.emptyList());
}

@Override
public RexNode visit(Expression.MapLiteral expr) throws RuntimeException {
var args =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.TypeConverter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -18,29 +19,54 @@ public class LiteralConstructorConverter implements CallConverter {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(LiteralConstructorConverter.class);

private final TypeConverter typeConverter;

public LiteralConstructorConverter(TypeConverter typeConverter) {
this.typeConverter = typeConverter;
}

@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
SqlOperator operator = call.getOperator();
if (operator instanceof SqlArrayValueConstructor) {
return Optional.of(
ExpressionCreator.list(
false,
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList())));
return call.getOperands().isEmpty()
? toEmptyListLiteral(call)
: toNonEmptyListLiteral(call, topLevelConverter);
} else if (operator instanceof SqlMapValueConstructor) {
List<Expression.Literal> literals =
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList());
Map<Expression.Literal, Expression.Literal> items = new HashMap<>();
assert literals.size() % 2 == 0;
for (int i = 0; i < literals.size(); i += 2) {
items.put(literals.get(i), literals.get(i + 1));
}
return Optional.of(ExpressionCreator.map(false, items));
return toMapLiteral(call, topLevelConverter);
}
return Optional.empty();
}

private Optional<Expression> toMapLiteral(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
List<Expression.Literal> literals =
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList());
Map<Expression.Literal, Expression.Literal> items = new HashMap<>();
assert literals.size() % 2 == 0;
for (int i = 0; i < literals.size(); i += 2) {
items.put(literals.get(i), literals.get(i + 1));
}
return Optional.of(ExpressionCreator.map(false, items));
}

private Optional<Expression> toNonEmptyListLiteral(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
return Optional.of(
ExpressionCreator.list(
call.getType().isNullable(),
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList())));
}

private Optional<Expression> toEmptyListLiteral(RexCall call) {
var calciteElementType = call.getType().getComponentType();
var substraitElementType = typeConverter.toSubstrait(calciteElementType);
return Optional.of(
ExpressionCreator.emptyList(call.getType().isNullable(), substraitElementType));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.substrait.isthmus;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.ExpressionCreator;
import io.substrait.relation.Rel;
import io.substrait.type.TypeCreator;
import java.util.List;
import org.junit.jupiter.api.Test;

public class EmptyArrayLiteralTest extends PlanTestBase {
private static final TypeCreator N = TypeCreator.of(true);

private final SubstraitBuilder b = new SubstraitBuilder(extensions);

@Test
void emptyArrayLiteral() {
var colType = N.I8;
var emptyListLiteral = ExpressionCreator.emptyList(false, N.I8);
var rel =
b.project(
input -> List.of(emptyListLiteral),
Rel.Remap.offset(1, 1),
b.namedScan(List.of("t"), List.of("col"), List.of(colType)));
assertFullRoundTrip(rel);
}

@Test
void nullableEmptyArrayLiteral() {
var colType = N.I8;
var emptyListLiteral = ExpressionCreator.emptyList(true, N.I8);
var rel =
b.project(
input -> List.of(emptyListLiteral),
Rel.Remap.offset(1, 1),
b.namedScan(List.of("t"), List.of("col"), List.of(colType)));
assertFullRoundTrip(rel);
}
}

0 comments on commit 2a98e3c

Please sign in to comment.