Skip to content

Commit d5e6663

Browse files
committed
do not accept NaN in Expressions, or mismatch type
1 parent d7c3c3e commit d5e6663

File tree

9 files changed

+106
-264
lines changed

9 files changed

+106
-264
lines changed

api/src/main/java/org/apache/iceberg/expressions/Expressions.java

+25
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.iceberg.transforms.Transform;
2727
import org.apache.iceberg.transforms.Transforms;
2828
import org.apache.iceberg.types.Types;
29+
import org.apache.iceberg.util.NaNUtil;
2930

3031
/**
3132
* Factory methods for creating {@link Expression expressions}.
@@ -140,50 +141,62 @@ public static <T> UnboundPredicate<T> notNaN(UnboundTerm<T> expr) {
140141
}
141142

142143
public static <T> UnboundPredicate<T> lessThan(String name, T value) {
144+
validateInput("lessThan", value);
143145
return new UnboundPredicate<>(Expression.Operation.LT, ref(name), value);
144146
}
145147

146148
public static <T> UnboundPredicate<T> lessThan(UnboundTerm<T> expr, T value) {
149+
validateInput("lessThan", value);
147150
return new UnboundPredicate<>(Expression.Operation.LT, expr, value);
148151
}
149152

150153
public static <T> UnboundPredicate<T> lessThanOrEqual(String name, T value) {
154+
validateInput("lessThanOrEqual", value);
151155
return new UnboundPredicate<>(Expression.Operation.LT_EQ, ref(name), value);
152156
}
153157

154158
public static <T> UnboundPredicate<T> lessThanOrEqual(UnboundTerm<T> expr, T value) {
159+
validateInput("lessThanOrEqual", value);
155160
return new UnboundPredicate<>(Expression.Operation.LT_EQ, expr, value);
156161
}
157162

158163
public static <T> UnboundPredicate<T> greaterThan(String name, T value) {
164+
validateInput("greaterThan", value);
159165
return new UnboundPredicate<>(Expression.Operation.GT, ref(name), value);
160166
}
161167

162168
public static <T> UnboundPredicate<T> greaterThan(UnboundTerm<T> expr, T value) {
169+
validateInput("greaterThan", value);
163170
return new UnboundPredicate<>(Expression.Operation.GT, expr, value);
164171
}
165172

166173
public static <T> UnboundPredicate<T> greaterThanOrEqual(String name, T value) {
174+
validateInput("greaterThanOrEqual", value);
167175
return new UnboundPredicate<>(Expression.Operation.GT_EQ, ref(name), value);
168176
}
169177

170178
public static <T> UnboundPredicate<T> greaterThanOrEqual(UnboundTerm<T> expr, T value) {
179+
validateInput("greaterThanOrEqual", value);
171180
return new UnboundPredicate<>(Expression.Operation.GT_EQ, expr, value);
172181
}
173182

174183
public static <T> UnboundPredicate<T> equal(String name, T value) {
184+
validateInput("equal", value);
175185
return new UnboundPredicate<>(Expression.Operation.EQ, ref(name), value);
176186
}
177187

178188
public static <T> UnboundPredicate<T> equal(UnboundTerm<T> expr, T value) {
189+
validateInput("equal", value);
179190
return new UnboundPredicate<>(Expression.Operation.EQ, expr, value);
180191
}
181192

182193
public static <T> UnboundPredicate<T> notEqual(String name, T value) {
194+
validateInput("notEqual", value);
183195
return new UnboundPredicate<>(Expression.Operation.NOT_EQ, ref(name), value);
184196
}
185197

186198
public static <T> UnboundPredicate<T> notEqual(UnboundTerm<T> expr, T value) {
199+
validateInput("notEqual", value);
187200
return new UnboundPredicate<>(Expression.Operation.NOT_EQ, expr, value);
188201
}
189202

@@ -232,6 +245,7 @@ public static <T> UnboundPredicate<T> notIn(UnboundTerm<T> expr, Iterable<T> val
232245
}
233246

234247
public static <T> UnboundPredicate<T> predicate(Operation op, String name, T value) {
248+
validateInput(op.toString(), value);
235249
return predicate(op, name, Literals.from(value));
236250
}
237251

@@ -243,6 +257,7 @@ public static <T> UnboundPredicate<T> predicate(Operation op, String name, Liter
243257
}
244258

245259
public static <T> UnboundPredicate<T> predicate(Operation op, String name, Iterable<T> values) {
260+
validateInput(op.toString(), values);
246261
return predicate(op, ref(name), values);
247262
}
248263

@@ -254,9 +269,19 @@ public static <T> UnboundPredicate<T> predicate(Operation op, String name) {
254269
}
255270

256271
private static <T> UnboundPredicate<T> predicate(Operation op, UnboundTerm<T> expr, Iterable<T> values) {
272+
validateInput(op.toString(), values);
257273
return new UnboundPredicate<>(op, expr, values);
258274
}
259275

276+
private static <T> void validateInput(String op, T value) {
277+
Preconditions.checkArgument(!NaNUtil.isNaN(value), String.format("Cannot create %s predicate with NaN", op));
278+
}
279+
280+
private static <T> void validateInput(String op, Iterable<T> values) {
281+
Preconditions.checkArgument(Lists.newArrayList(values).stream().noneMatch(NaNUtil::isNaN),
282+
String.format("Cannot create %s predicate with NaN", op));
283+
}
284+
260285
public static True alwaysTrue() {
261286
return True.INSTANCE;
262287
}

api/src/main/java/org/apache/iceberg/expressions/ManifestEvaluator.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ public <T> Boolean notNull(BoundReference<T> ref) {
134134
int pos = Accessors.toPosition(ref.accessor());
135135
// containsNull encodes whether at least one partition value is null, lowerBound is null if
136136
// all partition values are null.
137-
ByteBuffer lowerBound = stats.get(pos).lowerBound();
138-
if (lowerBound == null) {
137+
if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) {
139138
return ROWS_CANNOT_MATCH; // all values are null
140139
}
141140

@@ -147,8 +146,7 @@ public <T> Boolean isNaN(BoundReference<T> ref) {
147146
int pos = Accessors.toPosition(ref.accessor());
148147
// containsNull encodes whether at least one partition value is null, lowerBound is null if
149148
// all partition values are null.
150-
ByteBuffer lowerBound = stats.get(pos).lowerBound();
151-
if (lowerBound == null) {
149+
if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) {
152150
return ROWS_CANNOT_MATCH; // all values are null
153151
}
154152

api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java

+20-43
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.apache.iceberg.types.Type;
3131
import org.apache.iceberg.types.Types.StructType;
3232
import org.apache.iceberg.util.CharSequenceSet;
33-
import org.apache.iceberg.util.NaNUtil;
3433

3534
public class UnboundPredicate<T> extends Predicate<T, UnboundTerm<T>> implements Unbound<T, Expression> {
3635
private static final Joiner COMMA = Joiner.on(", ");
@@ -130,25 +129,35 @@ private Expression bindUnaryOperation(BoundTerm<T> boundTerm) {
130129
}
131130
return new BoundUnaryPredicate<>(Operation.NOT_NULL, boundTerm);
132131
case IS_NAN:
133-
return toIsNaNExpression(boundTerm);
132+
if (floatingType(boundTerm.type().typeId())) {
133+
return new BoundUnaryPredicate<>(Operation.IS_NAN, boundTerm);
134+
} else {
135+
throw new ValidationException("IsNaN cannot be used with a non-floating-point column");
136+
}
134137
case NOT_NAN:
135-
return toNotNaNExpression(boundTerm);
138+
if (floatingType(boundTerm.type().typeId())) {
139+
return new BoundUnaryPredicate<>(Operation.NOT_NAN, boundTerm);
140+
} else {
141+
throw new ValidationException("NotNaN cannot be used with a non-floating-point column");
142+
}
136143
default:
137144
throw new ValidationException("Operation must be IS_NULL, NOT_NULL, IS_NAN, or NOT_NAN");
138145
}
139146
}
140147

141-
private Expression bindLiteralOperation(BoundTerm<T> boundTerm) {
142-
return bindLiteralOperation(boundTerm, op(), literal().to(boundTerm.type()));
148+
private boolean floatingType(Type.TypeID typeID) {
149+
return Type.TypeID.DOUBLE.equals(typeID) || Type.TypeID.FLOAT.equals(typeID);
143150
}
144151

145-
private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Literal<T> lit) {
152+
private Expression bindLiteralOperation(BoundTerm<T> boundTerm) {
153+
Literal<T> lit = literal().to(boundTerm.type());
154+
146155
if (lit == null) {
147156
throw new ValidationException("Invalid value for conversion to type %s: %s (%s)",
148157
boundTerm.type(), literal().value(), literal().value().getClass().getName());
149158

150159
} else if (lit == Literals.aboveMax()) {
151-
switch (op) {
160+
switch (op()) {
152161
case LT:
153162
case LT_EQ:
154163
case NOT_EQ:
@@ -159,7 +168,7 @@ private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Li
159168
return Expressions.alwaysFalse();
160169
}
161170
} else if (lit == Literals.belowMin()) {
162-
switch (op) {
171+
switch (op()) {
163172
case GT:
164173
case GT_EQ:
165174
case NOT_EQ:
@@ -169,42 +178,10 @@ private Expression bindLiteralOperation(BoundTerm<T> boundTerm, Operation op, Li
169178
case EQ:
170179
return Expressions.alwaysFalse();
171180
}
172-
} else if (NaNUtil.isNaN(lit.value())) {
173-
switch (op) {
174-
case GT:
175-
case GT_EQ:
176-
case LT:
177-
case LT_EQ:
178-
throw new IllegalArgumentException(String.format("Cannot perform operation %s with value NaN", op));
179-
case EQ:
180-
return toIsNaNExpression(boundTerm);
181-
case NOT_EQ:
182-
return toNotNaNExpression(boundTerm);
183-
}
184181
}
185182

186183
// TODO: translate truncate(col) == value to startsWith(value)
187-
return new BoundLiteralPredicate<>(op, boundTerm, lit);
188-
}
189-
190-
private Expression toIsNaNExpression(BoundTerm<T> boundTerm) {
191-
if (typeIncludesNaN(boundTerm.type().typeId())) {
192-
return new BoundUnaryPredicate<>(Operation.IS_NAN, boundTerm);
193-
} else {
194-
return Expressions.alwaysFalse();
195-
}
196-
}
197-
198-
private Expression toNotNaNExpression(BoundTerm<T> boundTerm) {
199-
if (typeIncludesNaN(boundTerm.type().typeId())) {
200-
return new BoundUnaryPredicate<>(Operation.NOT_NAN, boundTerm);
201-
} else {
202-
return Expressions.alwaysTrue();
203-
}
204-
}
205-
206-
private boolean typeIncludesNaN(Type.TypeID typeID) {
207-
return Type.TypeID.DOUBLE.equals(typeID) || Type.TypeID.FLOAT.equals(typeID);
184+
return new BoundLiteralPredicate<>(op(), boundTerm, lit);
208185
}
209186

210187
private Expression bindInOperation(BoundTerm<T> boundTerm) {
@@ -232,9 +209,9 @@ private Expression bindInOperation(BoundTerm<T> boundTerm) {
232209
if (literalSet.size() == 1) {
233210
switch (op()) {
234211
case IN:
235-
return bindLiteralOperation(boundTerm, Operation.EQ, Iterables.get(convertedLiterals, 0));
212+
return new BoundLiteralPredicate<>(Operation.EQ, boundTerm, Iterables.get(convertedLiterals, 0));
236213
case NOT_IN:
237-
return bindLiteralOperation(boundTerm, Operation.NOT_EQ, Iterables.get(convertedLiterals, 0));
214+
return new BoundLiteralPredicate<>(Operation.NOT_EQ, boundTerm, Iterables.get(convertedLiterals, 0));
238215
default:
239216
throw new ValidationException("Operation must be IN or NOT_IN");
240217
}

api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java

+44
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
package org.apache.iceberg.expressions;
2121

22+
import java.util.concurrent.Callable;
2223
import org.apache.iceberg.AssertHelpers;
24+
import org.apache.iceberg.transforms.Transforms;
2325
import org.apache.iceberg.types.Types;
2426
import org.apache.iceberg.types.Types.NestedField;
2527
import org.apache.iceberg.types.Types.StructType;
@@ -45,6 +47,8 @@
4547
import static org.apache.iceberg.expressions.Expressions.notIn;
4648
import static org.apache.iceberg.expressions.Expressions.notNull;
4749
import static org.apache.iceberg.expressions.Expressions.or;
50+
import static org.apache.iceberg.expressions.Expressions.predicate;
51+
import static org.apache.iceberg.expressions.Expressions.ref;
4852
import static org.apache.iceberg.expressions.Expressions.rewriteNot;
4953
import static org.apache.iceberg.expressions.Expressions.truncate;
5054
import static org.apache.iceberg.expressions.Expressions.year;
@@ -187,4 +191,44 @@ public void testMultiAnd() {
187191

188192
Assert.assertEquals(expected.toString(), actual.toString());
189193
}
194+
195+
@Test
196+
public void testInvalidateNaNInput() {
197+
assertInvalidateNaNThrows("lessThan", () -> lessThan("a", Double.NaN));
198+
assertInvalidateNaNThrows("lessThan", () -> lessThan(self("a"), Double.NaN));
199+
200+
assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual("a", Double.NaN));
201+
assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual(self("a"), Double.NaN));
202+
203+
assertInvalidateNaNThrows("greaterThan", () -> greaterThan("a", Double.NaN));
204+
assertInvalidateNaNThrows("greaterThan", () -> greaterThan(self("a"), Double.NaN));
205+
206+
assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual("a", Double.NaN));
207+
assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual(self("a"), Double.NaN));
208+
209+
assertInvalidateNaNThrows("equal", () -> equal("a", Double.NaN));
210+
assertInvalidateNaNThrows("equal", () -> equal(self("a"), Double.NaN));
211+
212+
assertInvalidateNaNThrows("notEqual", () -> notEqual("a", Double.NaN));
213+
assertInvalidateNaNThrows("notEqual", () -> notEqual(self("a"), Double.NaN));
214+
215+
assertInvalidateNaNThrows("IN", () -> in("a", 1.0D, 2.0D, Double.NaN));
216+
assertInvalidateNaNThrows("IN", () -> in(self("a"), 1.0D, 2.0D, Double.NaN));
217+
218+
assertInvalidateNaNThrows("NOT_IN", () -> notIn("a", 1.0D, 2.0D, Double.NaN));
219+
assertInvalidateNaNThrows("NOT_IN", () -> notIn(self("a"), 1.0D, 2.0D, Double.NaN));
220+
221+
assertInvalidateNaNThrows("EQ", () -> predicate(Expression.Operation.EQ, "a", Double.NaN));
222+
}
223+
224+
private void assertInvalidateNaNThrows(String operation, Callable<UnboundPredicate<Double>> callable) {
225+
AssertHelpers.assertThrows("Should invalidate NaN input",
226+
IllegalArgumentException.class, String.format("Cannot create %s predicate with NaN", operation),
227+
callable);
228+
}
229+
230+
private <T> UnboundTerm<T> self(String name) {
231+
return new UnboundTransform<>(ref(name), Transforms.identity(Types.DoubleType.get()));
232+
}
233+
190234
}

api/src/test/java/org/apache/iceberg/expressions/TestInclusiveManifestEvaluator.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public class TestInclusiveManifestEvaluator {
9292
new TestHelpers.TestFieldSummary(false,
9393
toByteBuffer(Types.FloatType.get(), 0F),
9494
toByteBuffer(Types.FloatType.get(), 20F)),
95-
new TestHelpers.TestFieldSummary(false, null, null)
95+
new TestHelpers.TestFieldSummary(true, null, null)
9696
));
9797

9898
@Test

0 commit comments

Comments
 (0)