diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index ba1d362d824..269a72fea7a 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -2375,24 +2375,28 @@ public RelBuilder join(JoinRelType joinType, RexNode condition, } if (correlate) { final CorrelationId id = Iterables.getOnlyElement(variablesSet); - final ImmutableBitSet requiredColumns = - RelOptUtil.correlationColumns(id, right.rel); if (!RelOptUtil.notContainsCorrelation(left.rel, id, Litmus.IGNORE)) { throw new IllegalArgumentException("variable " + id + " must not be used by left input to correlation"); } + // Correlate does not have an ON clause. switch (joinType) { case LEFT: - // Correlate does not have an ON clause. - // For a LEFT correlate, predicate must be evaluated first. - // For INNER, we can defer. + case SEMI: + case ANTI: + // For a LEFT/SEMI/ANTI, predicate must be evaluated first. stack.push(right); filter(condition.accept(new Shifter(left.rel, id, right.rel))); right = stack.pop(); break; - default: + case INNER: + // For INNER, we can defer. postCondition = condition; + break; + default: + throw new IllegalArgumentException("Correlated " + joinType + " join is not supported"); } + final ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, right.rel); join = struct.correlateFactory.createCorrelate(left.rel, right.rel, id, requiredColumns, joinType); diff --git a/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java b/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java index 63e26367f3a..aec9d95e72a 100644 --- a/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java @@ -333,7 +333,7 @@ private void verify(RelNode rel, String expectedPhysical, String expectedLogical ImmutableSet.of(v.get().id)) .build(); String expectedPhysical = "" - + "EnumerableCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{}])\n" + + "EnumerableCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{5}])\n" + " EnumerableTableScan(table=[[scott, EMP]])\n" + " EnumerableFilter(condition=[=($cor0.SAL, 1000)])\n" + " EnumerableTableScan(table=[[scott, DEPT]])\n"; diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 1d062f4daa8..258631761a8 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -2336,7 +2336,7 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) { // Note that the join filter gets pushed to the right-hand input of // LogicalCorrelate final String expected = "" - + "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n" + + "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{5, 7}])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalFilter(condition=[=($cor0.SAL, 1000)])\n" + " LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n" @@ -3670,7 +3670,7 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) { final String expected = "" + "LogicalCorrelate(correlation=[$cor0], joinType=[left], " - + "requiredColumns=[{7}])\n" + + "requiredColumns=[{5, 7}])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalFilter(condition=[=($cor0.SAL, 1000)])\n" + " LogicalFilter(condition=[OR(" @@ -3870,6 +3870,70 @@ private void checkExpandTable(RelBuilder builder, Matcher matcher) { assertThat(root, hasTree(expected)); } + @Test void testSimpleSemiCorrelateViaJoin() { + RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.SEMI); + final String expected = "" + + "LogicalCorrelate(correlation=[$cor0], joinType=[semi], requiredColumns=[{7}])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } + + @Test void testSimpleAntiCorrelateViaJoin() { + RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.ANTI); + final String expected = "" + + "LogicalCorrelate(correlation=[$cor0], joinType=[anti], requiredColumns=[{7}])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } + + @Test void testSimpleLeftCorrelateViaJoin() { + RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.LEFT); + final String expected = "" + + "LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } + + @Test void testSimpleInnerCorrelateViaJoin() { + RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.INNER); + final String expected = "" + + "LogicalFilter(condition=[=($7, $8)])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(root, hasTree(expected)); + } + + @Test void testSimpleRightCorrelateViaJoinThrowsException() { + assertThrows(IllegalArgumentException.class, + () -> buildSimpleCorrelateWithJoin(JoinRelType.RIGHT)); + } + + @Test void testSimpleFullCorrelateViaJoinThrowsException() { + assertThrows(IllegalArgumentException.class, + () -> buildSimpleCorrelateWithJoin(JoinRelType.FULL)); + } + + private static RelNode buildSimpleCorrelateWithJoin(JoinRelType type) { + final RelBuilder builder = RelBuilder.create(config().build()); + final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); + return builder + .scan("EMP") + .variable(v) + .scan("DEPT") + .join(type, + builder.equals( + builder.field(2, 0, "DEPTNO"), + builder.field(2, 1, "DEPTNO")), ImmutableSet.of(v.get().id)) + .build(); + } + @Test void testCorrelateWithComplexFields() { final RelBuilder builder = RelBuilder.create(config().build()); final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index c30446ffd48..4c509c31cf5 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -10545,7 +10545,7 @@ LogicalProject(DEPTNO=[$0])