Skip to content

Commit

Permalink
[CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin on composit…
Browse files Browse the repository at this point in the history
…e key return rows matching condition 'null = null'
  • Loading branch information
rubenada committed Sep 8, 2023
1 parent 64268b9 commit 3aee0b8
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ private Result implementHashSemiJoin(EnumerableRelImplementor implementor, Prefe
Expressions.list(
leftExpression,
rightExpression,
leftResult.physType.generateAccessor(joinInfo.leftKeys),
rightResult.physType.generateAccessor(joinInfo.rightKeys),
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
Util.first(keyPhysType.comparer(),
Expressions.constant(null)),
predicate)))
Expand Down Expand Up @@ -264,8 +264,8 @@ private Result implementHashJoin(EnumerableRelImplementor implementor, Prefer pr
BuiltInMethod.HASH_JOIN.method,
Expressions.list(
rightExpression,
leftResult.physType.generateAccessor(joinInfo.leftKeys),
rightResult.physType.generateAccessor(joinInfo.rightKeys),
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
EnumUtils.joinSelector(joinType,
physType,
ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ public static EnumerableMergeJoin create(RelNode left, RelNode right,
RelFieldCollation.NullDirection.LAST));
}
final RelCollation collation = RelCollations.of(fieldCollations);
final Expression comparator = leftKeyPhysType.generateComparator(collation);
final Expression comparator = leftKeyPhysType.generateMergeJoinComparator(collation);

return implementor.result(
physType,
Expand All @@ -512,6 +512,9 @@ public static EnumerableMergeJoin create(RelNode left, RelNode right,
ImmutableList.of(
leftResult.physType, rightResult.physType)),
Expressions.constant(EnumUtils.toLinq4jJoinType(joinType)),
comparator))).toBlock());
comparator,
Util.first(
leftKeyPhysType.comparer(),
Expressions.constant(null))))).toBlock());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,28 @@ Expression fieldReference(Expression expression, int field,
* public Object[] apply(Employee v1) {
* return FlatLists.of(v1.<fieldN>, v1.<fieldM>);
* }
* }
* }</pre></blockquote>
*/
Expression generateAccessor(List<Integer> fields);

/** Similar to {@link #generateAccessor(List)}, but if one of the fields is <code>null</code>,
* it will return <code>null</code>.
*
* <p>For example:
*
* <blockquote><pre>
* new Function1&lt;Employee, Object[]&gt; {
* public Object[] apply(Employee v1) {
* return v1.&lt;fieldN&gt; == null
* ? null
* : v1.&lt;fieldM&gt; == null
* ? null
* : FlatLists.of(v1.&lt;fieldN&gt;, v1.&lt;fieldM&gt;);
* }
* }</pre></blockquote>
*/
Expression generateAccessorWithoutNulls(List<Integer> fields);

