Skip to content

Commit 163a8a1

Browse files
committed
Make equi join clauses normalized according to join sides
1 parent 3a251cc commit 163a8a1

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

presto-main/src/main/java/io/prestosql/sql/planner/plan/JoinNode.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,24 @@ public JoinNode(
101101
this.spillable = spillable;
102102
this.dynamicFilters = ImmutableMap.copyOf(requireNonNull(dynamicFilters, "dynamicFilters is null"));
103103

104+
Set<Symbol> leftSymbols = ImmutableSet.copyOf(left.getOutputSymbols());
105+
Set<Symbol> rightSymbols = ImmutableSet.copyOf(right.getOutputSymbols());
104106
Set<Symbol> inputSymbols = ImmutableSet.<Symbol>builder()
105-
.addAll(left.getOutputSymbols())
106-
.addAll(right.getOutputSymbols())
107+
.addAll(leftSymbols)
108+
.addAll(rightSymbols)
107109
.build();
108110
checkArgument(new HashSet<>(inputSymbols).containsAll(outputSymbols), "Left and right join inputs do not contain all output symbols");
109111
checkArgument(!isCrossJoin() || inputSymbols.size() == outputSymbols.size(), "Cross join does not support output symbols pruning or reordering");
110112

111113
checkArgument(!(criteria.isEmpty() && leftHashSymbol.isPresent()), "Left hash symbol is only valid in an equijoin");
112114
checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin");
113115

116+
criteria.forEach(equiJoinClause ->
117+
checkArgument(
118+
leftSymbols.contains(equiJoinClause.getLeft()) &&
119+
rightSymbols.contains(equiJoinClause.getRight()),
120+
"Equality join criteria should be normalized according to join sides: %s", equiJoinClause));
121+
114122
if (distributionType.isPresent()) {
115123
// The implementation of full outer join only works if the data is hash partitioned.
116124
checkArgument(
@@ -127,7 +135,7 @@ public JoinNode(
127135
}
128136

129137
for (Symbol symbol : dynamicFilters.values()) {
130-
checkArgument(right.getOutputSymbols().contains(symbol), "Right join input doesn't contain symbol for dynamic filter: %s", symbol);
138+
checkArgument(rightSymbols.contains(symbol), "Right join input doesn't contain symbol for dynamic filter: %s", symbol);
131139
}
132140
}
133141

presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public void testJoinOrder()
118118
values("b")),
119119
values("c"),
120120
"a", "c",
121-
"c", "b");
121+
"b", "c");
122122

123123
JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator());
124124

@@ -137,7 +137,7 @@ public void testJoinOrderWithRealCrossJoin()
137137
values("b")),
138138
values("c"),
139139
"a", "c",
140-
"c", "b");
140+
"b", "c");
141141

142142
PlanNode rightPlan =
143143
joinNode(
@@ -146,7 +146,7 @@ public void testJoinOrderWithRealCrossJoin()
146146
values("y")),
147147
values("z"),
148148
"x", "z",
149-
"z", "y");
149+
"y", "z");
150150

151151
PlanNode plan = joinNode(leftPlan, rightPlan);
152152

@@ -167,8 +167,8 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes()
167167
values("b1", "b2")),
168168
values("c1", "c2"),
169169
"a", "c1",
170-
"c1", "b1",
171-
"c2", "b2");
170+
"b1", "c1",
171+
"b2", "c2");
172172

173173
JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator());
174174

@@ -187,7 +187,7 @@ public void testDoesNotChangeOrderWithoutCrossJoin()
187187
values("b"),
188188
"a", "b"),
189189
values("c"),
190-
"c", "b");
190+
"b", "c");
191191

192192
JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator());
193193

@@ -205,7 +205,7 @@ public void testDoNotReorderCrossJoins()
205205
values("a"),
206206
values("b")),
207207
values("c"),
208-
"c", "b");
208+
"b", "c");
209209

210210
JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator());
211211

@@ -290,7 +290,7 @@ public void testGiveUpOnComplexProjections()
290290
new SymbolReference("b")),
291291
values("c"),
292292
"a2", "c",
293-
"c", "b");
293+
"b", "c");
294294

295295
assertEquals(JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()).size(), 2);
296296
}
@@ -309,8 +309,8 @@ private Function<PlanBuilder, PlanNode> crossJoinAndJoin(JoinNode.Type secondJoi
309309
p.values(axSymbol),
310310
p.values(bySymbol)),
311311
p.values(cxSymbol, cySymbol),
312-
new EquiJoinClause(cxSymbol, axSymbol),
313-
new EquiJoinClause(cySymbol, bySymbol));
312+
new EquiJoinClause(axSymbol, cxSymbol),
313+
new EquiJoinClause(bySymbol, cySymbol));
314314
};
315315
}
316316

presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestReorderJoins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ public void testReplicatedScalarJoinEvenWhereSessionRequiresRepartitioned()
284284
INNER,
285285
p.values(new PlanNodeId("valuesB"), ImmutableList.of(p.symbol("B1")), TWO_ROWS),
286286
p.values(new PlanNodeId("valuesA"), p.symbol("A1")), // matches isAtMostScalar
287-
ImmutableList.of(new EquiJoinClause(p.symbol("A1"), p.symbol("B1"))),
287+
ImmutableList.of(new EquiJoinClause(p.symbol("B1"), p.symbol("A1"))),
288288
ImmutableList.of(p.symbol("A1"), p.symbol("B1")),
289289
Optional.empty()))
290290
.overrideStats("valuesA", valuesA)

0 commit comments

Comments
 (0)