Skip to content

Commit

Permalink
[CALCITE-4574] Wrong/Invalid plans when using RelBuilder#join with co…
Browse files Browse the repository at this point in the history
…rrelations

1. Gather required columns from the right side after the handling of the
filter to account for those columns present in the join condition.
2. Predicate for SEMI/ANTI join types should be pushed to the right
cause otherwise columns in the condition referencing the right side will
be invalid.
3. Throw IllegalArgumentException for non-supported correlation joins.
4. Update existing tests with the correct plans
5. Add new tests for RelBuilder#join with correlation covering all join
types.

Close apache#2393
  • Loading branch information
zabetak committed Apr 28, 2021
1 parent de847c3 commit 8c2228e
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
16 changes: 10 additions & 6 deletions core/src/main/java/org/apache/calcite/tools/RelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2375,24 +2375,28 @@ public RelBuilder join(JoinRelType joinType, RexNode condition,
}
if (correlate) {
final CorrelationId id = Iterables.getOnlyElement(variablesSet);
final ImmutableBitSet requiredColumns =
RelOptUtil.correlationColumns(id, right.rel);
if (!RelOptUtil.notContainsCorrelation(left.rel, id, Litmus.IGNORE)) {
throw new IllegalArgumentException("variable " + id
+ " must not be used by left input to correlation");
}
// Correlate does not have an ON clause.
switch (joinType) {
case LEFT:
// Correlate does not have an ON clause.
// For a LEFT correlate, predicate must be evaluated first.
// For INNER, we can defer.
case SEMI:
case ANTI:
// For a LEFT/SEMI/ANTI, predicate must be evaluated first.
stack.push(right);
filter(condition.accept(new Shifter(left.rel, id, right.rel)));
right = stack.pop();
break;
default:
case INNER:
// For INNER, we can defer.
postCondition = condition;
break;
default:
throw new IllegalArgumentException("Correlated " + joinType + " join is not supported");
}
final ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, right.rel);
join =
struct.correlateFactory.createCorrelate(left.rel, right.rel, id,
requiredColumns, joinType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ private void verify(RelNode rel, String expectedPhysical, String expectedLogical
ImmutableSet.of(v.get().id))
.build();
String expectedPhysical = ""
+ "EnumerableCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{}])\n"
+ "EnumerableCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{5}])\n"
+ " EnumerableTableScan(table=[[scott, EMP]])\n"
+ " EnumerableFilter(condition=[=($cor0.SAL, 1000)])\n"
+ " EnumerableTableScan(table=[[scott, DEPT]])\n";
Expand Down
68 changes: 66 additions & 2 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2336,7 +2336,7 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
// Note that the join filter gets pushed to the right-hand input of
// LogicalCorrelate
final String expected = ""
+ "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n"
+ "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{5, 7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalFilter(condition=[=($cor0.SAL, 1000)])\n"
+ " LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n"
Expand Down Expand Up @@ -3670,7 +3670,7 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {

final String expected = ""
+ "LogicalCorrelate(correlation=[$cor0], joinType=[left], "
+ "requiredColumns=[{7}])\n"
+ "requiredColumns=[{5, 7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalFilter(condition=[=($cor0.SAL, 1000)])\n"
+ " LogicalFilter(condition=[OR("
Expand Down Expand Up @@ -3870,6 +3870,70 @@ private void checkExpandTable(RelBuilder builder, Matcher<RelNode> matcher) {
assertThat(root, hasTree(expected));
}

@Test void testSimpleSemiCorrelateViaJoin() {
RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.SEMI);
final String expected = ""
+ "LogicalCorrelate(correlation=[$cor0], joinType=[semi], requiredColumns=[{7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(root, hasTree(expected));
}

@Test void testSimpleAntiCorrelateViaJoin() {
RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.ANTI);
final String expected = ""
+ "LogicalCorrelate(correlation=[$cor0], joinType=[anti], requiredColumns=[{7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(root, hasTree(expected));
}

@Test void testSimpleLeftCorrelateViaJoin() {
RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.LEFT);
final String expected = ""
+ "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(root, hasTree(expected));
}

@Test void testSimpleInnerCorrelateViaJoin() {
RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.INNER);
final String expected = ""
+ "LogicalFilter(condition=[=($7, $8)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(root, hasTree(expected));
}

@Test void testSimpleRightCorrelateViaJoinThrowsException() {
assertThrows(IllegalArgumentException.class,
() -> buildSimpleCorrelateWithJoin(JoinRelType.RIGHT));
}

@Test void testSimpleFullCorrelateViaJoinThrowsException() {
assertThrows(IllegalArgumentException.class,
() -> buildSimpleCorrelateWithJoin(JoinRelType.FULL));
}

private static RelNode buildSimpleCorrelateWithJoin(JoinRelType type) {
final RelBuilder builder = RelBuilder.create(config().build());
final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
return builder
.scan("EMP")
.variable(v)
.scan("DEPT")
.join(type,
builder.equals(
builder.field(2, 0, "DEPTNO"),
builder.field(2, 1, "DEPTNO")), ImmutableSet.of(v.get().id))
.build();
}

@Test void testCorrelateWithComplexFields() {
final RelBuilder builder = RelBuilder.create(config().build());
final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10545,7 +10545,7 @@ LogicalProject(DEPTNO=[$0])
<Resource name="planMid">
<![CDATA[
LogicalProject(SAL=[$5], EXPR$1=[IS NULL($10)])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{2}])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 2}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalFilter(condition=[=($cor0.EMPNO, $0)])
LogicalProject(DEPTNO=[$0], i=[true])
Expand Down Expand Up @@ -12307,7 +12307,7 @@ LogicalProject(SAL=[$5])
LogicalProject(SAL=[$5])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[OR(=($9, 0), IS NOT TRUE(OR(IS NOT NULL($12), <($10, $9))))])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{2}])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 2}])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{2}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(c=[$0], ck=[$0])
Expand Down Expand Up @@ -12384,7 +12384,7 @@ LogicalProject(EMPNO=[$1])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[OR(=($9, 0), IS NOT TRUE(OR(IS NOT NULL($12), <($10, $9))))])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 1}])
LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(c=[$0], ck=[$0])
Expand Down

0 comments on commit 8c2228e

Please sign in to comment.