/** Generates a selector for the given fields from an expression, with the
* default row format. */
Expression generateSelector(
Expand Down Expand Up @@ -181,6 +198,13 @@ Pair<Expression, Expression> generateCollationKey(
Expression generateComparator(
RelCollation collation);

/** Similar to {@link #generateComparator(RelCollation)}, but with some specificities for
* MergeJoin algorithm: it will not consider two <code>null</code> values as equal.
*
* @see org.apache.calcite.linq4j.EnumerableDefaults#compareNullsLastForMergeJoin
*/
Expression generateMergeJoinComparator(RelCollation collation);

/** Returns a expression that yields a comparer, or null if this type
* is comparable. */
@Nullable Expression comparer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,24 @@ static PhysType of(
}

@Override public Expression generateComparator(RelCollation collation) {
return this.generateComparator(collation, fieldCollation -> {
final int index = fieldCollation.getFieldIndex();
final boolean nullsFirst =
fieldCollation.nullDirection
== RelFieldCollation.NullDirection.FIRST;
final boolean descending =
fieldCollation.getDirection()
== RelFieldCollation.Direction.DESCENDING;
return fieldNullable(index)
? (nullsFirst != descending
? "compareNullsFirst"
: "compareNullsLast")
: "compare";
});
}

private Expression generateComparator(RelCollation collation,
Function1<RelFieldCollation, String> compareMethodNameFunction) {
// int c;
// c = Utilities.compare(v0, v1);
// if (c != 0) return c; // or -c if descending
Expand Down Expand Up @@ -437,9 +455,6 @@ static PhysType of(
default:
break;
}
final boolean nullsFirst =
fieldCollation.nullDirection
== RelFieldCollation.NullDirection.FIRST;
final boolean descending =
fieldCollation.getDirection()
== RelFieldCollation.Direction.DESCENDING;
Expand All @@ -449,11 +464,7 @@ static PhysType of(
parameterC,
Expressions.call(
Utilities.class,
fieldNullable(index)
? (nullsFirst != descending
? "compareNullsFirst"
: "compareNullsLast")
: "compare",
compareMethodNameFunction.apply(fieldCollation),
Expressions.list(
arg0,
arg1)
Expand Down Expand Up @@ -511,6 +522,17 @@ static PhysType of(
memberDeclarations);
}

@Override public Expression generateMergeJoinComparator(RelCollation collation) {
return this.generateComparator(collation, fieldCollation -> {
// merge join keys must be sorted in ascending order, nulls last
assert fieldCollation.nullDirection == RelFieldCollation.NullDirection.LAST;
assert fieldCollation.getDirection() == RelFieldCollation.Direction.ASCENDING;
return fieldNullable(fieldCollation.getFieldIndex())
? "compareNullsLastForMergeJoin"
: "compare";
});
}

@Override public RelDataType getRowType() {
return rowType;
}
Expand Down Expand Up @@ -616,65 +638,79 @@ private List<Expression> fieldReferences(
for (int field : fields) {
list.add(fieldReference(v1, field));
}
switch (list.size()) {
case 2:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST2.method,
list),
v1);
case 3:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST3.method,
list),
v1);
case 4:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST4.method,
list),
v1);
case 5:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST5.method,
list),
v1);
case 6:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST6.method,
list),
v1);
default:
return Expressions.lambda(
Function1.class,
Expressions.call(
List.class,
null,
BuiltInMethod.LIST_N.method,
Expressions.newArrayInit(
Comparable.class,
list)),
v1);
}
return Expressions.lambda(Function1.class, getListExpression(list), v1);
}
}

private static Expression getListExpression(Expressions.FluentList<Expression> list) {
assert list.size() >= 2;

switch (list.size()) {
case 2:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST2.method,
list);
case 3:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST3.method,
list);
case 4:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST4.method,
list);
case 5:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST5.method,
list);
case 6:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST6.method,
list);
default:
return Expressions.call(
List.class,
null,
BuiltInMethod.LIST_N.method,
Expressions.newArrayInit(Comparable.class, list));
}
}

@Override public Expression generateAccessorWithoutNulls(List<Integer> fields) {
if (fields.size() < 2) {
return generateAccessor(fields);
}

ParameterExpression v1 = Expressions.parameter(javaRowClass, "v1");
Expressions.FluentList<Expression> list = Expressions.list();
for (int field : fields) {
list.add(fieldReference(v1, field));
}

// (v1.<field0> == null)
// ? null
// : (v1.<field1> == null)
// ? null;
// : ...
// : FlatLists.of(...);
Expression exp = getListExpression(list);
for (int i = list.size() - 1; i >= 0; i--) {
exp =
Expressions.condition(
Expressions.equal(list.get(i), Expressions.constant(null)),
Expressions.constant(null),
exp);
}
return Expressions.lambda(Function1.class, exp, v1);
}

@Override public Expression fieldReference(
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/org/apache/calcite/runtime/Utilities.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.calcite.runtime;

import org.apache.calcite.linq4j.EnumerableDefaults;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.text.Collator;
Expand Down Expand Up @@ -250,6 +252,15 @@ public static int compareNullsLast(List v0, List v1) {
: FlatLists.ComparableListImpl.compare(v0, v1);
}

public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1) {
return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1);
}

public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1,
Comparator comparator) {
return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1, comparator);
}

/** Creates a pattern builder. */
public static Pattern.PatternBuilder patternBuilder() {
return Pattern.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public enum BuiltInMethod {
List.class, int.class, Consumer.class),
MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class,
Enumerable.class, Function1.class, Function1.class, Predicate2.class, Function2.class,
JoinType.class, Comparator.class),
JoinType.class, Comparator.class, EqualityComparer.class),
SLICE0(Enumerables.class, "slice0", Enumerable.class),
SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class,
Enumerable.class, Function1.class, Function1.class,
Expand Down
Loading

0 comments on commit 3aee0b8

Please sign in to comment.