Skip to content

Commit

Permalink
[fix](nereids) fix compare with date like literal (apache#45382)
Browse files Browse the repository at this point in the history
when compare with date like literal,  the literal may be wrong cut.
for example:

if a is datetimev1,   then:
`a = '2020-01-01 00:00:00.12'` should opt as `FALSE`, but it opt as `a =
'2020-01-01 00:00:00`;
`a >= '2020-01-01 00:00:00.12'` should opt as `a >= '2020-01-01
00:00:01`, but it opt as `a >= '2020-01-01 00:00:00`;

if a is date / datev2, then:
`a = '2020-01-01 00:00:12'` should opt as `FALSE`, but it don't opt it;
  • Loading branch information
yujun777 authored Dec 18, 2024
1 parent c58a1b7 commit ad814d2
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
Expand All @@ -71,12 +72,6 @@
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

enum AdjustType {
LOWER,
UPPER,
NONE
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -119,18 +114,23 @@ public static Expression simplify(ComparisonPredicate cp) {
return result;
}

private static Expression processComparisonPredicateDateTimeV2Literal(
private static Expression processDateTimeLikeComparisonPredicateDateTimeV2Literal(
ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) {
DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType();
DataType leftType = left.getDataType();
int toScale = 0;
if (leftType instanceof DateTimeType) {
toScale = 0;
} else if (leftType instanceof DateTimeV2Type) {
toScale = ((DateTimeV2Type) leftType).getScale();
} else {
return comparisonPredicate;
}
DateTimeV2Type rightType = right.getDataType();
if (leftType.getScale() < rightType.getScale()) {
int toScale = leftType.getScale();
if (toScale < rightType.getScale()) {
if (comparisonPredicate instanceof EqualTo) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (right.getMicroSecond() != originValue) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
Expand All @@ -142,50 +142,55 @@ private static Expression processComparisonPredicateDateTimeV2Literal(
} else if (comparisonPredicate instanceof NullSafeEqual) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (right.getMicroSecond() != originValue) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, right.roundFloor(toScale));
right = right.roundFloor(toScale);
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, right.roundCeiling(toScale));
right = right.roundCeiling(toScale);
} else {
return comparisonPredicate;
}
Expression newRight = leftType instanceof DateTimeType ? migrateToDateTime(right) : right;
return comparisonPredicate.withChildren(left, newRight);
} else {
if (leftType instanceof DateTimeType) {
return comparisonPredicate.withChildren(left, migrateToDateTime(right));
} else {
return comparisonPredicate;
}
}
return comparisonPredicate;
}

private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
if (left instanceof Cast && right instanceof DateLiteral) {
Cast cast = (Cast) left;
if (cast.child().getDataType() instanceof DateTimeType) {
if (cast.child().getDataType() instanceof DateTimeType
|| cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
right = migrateToDateTime((DateTimeV2Literal) right);
}
}
if (cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
return processComparisonPredicateDateTimeV2Literal(cp, left, (DateTimeV2Literal) right);
return processDateTimeLikeComparisonPredicateDateTimeV2Literal(
cp, cast.child(), (DateTimeV2Literal) right);
}
}

// datetime to datev2
if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) {
if (right instanceof DateTimeLiteral) {
if (cannotAdjust((DateTimeLiteral) right, cp)) {
return cp;
}
AdjustType type = AdjustType.NONE;
if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
type = AdjustType.UPPER;
} else if (cp instanceof GreaterThan || cp instanceof LessThanEqual) {
type = AdjustType.LOWER;
DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) right;
right = migrateToDateV2(dateTimeLiteral);
if (dateTimeLiteral.getHour() != 0 || dateTimeLiteral.getMinute() != 0
|| dateTimeLiteral.getSecond() != 0) {
if (cp instanceof EqualTo) {
return ExpressionUtils.falseOrNull(cast.child());
} else if (cp instanceof NullSafeEqual) {
return BooleanLiteral.FALSE;
} else if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
right = ((DateV2Literal) right).plusDays(1);
}
}
right = migrateToDateV2((DateTimeLiteral) right, type);
if (cast.child().getDataType() instanceof DateV2Type) {
left = cast.child();
}
Expand Down Expand Up @@ -416,17 +421,8 @@ private static Expression migrateToDateTime(DateTimeV2Literal l) {
return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond());
}

private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) {
return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0);
}

private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) {
DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) {
return d.plusDays(1);
} else {
return d;
}
private static Expression migrateToDateV2(DateTimeLiteral l) {
return new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
}

private static Expression migrateToDate(DateV2Literal l) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
Expand Down Expand Up @@ -95,11 +98,11 @@ void testSimplifyComparisonPredicateRule() {
new LessThan(dv2, dv2PlusOne));
assertRewrite(
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

assertRewrite(
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

// test hour, minute and second all zero
Expression dtv2AtZeroClock = new DateTimeV2Literal(1, 1, 1, 0, 0, 0, 0);
Expand Down Expand Up @@ -140,6 +143,100 @@ void testDateTimeV2CmpDateTimeV2() {
expression = new GreaterThan(left, right);
rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType());

Expression date = new SlotReference("a", DateV2Type.INSTANCE);
Expression datev1 = new SlotReference("a", DateType.INSTANCE);
Expression datetime0 = new SlotReference("a", DateTimeV2Type.of(0));
Expression datetime2 = new SlotReference("a", DateTimeV2Type.of(2));
Expression datetimev1 = new SlotReference("a", DateTimeType.INSTANCE);

// date
// cast (date as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(date, new DateV2Literal("2020-01-01")));
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(date));
assertRewrite(new NullSafeEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(date, new DateV2Literal("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(date, new DateV2Literal("2020-01-01")));
// cast (date as datev1) = datev1-literal
// assertRewrite(new EqualTo(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new EqualTo(date, new DateV2Literal("2020-01-01")));
// assertRewrite(new GreaterThan(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new GreaterThan(date, new DateV2Literal("2020-01-01")));

// cast (datev1 as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(datev1));
assertRewrite(new NullSafeEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThan(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));

// cast (datetimev1 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new EqualTo(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetimev1));
assertRewrite(new NullSafeEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));

// cast (datetime0 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetime0));
assertRewrite(new NullSafeEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));

// cast (datetime2 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
ExpressionUtils.falseOrNull(datetime2));
assertRewrite(new NullSafeEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
assertRewrite(new GreaterThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ PhysicalResultSink
----------PhysicalProject
------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal]
------PhysicalProject
--------filter((cast(d_date as DATETIMEV2(0)) = '2024-08-02 10:10:00'))
--------filter((t2.d_date = '2024-08-02'))
----------PhysicalOlapScan[test_types]

-- !const_value_and_join_column_type170 --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ select c1 from (select
qt_const_value_and_join_column_type169 """
explain shape plan
select c1 from (select
'2024-08-02 10:10:00.123332' as c1 from test_pull_up_predicate_literal limit 10) t inner join test_types t2 on d_date=t.c1"""
'2024-08-02 00:00:00.000000' as c1 from test_pull_up_predicate_literal limit 10) t inner join test_types t2 on d_date=t.c1"""

qt_const_value_and_join_column_type170 """
explain shape plan
Expand Down

0 comments on commit ad814d2

Please sign in to comment.