Skip to content

Commit

Permalink
Remove source location from CanonicalPlanGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
pranjalssh authored and highker committed Jun 17, 2022
1 parent da4f2b5 commit bf9cff7
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.expressions;

import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class CanonicalRowExpressionRewriter
extends RowExpressionRewriter<Void>
{
private static final CanonicalRowExpressionRewriter SINGLETON = new CanonicalRowExpressionRewriter();

private CanonicalRowExpressionRewriter() {}

public static RowExpression canonicalizeRowExpression(RowExpression expression)
{
return RowExpressionTreeRewriter.rewriteWith(SINGLETON, expression, null);
}

@Override
public RowExpression rewriteInputReference(InputReferenceExpression input, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
return input.canonicalize();
}

@Override
public RowExpression rewriteCall(CallExpression call, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
List<RowExpression> arguments = rewrite(call.getArguments(), context, treeRewriter);

if (!sameElements(call.getArguments(), arguments)) {
return new CallExpression(Optional.empty(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), arguments);
}
return call.canonicalize();
}

@Override
public RowExpression rewriteConstant(ConstantExpression literal, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
return literal.canonicalize();
}

@Override
public RowExpression rewriteLambda(LambdaDefinitionExpression lambda, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
RowExpression body = treeRewriter.rewrite(lambda.getBody(), context);
if (body != lambda.getBody()) {
return new LambdaDefinitionExpression(Optional.empty(), lambda.getArgumentTypes(), lambda.getArguments(), body);
}

return lambda.canonicalize();
}

@Override
public RowExpression rewriteVariableReference(VariableReferenceExpression variable, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
return variable.canonicalize();
}

@Override
public RowExpression rewriteSpecialForm(SpecialFormExpression specialForm, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
List<RowExpression> arguments = rewrite(specialForm.getArguments(), context, treeRewriter);

if (!sameElements(specialForm.getArguments(), arguments)) {
return new SpecialFormExpression(Optional.empty(), specialForm.getForm(), specialForm.getType(), arguments);
}
return specialForm.canonicalize();
}

private List<RowExpression> rewrite(List<RowExpression> items, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
List<RowExpression> rewrittenExpressions = new ArrayList<>();
for (RowExpression expression : items) {
rewrittenExpressions.add(treeRewriter.rewrite(expression, context));
}
return Collections.unmodifiableList(rewrittenExpressions);
}

@SuppressWarnings("ObjectEquality")
private static <T> boolean sameElements(Collection<? extends T> a, Collection<? extends T> b)
{
if (a.size() != b.size()) {
return false;
}

Iterator<? extends T> first = a.iterator();
Iterator<? extends T> second = b.iterator();

while (first.hasNext() && second.hasNext()) {
if (first.next() != second.next()) {
return false;
}
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.expressions.CanonicalRowExpressionRewriter.canonicalizeRowExpression;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -276,7 +277,7 @@ public Object getIdentifier(Optional<ConnectorSplit> split)
return ImmutableMap.builder()
.put("schemaTableName", schemaTableName)
.put("domainPredicate", domainPredicate)
.put("remainingPredicate", remainingPredicate)
.put("remainingPredicate", canonicalizeRowExpression(remainingPredicate))
.put("bucketFilter", bucketFilter)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ private void assertSameCanonicalLeafSubPlan(Session session, String sql2, String
.map(Optional::get)
.collect(Collectors.toList());
assertEquals(leafCanonicalPlans.size(), 2);
assertEquals(objectMapper.writeValueAsString(leafCanonicalPlans.get(0)).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""), objectMapper.writeValueAsString(leafCanonicalPlans.get(1)).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""));
assertEquals(objectMapper.writeValueAsString(leafCanonicalPlans.get(0)), objectMapper.writeValueAsString(leafCanonicalPlans.get(1)));
}

private void assertDifferentCanonicalLeafSubPlan(Session session, String sql1, String sql2)
Expand All @@ -195,6 +195,6 @@ private void assertDifferentCanonicalLeafSubPlan(Session session, String sql1, S
Optional<CanonicalPlanFragment> canonicalPlan2 = generateCanonicalPlan(fragment2.getRoot(), fragment2.getPartitioningScheme());
assertTrue(canonicalPlan1.isPresent());
assertTrue(canonicalPlan2.isPresent());
assertNotEquals(objectMapper.writeValueAsString(canonicalPlan1).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""), objectMapper.writeValueAsString(canonicalPlan2).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""));
assertNotEquals(objectMapper.writeValueAsString(canonicalPlan1), objectMapper.writeValueAsString(canonicalPlan2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.Optional;

import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.expressions.CanonicalRowExpressionRewriter.canonicalizeRowExpression;
import static com.facebook.presto.sql.planner.CanonicalPartitioningScheme.getCanonicalPartitioningScheme;
import static com.facebook.presto.sql.planner.CanonicalTableScanNode.CanonicalTableHandle.getCanonicalTableHandle;
import static com.facebook.presto.sql.planner.RowExpressionVariableInliner.inlineVariables;
Expand Down Expand Up @@ -95,7 +96,7 @@ public Optional<PlanNode> visitAggregation(AggregationNode node, Map<VariableRef
}

return Optional.of(new AggregationNode(
node.getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
source.get(),
aggregations.build(),
Expand All @@ -111,8 +112,8 @@ public Optional<PlanNode> visitAggregation(AggregationNode node, Map<VariableRef
private static Aggregation getCanonicalAggregation(Aggregation aggregation, Map<VariableReferenceExpression, VariableReferenceExpression> context)
{
return new Aggregation(
(CallExpression) inlineVariables(context, aggregation.getCall()),
aggregation.getFilter().map(filter -> inlineVariables(context, filter)),
(CallExpression) inlineAndCanonicalize(context, aggregation.getCall()),
aggregation.getFilter().map(filter -> inlineAndCanonicalize(context, filter)),
aggregation.getOrderBy().map(orderBy -> getCanonicalOrderingScheme(orderBy, context)),
aggregation.isDistinct(),
aggregation.getMask().map(context::get));
Expand Down Expand Up @@ -185,7 +186,7 @@ public Optional<PlanNode> visitGroupId(GroupIdNode node, Map<VariableReferenceEx
context.put(node.getGroupIdVariable(), groupId);

return Optional.of(new GroupIdNode(
source.get().getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
source.get(),
groupingSets.build(),
Expand All @@ -207,10 +208,10 @@ public Optional<PlanNode> visitUnnest(UnnestNode node, Map<VariableReferenceExpr
// Generate canonical unnestVariables.
ImmutableMap.Builder<VariableReferenceExpression, List<VariableReferenceExpression>> newUnnestVariables = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> unnestVariable : node.getUnnestVariables().entrySet()) {
VariableReferenceExpression input = (VariableReferenceExpression) inlineVariables(context, unnestVariable.getKey());
VariableReferenceExpression input = (VariableReferenceExpression) inlineAndCanonicalize(context, unnestVariable.getKey());
ImmutableList.Builder<VariableReferenceExpression> newVariables = ImmutableList.builder();
for (VariableReferenceExpression variable : unnestVariable.getValue()) {
VariableReferenceExpression newVariable = variableAllocator.newVariable(variable.getSourceLocation(), "unnest_field", variable.getType());
VariableReferenceExpression newVariable = variableAllocator.newVariable(Optional.empty(), "unnest_field", variable.getType());
context.put(variable, newVariable);
newVariables.add(newVariable);
}
Expand All @@ -220,17 +221,17 @@ public Optional<PlanNode> visitUnnest(UnnestNode node, Map<VariableReferenceExpr
// Generate canonical ordinality variable
Optional<VariableReferenceExpression> ordinalityVariable = node.getOrdinalityVariable()
.map(variable -> {
VariableReferenceExpression newVariable = variableAllocator.newVariable(variable.getSourceLocation(), "unnest_ordinality", variable.getType());
VariableReferenceExpression newVariable = variableAllocator.newVariable(Optional.empty(), "unnest_ordinality", variable.getType());
context.put(variable, newVariable);
return newVariable;
});

return Optional.of(new UnnestNode(
node.getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
source.get(),
node.getReplicateVariables().stream()
.map(variable -> (VariableReferenceExpression) inlineVariables(context, variable))
.map(variable -> (VariableReferenceExpression) inlineAndCanonicalize(context, variable))
.collect(toImmutableList()),
newUnnestVariables.build(),
ordinalityVariable));
Expand All @@ -245,7 +246,7 @@ public Optional<PlanNode> visitProject(ProjectNode node, Map<VariableReferenceEx
}

List<RowExpressionReference> rowExpressionReferences = node.getAssignments().entrySet().stream()
.map(entry -> new RowExpressionReference(inlineVariables(context, entry.getValue()), entry.getKey()))
.map(entry -> new RowExpressionReference(inlineAndCanonicalize(context, entry.getValue()), entry.getKey()))
.sorted(comparing(rowExpressionReference -> rowExpressionReference.getRowExpression().toString()))
.collect(toImmutableList());
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.builder();
Expand All @@ -256,7 +257,7 @@ public Optional<PlanNode> visitProject(ProjectNode node, Map<VariableReferenceEx
}

return Optional.of(new ProjectNode(
node.getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
source.get(),
new Assignments(assignments.build()),
Expand Down Expand Up @@ -290,10 +291,10 @@ public Optional<PlanNode> visitFilter(FilterNode node, Map<VariableReferenceExpr
{
Optional<PlanNode> source = node.getSource().accept(this, context);
return source.map(planNode -> new FilterNode(
node.getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
planNode,
inlineVariables(context, node.getPredicate())));
inlineAndCanonicalize(context, node.getPredicate())));
}

@Override
Expand All @@ -306,20 +307,25 @@ public Optional<PlanNode> visitTableScan(TableScanNode node, Map<VariableReferen
ImmutableList.Builder<VariableReferenceExpression> outputVariables = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> assignments = ImmutableMap.builder();
for (ColumnReference columnReference : columnReferences) {
VariableReferenceExpression reference = variableAllocator.newVariable(columnReference.getVariableReferenceExpression().getSourceLocation(), columnReference.getColumnHandle().toString(), columnReference.getVariableReferenceExpression().getType());
VariableReferenceExpression reference = variableAllocator.newVariable(Optional.empty(), columnReference.getColumnHandle().toString(), columnReference.getVariableReferenceExpression().getType());
context.put(columnReference.getVariableReferenceExpression(), reference);
outputVariables.add(reference);
assignments.put(reference, columnReference.getColumnHandle());
}

return Optional.of(new CanonicalTableScanNode(
node.getSourceLocation(),
Optional.empty(),
planNodeidAllocator.getNextId(),
getCanonicalTableHandle(node.getTable()),
outputVariables.build(),
assignments.build()));
}

private static RowExpression inlineAndCanonicalize(Map<VariableReferenceExpression, VariableReferenceExpression> context, RowExpression expression)
{
return inlineVariables(context::get, canonicalizeRowExpression(expression));
}

private static class ColumnReference
{
private final ColumnHandle columnHandle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,11 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
{
throw new UnsupportedOperationException("OriginalExpression cannot appear in a RowExpression tree");
}

@Override
public RowExpression canonicalize()
{
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private void assertSameCanonicalLeafSubPlan(String sql1, String sql2)
.collect(Collectors.toList());
assertEquals(leafCanonicalPlans.size(), 2);
assertEquals(leafCanonicalPlans.get(0), leafCanonicalPlans.get(1));
assertEquals(objectMapper.writeValueAsString(leafCanonicalPlans.get(0)).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""), objectMapper.writeValueAsString(leafCanonicalPlans.get(1)).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""));
assertEquals(objectMapper.writeValueAsString(leafCanonicalPlans.get(0)), objectMapper.writeValueAsString(leafCanonicalPlans.get(1)));
}

private void assertDifferentCanonicalLeafSubPlan(String sql1, String sql2)
Expand All @@ -217,7 +217,7 @@ private void assertDifferentCanonicalLeafSubPlan(String sql1, String sql2)
Optional<CanonicalPlanFragment> canonicalPlan2 = generateCanonicalPlan(fragment2.getRoot(), fragment2.getPartitioningScheme());
assertTrue(canonicalPlan1.isPresent());
assertTrue(canonicalPlan2.isPresent());
assertNotEquals(objectMapper.writeValueAsString(canonicalPlan1).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""), objectMapper.writeValueAsString(canonicalPlan2).replaceAll("\"sourceLocation\":\\{[^\\}]*\\}", ""));
assertNotEquals(objectMapper.writeValueAsString(canonicalPlan1), objectMapper.writeValueAsString(canonicalPlan2));
}

// We add the following field test to make sure corresponding canonical class is still correct.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,10 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
{
return visitor.visitCall(this, context);
}

@Override
public RowExpression canonicalize()
{
return getSourceLocation().isPresent() ? new CallExpression(Optional.empty(), displayName, functionHandle, returnType, arguments) : this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,10 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
{
return visitor.visitConstant(this, context);
}

@Override
public RowExpression canonicalize()
{
return getSourceLocation().isPresent() ? new ConstantExpression(Optional.empty(), value, type) : this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
return visitor.visitInputReference(this, context);
}

@Override
public RowExpression canonicalize()
{
return getSourceLocation().isPresent() ? new InputReferenceExpression(Optional.empty(), field, type) : this;
}

@Override
public boolean equals(Object obj)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
return visitor.visitLambda(this, context);
}

@Override
public RowExpression canonicalize()
{
return getSourceLocation().isPresent() ? new LambdaDefinitionExpression(Optional.empty(), argumentTypes, arguments, body) : this;
}

private static void checkArgument(boolean condition, String message, Object... messageArgs)
{
if (!condition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,10 @@ public Optional<SourceLocation> getSourceLocation()
public abstract String toString();

public abstract <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context);

/**
* @return Canonical form of RowExpression by removing non-critical information
* from the node, like source location. Does NOT canonicalize the child expressions.
*/
public abstract RowExpression canonicalize();
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ public <R, C> R accept(RowExpressionVisitor<R, C> visitor, C context)
return visitor.visitSpecialForm(this, context);
}

@Override
public RowExpression canonicalize()
{
return getSourceLocation().isPresent() ? new SpecialFormExpression(Optional.empty(), form, returnType, arguments) : this;
}

public enum Form
{
IF,
Expand Down
Loading

0 comments on commit bf9cff7

Please sign in to comment.