Skip to content

Commit 98350e2

Browse files
nits and refactor
1 parent ecb5608 commit 98350e2

File tree

4 files changed

+49
-40
lines changed

4 files changed

+49
-40
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownJoin.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
*/
3333
@Evolving
3434
public interface SupportsPushDownJoin extends ScanBuilder {
35-
boolean isRightSideCompatibleForJoin(SupportsPushDownJoin other);
35+
boolean isRightSideCompatibleForJoin(SupportsPushDownJoin other);
3636

37-
boolean pushJoin(
38-
SupportsPushDownJoin other,
39-
JoinType joinType,
40-
Optional<Predicate> condition,
41-
StructType leftRequiredSchema,
42-
StructType rightRequiredSchema
37+
boolean pushJoin(
38+
SupportsPushDownJoin other,
39+
JoinType joinType,
40+
Optional<Predicate> condition,
41+
StructType leftRequiredSchema,
42+
StructType rightRequiredSchema
4343
);
4444
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/JoinTypeSQLBuilder.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@
2929
* @since 4.0.0
3030
*/
3131
public class JoinTypeSQLBuilder {
32-
public String build(JoinType joinType) {
33-
if (joinType instanceof Inner inner) {
34-
return visitInnerJoin(inner);
35-
} else {
36-
return visitUnexpectedJoinType(joinType);
37-
}
32+
public String build(JoinType joinType) {
33+
if (joinType instanceof Inner inner) {
34+
return visitInnerJoin(inner);
35+
} else {
36+
return visitUnexpectedJoinType(joinType);
3837
}
38+
}
3939

40-
protected String visitInnerJoin(Inner inner) {
40+
protected String visitInnerJoin(Inner inner) {
4141
return "INNER JOIN";
4242
}
4343

44-
protected String visitUnexpectedJoinType(JoinType joinType) throws IllegalArgumentException {
45-
Map<String, String> params = new HashMap<>();
46-
params.put("joinType", String.valueOf(joinType));
47-
throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3209", params);
48-
}
44+
protected String visitUnexpectedJoinType(JoinType joinType) throws IllegalArgumentException {
45+
Map<String, String> params = new HashMap<>();
46+
params.put("joinType", String.valueOf(joinType));
47+
throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3209", params);
48+
}
4949
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public String build(Expression expr) {
7878
if (expr instanceof Literal literal) {
7979
return visitLiteral(literal);
8080
} else if (expr instanceof JoinColumn column) {
81-
return visitJoinColumn(column);
81+
return visitJoinColumn(column);
8282
} else if (expr instanceof NamedReference namedReference) {
8383
return visitNamedReference(namedReference);
8484
} else if (expr instanceof Cast cast) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,32 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
111111
def pushDownJoin(plan: LogicalPlan): LogicalPlan = plan.transformUp {
112112
// Join can be attempted to be pushed down only if left and right side of join are
113113
// compatible (same data source, for example). Also, another requirement is that if
114-
// there are projections between Join and ScanBuilderHolder, these projec
114+
// there are projections between Join and ScanBuilderHolder, these projections need to be
115+
// AttributeReferences. We could probably support Alias as well, but this should be on
116+
// TODO list.
117+
// Alias can exist between Join and sHolder node because the query below is not valid:
118+
// SELECT * FROM
119+
// (SELECT * FROM tbl t1 JOIN tbl2 t2) p
120+
// JOIN
121+
// (SELECT * FROM tbl t3 JOIN tbl3 t4) q
122+
// ON p.t1.col = q.t3.col (this is not possible)
123+
// It's because there are 2 same tables in both sides of top level join and it not possible
124+
// to fully qualified the column names in condition. Therefore, query should be rewritten so
125+
// that each of the outputs of child joins are aliased, so there would be a projection
126+
// with aliases between top level join and scanBuilderHolder (that has pushed child joins).
115127
case node @ Join(
116-
PhysicalOperation(
117-
leftProjections,
118-
Nil,
119-
leftHolder @ ScanBuilderHolder(_, _, lBuilder: SupportsPushDownJoin)
120-
),
121-
PhysicalOperation(
122-
rightProjections,
123-
Nil,
124-
rightHolder @ ScanBuilderHolder(_, _, rBuilder: SupportsPushDownJoin)
125-
),
126-
joinType,
127-
condition,
128+
PhysicalOperation(
129+
leftProjections,
130+
Nil,
131+
leftHolder @ ScanBuilderHolder(_, _, lBuilder: SupportsPushDownJoin)
132+
),
133+
PhysicalOperation(
134+
rightProjections,
135+
Nil,
136+
rightHolder @ ScanBuilderHolder(_, _, rBuilder: SupportsPushDownJoin)
137+
),
138+
joinType,
139+
condition,
128140
_) if conf.dataSourceV2JoinPushdown &&
129141
// TODO: I think projections will always be Seq[AttributeReference] because
130142
// When
@@ -184,7 +196,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
184196

185197
leftHolder.exprIdToOriginalName ++= rightHolder.exprIdToOriginalName
186198
leftHolder.output = newOutput
187-
leftHolder.isJoinPushed = true
188199
leftHolder
189200
} else {
190201
node
@@ -206,15 +217,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
206217

207218
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
208219
val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
209-
val normalizedAggExprs = if (holder.isJoinPushed) {
220+
val normalizedAggExprs = if (holder.joinedRelations.nonEmpty) {
210221
DataSourceStrategy.normalizeExprs(aggregates, holder.output)
211222
.asInstanceOf[Seq[AggregateExpression]]
212223
} else {
213224
DataSourceStrategy.normalizeExprs(aggregates, holder.relation.output)
214225
.asInstanceOf[Seq[AggregateExpression]]
215226
}
216227
val normalizedGroupingExpr =
217-
if (holder.isJoinPushed) {
228+
if (holder.joinedRelations.nonEmpty) {
218229
DataSourceStrategy.normalizeExprs(actualGroupExprs, holder.output)
219230
} else {
220231
DataSourceStrategy.normalizeExprs(actualGroupExprs, holder.relation.output)
@@ -459,7 +470,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
459470
}
460471

461472
def buildScanWithPushedJoin(plan: LogicalPlan): LogicalPlan = plan.transform {
462-
case holder: ScanBuilderHolder if holder.isJoinPushed && !holder.isStreaming =>
473+
case holder: ScanBuilderHolder if holder.joinedRelations.nonEmpty && !holder.isStreaming =>
463474
val scan = holder.builder.build()
464475
val realOutput = toAttributes(scan.readSchema())
465476
assert(realOutput.length == holder.output.length,
@@ -563,7 +574,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
563574
} else {
564575
aliasReplacedOrder.asInstanceOf[Seq[SortOrder]]
565576
}
566-
val normalizedOrders = if (sHolder.isJoinPushed) {
577+
val normalizedOrders = if (sHolder.joinedRelations.nonEmpty) {
567578
DataSourceStrategy.normalizeExprs(
568579
newOrder, sHolder.output).asInstanceOf[Seq[SortOrder]]
569580
} else {
@@ -704,8 +715,6 @@ case class ScanBuilderHolder(
704715

705716
var joinedRelations: Seq[DataSourceV2RelationBase] = Seq()
706717

707-
var isJoinPushed: Boolean = false
708-
709718
var exprIdToOriginalName: scala.collection.mutable.Map[ExprId, String] =
710719
scala.collection.mutable.Map.empty[ExprId, String]
711720
}

0 commit comments

Comments
 (0)