Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: add isNaN and notNaN predicates #1747

Merged
merged 6 commits into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
do not accept NaN in Expressions, or mismatch type
  • Loading branch information
yyanyy committed Nov 26, 2020
commit d5e666399663685d5fc583a692c57c090de74de3
25 changes: 25 additions & 0 deletions api/src/main/java/org/apache/iceberg/expressions/Expressions.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.NaNUtil;

/**
* Factory methods for creating {@link Expression expressions}.
Expand Down Expand Up @@ -140,50 +141,62 @@ public static <T> UnboundPredicate<T> notNaN(UnboundTerm<T> expr) {
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update the equality predicate to catch NaN and rewrite to isNaN?

Copy link
Contributor Author

@yyanyy yyanyy Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally thought to update SparkFilters to do the rewrite, but this is a much better place. Thanks for the suggestion!

Edit: what do you think about doing rewriting eq within UnboundPredicate? And for rewriting in, I was thinking to let Expressions.in to do the rewrite logic of or(isNaN, in)/and(notNaN, notIn), but that means it will return Expression instead of Predicate; does that align with your thinking?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not fully understand what you mean by "rewrite logic of or(isNaN, in)/and(notNaN, notIn)" when you talk about rewriting in. Can you give some examples of what predicate are you trying to support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now since we want to handle NaN in in predicate, for query in(1,2, NaN) to avoid checking for NaN in in evaluation all the time we can transform that to in(1,2) or isNaN, and notIn(1,2,NaN) to notIn(1, 2) and notNaN. The problem is where to do that, since in and notIn are both predicate, and if we are extending them we are transforming a predicate (simpler form) to an expression (complex form), and I think there's no such case in the current code base, and it would touch a lot of existing test cases for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so it's what I thought, just a bit confused by the notation.

So for eq, what is the benefit of doing it in UnboundedPredicate versus just rewriting it in the Expressions?

For in, I think it is a more complex question.We need to figure out:

  1. should syntax like in(1,2,NaN) be supported, given it can be written as is_nan or in(1,2) on client side
  2. if so, Expressions.in should return Expression as you said, which looks fine to me because the only caller SparkFilters.convert also returns an Expression in the end.
  3. maybe we should tackle this in another PR to keep changes concise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick response! Yeah I think the amount of change to method return type/tests is not a concern now. I just wasn't entirely sure if rewriting eq to isNan in Expressions will help with catching problems early (comparing to rewriting in UnboundPredicate), since it seems to me that the related code will not have a chance to throw any exception until bind() is called?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it isn't much earlier in that case. Maybe that actually exposes a problem with rewriting, too.

Expressions.equal("c", Double.NaN) if c is not a floating point column would result in isNaN, which should be rejected while binding expressions. You could argue that it should rewrite to alwaysFalse instead following the same logic as Expressions.equal("intCol", Long.MAX_VALUE) -- it can't be true.

I think that it would be better to be strict and reject binding in that case because something is clearly wrong. I think a lot of the time, that kind of error would happen when columns are misaligned or predicates are incorrectly converted.

If the result of those errors is just to fail in expression binding, then why rewrite at all? Maybe we should just reject NaN in any predicate and force people to explicitly use isNaN and notNaN. That way we do throw an exception much earlier in all cases. Plus, we wouldn't have to worry about confusion over whether NaN is equal to itself: in Java, a Double that holds NaN is equal to itself, but a primitive is not. 😕

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, those are some good points! To make sure I understand correctly/know how to move forward, I have some questions:

  • If I understand correctly, to reject NaN in any predicate sounds like we might go back to the idea of rewriting equals in SparkFilters (or in general, the integration point with engines during the query-to-expression translation); or maybe even earlier than that, to let engines to support syntax of is NaN?
  • Since to know if a query is eligible to be translated to isNaN there has to be some place that ensures the type has to be either double or float, and in iceberg code base we will only know this during binding; are we able to rely on engine to do this check before translating query to Expression?
  • And seems like this may only impact eq as we decided to do input validation on other lg/lteq/gt/gteq and in anyway?
  • And if we start to throw exceptions when the code passes in NaN to eq, that may sound backward incompatible until the engine starts to rewrite NaN?

I guess the conversation is starting to get too detailed, if you wouldn't mind I'll try to follow up on Slack tomorrow and then post the conclusion here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, to reject NaN in any predicate sounds like we might go back to the idea of rewriting equals in SparkFilters

Yes. If the engine generally uses d = NaN then we can convert that to isNaN. But that would be engine-dependent and the Iceberg expression API would not support equals with NaN.

are we able to rely on engine to do this check before translating query to Expression?

I think so. Most engines will optimize the SQL expressions and handle this already. If not, then it would result in an exception from Iceberg to the user. I think that's okay, too, because as I said above, we want to fail if a NaN is used in an expression with a non-floating-point column, not rewrite to false.

And seems like this may only impact eq as we decided to do input validation on other lg/lteq/gt/gteq and in anyway?

Yes. This makes all of the handling in Expressions consistent: always reject NaN values.

that may sound backward incompatible until the engine starts to rewrite NaN?

I'm not convinced either way. You could argue that d = NaN is ambiguous and that rejecting it is now fixing a bug. That's certainly the case with d > NaN, which is not defined. On the other hand, there was some bevhavior before that will now no longer work. So I'd be up for fixing this in Flink and Spark conversions as soon as we can.

Feel free to ping me on Slack!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation! I think now I understand the full picture. I think I've addressed everything except for rewriting in SparkFilters and other engines, which I think this PR is already too big so I'll submit a separate PR for it (likely next week).


public static <T> UnboundPredicate<T> lessThan(String name, T value) {
validateInput("lessThan", value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An easier way to do this is to add the check in Literal.from. That's where Iceberg enforces that the value cannot be null. Since a literal is created for every value that is passed in, we would only need to change that one place instead of all of the factory methods here.

It also ensures that we don't add factory methods later and forget to add the check to them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I didn't notice Literals.from was called within constructor of UnboundPredicate when normal object is passed in. This is definitely much more cleaner! I have created #1892 to address this.

Thank you so much for your time reviewing this long PR!

return new UnboundPredicate<>(Expression.Operation.LT, ref(name), value);
}

public static <T> UnboundPredicate<T> lessThan(UnboundTerm<T> expr, T value) {
validateInput("lessThan", value);
return new UnboundPredicate<>(Expression.Operation.LT, expr, value);
}

public static <T> UnboundPredicate<T> lessThanOrEqual(String name, T value) {
validateInput("lessThanOrEqual", value);
return new UnboundPredicate<>(Expression.Operation.LT_EQ, ref(name), value);
}

public static <T> UnboundPredicate<T> lessThanOrEqual(UnboundTerm<T> expr, T value) {
validateInput("lessThanOrEqual", value);
return new UnboundPredicate<>(Expression.Operation.LT_EQ, expr, value);
}

public static <T> UnboundPredicate<T> greaterThan(String name, T value) {
validateInput("greaterThan", value);
return new UnboundPredicate<>(Expression.Operation.GT, ref(name), value);
}

public static <T> UnboundPredicate<T> greaterThan(UnboundTerm<T> expr, T value) {
validateInput("greaterThan", value);
return new UnboundPredicate<>(Expression.Operation.GT, expr, value);
}

public static <T> UnboundPredicate<T> greaterThanOrEqual(String name, T value) {
validateInput("greaterThanOrEqual", value);
return new UnboundPredicate<>(Expression.Operation.GT_EQ, ref(name), value);
}

public static <T> UnboundPredicate<T> greaterThanOrEqual(UnboundTerm<T> expr, T value) {
validateInput("greaterThanOrEqual", value);
return new UnboundPredicate<>(Expression.Operation.GT_EQ, expr, value);
}

public static <T> UnboundPredicate<T> equal(String name, T value) {
validateInput("equal", value);
return new UnboundPredicate<>(Expression.Operation.EQ, ref(name), value);
}

public static <T> UnboundPredicate<T> equal(UnboundTerm<T> expr, T value) {
validateInput("equal", value);
return new UnboundPredicate<>(Expression.Operation.EQ, expr, value);
}

public static <T> UnboundPredicate<T> notEqual(String name, T value) {
validateInput("notEqual", value);
return new UnboundPredicate<>(Expression.Operation.NOT_EQ, ref(name), value);
}

public static <T> UnboundPredicate<T> notEqual(UnboundTerm<T> expr, T value) {
validateInput("notEqual", value);
return new UnboundPredicate<>(Expression.Operation.NOT_EQ, expr, value);
}

Expand Down Expand Up @@ -232,6 +245,7 @@ public static <T> UnboundPredicate<T> notIn(UnboundTerm<T> expr, Iterable<T> val
}

public static <T> UnboundPredicate<T> predicate(Operation op, String name, T value) {
validateInput(op.toString(), value);
return predicate(op, name, Literals.from(value));
}

Expand All @@ -243,6 +257,7 @@ public static <T> UnboundPredicate<T> predicate(Operation op, String name, Liter
}

public static <T> UnboundPredicate<T> predicate(Operation op, String name, Iterable<T> values) {
validateInput(op.toString(), values);
return predicate(op, ref(name), values);
}

Expand All @@ -254,9 +269,19 @@ public static <T> UnboundPredicate<T> predicate(Operation op, String name) {
}

private static <T> UnboundPredicate<T> predicate(Operation op, UnboundTerm<T> expr, Iterable<T> values) {
validateInput(op.toString(), values);
return new UnboundPredicate<>(op, expr, values);
}

private static <T> void validateInput(String op, T value) {
Preconditions.checkArgument(!NaNUtil.isNaN(value), String.format("Cannot create %s predicate with NaN", op));
}

private static <T> void validateInput(String op, Iterable<T> values) {
Preconditions.checkArgument(Lists.newArrayList(values).stream().noneMatch(NaNUtil::isNaN),
String.format("Cannot create %s predicate with NaN", op));
}

public static True alwaysTrue() {
return True.INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ public <T> Boolean notNull(BoundReference<T> ref) {
int pos = Accessors.toPosition(ref.accessor());
// containsNull encodes whether at least one partition value is null, lowerBound is null if
// all partition values are null.
ByteBuffer lowerBound = stats.get(pos).lowerBound();
if (lowerBound == null) {
if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) {
return ROWS_CANNOT_MATCH; // all values are null
}

Expand All @@ -147,8 +146,7 @@ public <T> Boolean isNaN(BoundReference<T> ref) {
int pos = Accessors.toPosition(ref.accessor());
// containsNull encodes whether at least one partition value is null, lowerBound is null if
// all partition values are null.
ByteBuffer lowerBound = stats.get(pos).lowerBound();
if (lowerBound == null) {
if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) {
return ROWS_CANNOT_MATCH; // all values are null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.CharSequenceSet;
import org.apache.iceberg.util.NaNUtil;

public class UnboundPredicate<T> extends Predicate<T, UnboundTerm<T>> implements Unbound<T, Expression> {
private static final Joiner COMMA = Joiner.on(", ");
Expand Down Expand Up @@ -130,25 +129,35 @@ private Expression bindUnaryOperation(BoundTerm<T> boundTerm) {
}
return new BoundUnaryPredicate<>(Operation.NOT_NULL, boundTerm);
case IS_NAN:
return toIsNaNExpression(boundTerm);
if (floatingType(boundTerm.type().typeId())) {
return new BoundUnaryPredicate<>(Operation.IS_NAN, boundTerm);
} else {
throw new ValidationException("IsNaN cannot be used with a non-floating-point column");
}
case NOT_NAN:
return toNotNaNExpression(boundTerm);
if (floatingType(boundTerm.type().typeId())) {
return new BoundUnaryPredicate<>(Operation.NOT_NAN, boundTerm);
} else {
throw new ValidationException("NotNaN cannot be used with a non-floating-point column");
}
default:
throw new ValidationException("Operation must be IS_NULL, NOT_NULL, IS_NAN, or NOT_NAN");
}
}

private Expression bindLiteralOperation(BoundTerm<T> boundTerm) {
return bindLiteralOperation(boundTerm, op(), literal().to(boundTerm.type()));
private boolean floatingType(Type.TypeID typeID) {
return Type.TypeID.DOUBLE.equals(typeID) || Type.TypeID.FLOAT.equals(typeID);
}

private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Literal<T> lit) {
private Expression bindLiteralOperation(BoundTerm<T> boundTerm) {
Literal<T> lit = literal().to(boundTerm.type());

if (lit == null) {
throw new ValidationException("Invalid value for conversion to type %s: %s (%s)",
boundTerm.type(), literal().value(), literal().value().getClass().getName());

} else if (lit == Literals.aboveMax()) {
switch (op) {
switch (op()) {
case LT:
case LT_EQ:
case NOT_EQ:
Expand All @@ -159,7 +168,7 @@ private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Li
return Expressions.alwaysFalse();
}
} else if (lit == Literals.belowMin()) {
switch (op) {
switch (op()) {
case GT:
case GT_EQ:
case NOT_EQ:
Expand All @@ -169,42 +178,10 @@ private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Li
case EQ:
return Expressions.alwaysFalse();
}
} else if (NaNUtil.isNaN(lit.value())) {
switch (op) {
case GT:
case GT_EQ:
case LT:
case LT_EQ:
throw new IllegalArgumentException(String.format("Cannot perform operation %s with value NaN", op));
case EQ:
return toIsNaNExpression(boundTerm);
case NOT_EQ:
return toNotNaNExpression(boundTerm);
}
}

// TODO: translate truncate(col) == value to startsWith(value)
return new BoundLiteralPredicate<>(op, boundTerm, lit);
}

private Expression toIsNaNExpression(BoundTerm<T> boundTerm) {
if (typeIncludesNaN(boundTerm.type().typeId())) {
return new BoundUnaryPredicate<>(Operation.IS_NAN, boundTerm);
} else {
return Expressions.alwaysFalse();
}
}

private Expression toNotNaNExpression(BoundTerm<T> boundTerm) {
if (typeIncludesNaN(boundTerm.type().typeId())) {
return new BoundUnaryPredicate<>(Operation.NOT_NAN, boundTerm);
} else {
return Expressions.alwaysTrue();
}
}

private boolean typeIncludesNaN(Type.TypeID typeID) {
return Type.TypeID.DOUBLE.equals(typeID) || Type.TypeID.FLOAT.equals(typeID);
return new BoundLiteralPredicate<>(op(), boundTerm, lit);
}

private Expression bindInOperation(BoundTerm<T> boundTerm) {
Expand Down Expand Up @@ -232,9 +209,9 @@ private Expression bindInOperation(BoundTerm<T> boundTerm) {
if (literalSet.size() == 1) {
switch (op()) {
case IN:
return bindLiteralOperation(boundTerm, Operation.EQ, Iterables.get(convertedLiterals, 0));
return new BoundLiteralPredicate<>(Operation.EQ, boundTerm, Iterables.get(convertedLiterals, 0));
case NOT_IN:
return bindLiteralOperation(boundTerm, Operation.NOT_EQ, Iterables.get(convertedLiterals, 0));
return new BoundLiteralPredicate<>(Operation.NOT_EQ, boundTerm, Iterables.get(convertedLiterals, 0));
default:
throw new ValidationException("Operation must be IN or NOT_IN");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

package org.apache.iceberg.expressions;

import java.util.concurrent.Callable;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.NestedField;
import org.apache.iceberg.types.Types.StructType;
Expand All @@ -45,6 +47,8 @@
import static org.apache.iceberg.expressions.Expressions.notIn;
import static org.apache.iceberg.expressions.Expressions.notNull;
import static org.apache.iceberg.expressions.Expressions.or;
import static org.apache.iceberg.expressions.Expressions.predicate;
import static org.apache.iceberg.expressions.Expressions.ref;
import static org.apache.iceberg.expressions.Expressions.rewriteNot;
import static org.apache.iceberg.expressions.Expressions.truncate;
import static org.apache.iceberg.expressions.Expressions.year;
Expand Down Expand Up @@ -187,4 +191,44 @@ public void testMultiAnd() {

Assert.assertEquals(expected.toString(), actual.toString());
}

@Test
public void testInvalidateNaNInput() {
assertInvalidateNaNThrows("lessThan", () -> lessThan("a", Double.NaN));
assertInvalidateNaNThrows("lessThan", () -> lessThan(self("a"), Double.NaN));

assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual("a", Double.NaN));
assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual(self("a"), Double.NaN));

assertInvalidateNaNThrows("greaterThan", () -> greaterThan("a", Double.NaN));
assertInvalidateNaNThrows("greaterThan", () -> greaterThan(self("a"), Double.NaN));

assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual("a", Double.NaN));
assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual(self("a"), Double.NaN));

assertInvalidateNaNThrows("equal", () -> equal("a", Double.NaN));
assertInvalidateNaNThrows("equal", () -> equal(self("a"), Double.NaN));

assertInvalidateNaNThrows("notEqual", () -> notEqual("a", Double.NaN));
assertInvalidateNaNThrows("notEqual", () -> notEqual(self("a"), Double.NaN));

assertInvalidateNaNThrows("IN", () -> in("a", 1.0D, 2.0D, Double.NaN));
assertInvalidateNaNThrows("IN", () -> in(self("a"), 1.0D, 2.0D, Double.NaN));

assertInvalidateNaNThrows("NOT_IN", () -> notIn("a", 1.0D, 2.0D, Double.NaN));
assertInvalidateNaNThrows("NOT_IN", () -> notIn(self("a"), 1.0D, 2.0D, Double.NaN));

assertInvalidateNaNThrows("EQ", () -> predicate(Expression.Operation.EQ, "a", Double.NaN));
}

private void assertInvalidateNaNThrows(String operation, Callable<UnboundPredicate<Double>> callable) {
AssertHelpers.assertThrows("Should invalidate NaN input",
IllegalArgumentException.class, String.format("Cannot create %s predicate with NaN", operation),
callable);
}

private <T> UnboundTerm<T> self(String name) {
return new UnboundTransform<>(ref(name), Transforms.identity(Types.DoubleType.get()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public class TestInclusiveManifestEvaluator {
new TestHelpers.TestFieldSummary(false,
toByteBuffer(Types.FloatType.get(), 0F),
toByteBuffer(Types.FloatType.get(), 20F)),
new TestHelpers.TestFieldSummary(false, null, null)
new TestHelpers.TestFieldSummary(true, null, null)
));

@Test
Expand Down
Loading