Skip to content

Commit

Permalink
[CALCITE-1045][CALCITE-5127] Support correlation variables in project
Browse files Browse the repository at this point in the history
To some extend correlation in project was already supported even before
this change. However, the fact that the correlation variables were not
explicitly present (and returned by the operator) creates problems
cause we cannot safely deduce if a column/field is used and thus we may
wrongly remove those fields when using the RelFieldTrimmer, when
merging projections, etc.; see queries and discussion under the
respective JIRAs.

The addition of correlation variables in project also aligns the code
with Filter, Join; the latter explicitly set correlation variables.

Co-authored-by: korlov42 <korlov@gridgain.com>

Close apache#2813
Close apache#2623
  • Loading branch information
libenchao authored and zabetak committed Oct 4, 2022
1 parent 3a38ebf commit c2407f5
Show file tree
Hide file tree
Showing 46 changed files with 510 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.calcite.util.Pair;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -42,7 +43,7 @@
public class CassandraProject extends Project implements CassandraRel {
public CassandraProject(RelOptCluster cluster, RelTraitSet traitSet,
RelNode input, List<? extends RexNode> projects, RelDataType rowType) {
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType);
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
assert getConvention() == CassandraRel.CONVENTION;
assert getConvention() == input.getConvention();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ protected CassandraProjectRule(Config config) {
return false;
}
}

return true;
return project.getVariablesSet().isEmpty();
}

@Override public RelNode convert(RelNode rel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -54,7 +55,7 @@ public EnumerableProject(
RelNode input,
List<? extends RexNode> projects,
RelDataType rowType) {
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType);
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
assert getConvention() instanceof EnumerableConvention;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.calcite.adapter.enumerable;

import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.rel.core.Project;
Expand All @@ -42,6 +43,11 @@ protected EnumerableProjectRule(Config config) {
super(config);
}

@Override public boolean matches(RelOptRuleCall call) {
Project project = call.rel(0);
return project.getVariablesSet().isEmpty();
}

@Override public RelNode convert(RelNode rel) {
final Project project = (Project) rel;
return EnumerableProject.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

import com.google.common.base.Preconditions;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
Expand Down Expand Up @@ -68,7 +70,10 @@ private static class ProjectFactoryImpl
implements org.apache.calcite.rel.core.RelFactories.ProjectFactory {
@Override public RelNode createProject(RelNode input, List<RelHint> hints,
List<? extends RexNode> childExprs,
@Nullable List<? extends @Nullable String> fieldNames) {
@Nullable List<? extends @Nullable String> fieldNames,
Set<CorrelationId> variablesSet) {
Preconditions.checkArgument(variablesSet.isEmpty(),
"EnumerableProject does not allow variables");
final RelDataType rowType =
RexUtil.createStructType(input.getCluster().getTypeFactory(), childExprs,
fieldNames, SqlValidatorUtil.F_SUGGESTER);
Expand Down
12 changes: 10 additions & 2 deletions core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
Expand All @@ -96,7 +97,9 @@ private JdbcRules() {
protected static final Logger LOGGER = CalciteTrace.getPlannerTracer();

static final RelFactories.ProjectFactory PROJECT_FACTORY =
(input, hints, projects, fieldNames) -> {
(input, hints, projects, fieldNames, variablesSet) -> {
Preconditions.checkArgument(variablesSet.isEmpty(),
"JdbcProject does not allow variables");
final RelOptCluster cluster = input.getCluster();
final RelDataType rowType =
RexUtil.createStructType(cluster.getTypeFactory(), projects,
Expand Down Expand Up @@ -510,6 +513,11 @@ private static boolean userDefinedFunctionInProject(Project project) {
return false;
}

@Override public boolean matches(RelOptRuleCall call) {
Project project = call.rel(0);
return project.getVariablesSet().isEmpty();
}

@Override public @Nullable RelNode convert(RelNode rel) {
final Project project = (Project) rel;

Expand All @@ -535,7 +543,7 @@ public JdbcProject(
RelNode input,
List<? extends RexNode> projects,
RelDataType rowType) {
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType);
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
assert getConvention() instanceof JdbcConvention;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
Expand Down Expand Up @@ -386,6 +387,11 @@ protected BindableProjectRule(Config config) {
super(config);
}

@Override public boolean matches(RelOptRuleCall call) {
final LogicalProject project = call.rel(0);
return project.getVariablesSet().isEmpty();
}

@Override public RelNode convert(RelNode rel) {
final LogicalProject project = (LogicalProject) rel;
return new BindableProject(rel.getCluster(),
Expand All @@ -403,7 +409,7 @@ protected BindableProjectRule(Config config) {
public static class BindableProject extends Project implements BindableRel {
public BindableProject(RelOptCluster cluster, RelTraitSet traitSet,
RelNode input, List<? extends RexNode> projects, RelDataType rowType) {
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType);
super(cluster, traitSet, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
assert getConvention() instanceof BindableConvention;
}

Expand Down
10 changes: 7 additions & 3 deletions core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ public static RelNode createCastRel(
List<RexNode> castExps;
RelNode input;
List<RelHint> hints = ImmutableList.of();
Set<CorrelationId> correlationVariables;
if (rel instanceof Project) {
// No need to create another project node if the rel
// is already a project.
Expand All @@ -899,21 +900,23 @@ public static RelNode createCastRel(
((Project) rel).getProjects());
input = rel.getInput(0);
hints = project.getHints();
correlationVariables = project.getVariablesSet();
} else {
castExps = RexUtil.generateCastExpressions(
rexBuilder,
castRowType,
rowType);
input = rel;
correlationVariables = ImmutableSet.of();
}
if (rename) {
// Use names and types from castRowType.
return projectFactory.createProject(input, hints, castExps,
castRowType.getFieldNames());
castRowType.getFieldNames(), correlationVariables);
} else {
// Use names from rowType, types from castRowType.
return projectFactory.createProject(input, hints, castExps,
rowType.getFieldNames());
rowType.getFieldNames(), correlationVariables);
}
}

Expand Down Expand Up @@ -3623,7 +3626,8 @@ public static RelNode projectMapping(
: fieldNames.get(i));
exprList.add(rexBuilder.makeInputRef(rel, source));
}
return projectFactory.createProject(rel, ImmutableList.of(), exprList, outputNameList);
return projectFactory.createProject(rel, ImmutableList.of(), exprList, outputNameList,
ImmutableSet.of());
}

/** Predicate for if a {@link Calc} does not contain windowed aggregates. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.calcite.util.BuiltInMethod;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.lang.reflect.Type;
import java.util.ArrayList;
Expand Down Expand Up @@ -108,7 +109,8 @@ public RelNode translate(Expression expression) {
return LogicalProject.create(input,
ImmutableList.of(),
toRex(input, (FunctionExpression) call.expressions.get(0)),
(List<String>) null);
(List<String>) null,
ImmutableSet.of());

case WHERE:
input = translate(getTargetExpression(call));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.calcite.schema.impl.AbstractTableQueryable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.checker.nullness.qual.PolyNull;
Expand Down Expand Up @@ -549,7 +550,8 @@ private void setRel(RelNode rel) {
RelNode child = toRel(source);
List<RexNode> nodes = translator.toRexList(selector, child);
setRel(
LogicalProject.create(child, ImmutableList.of(), nodes, (List<String>) null));
LogicalProject.create(child, ImmutableList.of(), nodes, (List<String>) null,
ImmutableSet.of()));
return castNonNull(null);
}

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/java/org/apache/calcite/rel/RelNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ public interface RelNode extends RelOptNode, Cloneable {
* expression but also used and therefore not available to parents of this
* relational expression.
*
* <p>Note: only {@link org.apache.calcite.rel.core.Correlate} should set
* variables.
*
* @return Names of variables which are set in this relational
* expression
*/
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/java/org/apache/calcite/rel/RelRoot.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.calcite.util.mapping.Mappings;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -166,7 +167,7 @@ public RelNode project(boolean force) {
for (Pair<Integer, String> field : fields) {
projects.add(rexBuilder.makeInputRef(rel, field.left));
}
return LogicalProject.create(rel, hints, projects, Pair.right(fields));
return LogicalProject.create(rel, hints, projects, Pair.right(fields), ImmutableSet.of());
}

public boolean isNameTrivial() {
Expand Down
39 changes: 35 additions & 4 deletions core/src/main/java/org/apache/calcite/rel/core/Project.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.calcite.util.mapping.Mappings;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.apiguardian.api.API;
import org.checkerframework.checker.nullness.qual.EnsuresNonNullIf;
Expand All @@ -53,6 +54,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static java.util.Objects.requireNonNull;
Expand All @@ -70,6 +72,8 @@ public abstract class Project extends SingleRel implements Hintable {

protected final ImmutableList<RelHint> hints;

protected final ImmutableSet<CorrelationId> variablesSet;

//~ Constructors -----------------------------------------------------------

/**
Expand All @@ -81,6 +85,8 @@ public abstract class Project extends SingleRel implements Hintable {
* @param input Input relational expression
* @param projects List of expressions for the input columns
* @param rowType Output row type
* @param variableSet Correlation variables set by this relational expression
* to be used by nested expressions
*/
@SuppressWarnings("method.invocation.invalid")
protected Project(
Expand All @@ -89,25 +95,38 @@ protected Project(
List<RelHint> hints,
RelNode input,
List<? extends RexNode> projects,
RelDataType rowType) {
RelDataType rowType,
Set<CorrelationId> variableSet) {
super(cluster, traits, input);
assert rowType != null;
this.exps = ImmutableList.copyOf(projects);
this.hints = ImmutableList.copyOf(hints);
this.rowType = rowType;
this.variablesSet = ImmutableSet.copyOf(variableSet);
assert isValid(Litmus.THROW, null);
}

@Deprecated // to be removed before 2.0
protected Project(
RelOptCluster cluster,
RelTraitSet traits,
List<RelHint> hints,
RelNode input,
List<? extends RexNode> projects,
RelDataType rowType) {
this(cluster, traits, hints, input, projects, rowType, ImmutableSet.of());
}

@Deprecated // to be removed before 2.0
protected Project(RelOptCluster cluster, RelTraitSet traits,
RelNode input, List<? extends RexNode> projects, RelDataType rowType) {
this(cluster, traits, ImmutableList.of(), input, projects, rowType);
this(cluster, traits, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
}

@Deprecated // to be removed before 2.0
protected Project(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
List<? extends RexNode> projects, RelDataType rowType, int flags) {
this(cluster, traitSet, ImmutableList.of(), input, projects, rowType);
this(cluster, traitSet, ImmutableList.of(), input, projects, rowType, ImmutableSet.of());
Util.discard(flags);
}

Expand All @@ -120,7 +139,14 @@ protected Project(RelInput input) {
ImmutableList.of(),
input.getInput(),
requireNonNull(input.getExpressionList("exprs"), "exprs"),
input.getRowType("exprs", "fields"));
input.getRowType("exprs", "fields"),
ImmutableSet.copyOf(
Util.transform(
Optional.ofNullable(input.getIntegerList("variablesSet"))
.orElse(ImmutableList.of()),
id -> new CorrelationId(id)
)
));
}

//~ Methods ----------------------------------------------------------------
Expand Down Expand Up @@ -264,8 +290,13 @@ private static int countTrivial(List<RexNode> refs) {
return refs.size();
}

@Override public Set<CorrelationId> getVariablesSet() {
return variablesSet;
}

@Override public RelWriter explainTerms(RelWriter pw) {
super.explainTerms(pw);
pw.itemIf("variablesSet", variablesSet, !variablesSet.isEmpty());
// Skip writing field names so the optimizer can reuse the projects that differ in
// field names only
if (pw.getDetailLevel() == SqlExplainLevel.DIGEST_ATTRIBUTES) {
Expand Down
Loading

0 comments on commit c2407f5

Please sign in to comment.