Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3dd17a
SPARK-47217. bug fix for exception thrown in reused dataframes involv…
Feb 29, 2024
c29366f
SPARK-47217. fix test failures
Mar 1, 2024
31d66c2
SPARK-47217. fix style format issue
Mar 1, 2024
127016c
SPARK-47217 : Fixing tests and code to try and resolve ambiguity in s…
Mar 6, 2024
b8e369c
SPARK-47217 : Fix unused import issue
Mar 6, 2024
872fece
SPARK-47217 : fixed bug and made assertions in existing tests for cor…
Mar 6, 2024
4b13514
Merge branch 'master' into SPARK-47217
Mar 6, 2024
8ed6aa4
SPARK-47217 : added more assetions
Mar 7, 2024
6b3b1d4
SPARK-47217 : fixed a bug and uncommented tests which were inadverten…
Mar 7, 2024
f9653ec
SPARK-47217 : added more tests and fixed inconsistency
Mar 7, 2024
68c2ad3
Merge branch 'master' into SPARK-47217
Mar 8, 2024
7150c98
SPARK-47217 : fixed test failure
Mar 8, 2024
6e03ae0
SPARK-47320 : incorporate review comments
Mar 11, 2024
d2175d4
SPARK-47320: reverting earlier change and refactoring
Mar 14, 2024
f3d280e
SPARK-47320: refactoring
Mar 14, 2024
ab14974
SPARK-47320: fixed test failure
Mar 14, 2024
bf1cd92
SPARK-47320: fixed test failure
Mar 15, 2024
611847e
SPARK-47320: refcatored code
Mar 15, 2024
4501ae5
SPARK-47320: removed dead code
Mar 15, 2024
3b8383d
SPARK-47320: refactored the code to remove UnresolvedAttributeWithTag…
Mar 16, 2024
f78eaaa
SPARK-47320: removed dead code
Mar 16, 2024
11bc231
SPARK-47320: fixed pyspark failures
Mar 16, 2024
01a4074
Merge branch 'apache:master' into SPARK-47320
ahshahid Mar 18, 2024
ceb98f5
Merge branch 'master' into SPARK-47320
Mar 18, 2024
3619857
Merge branch 'SPARK-47320' of https://github.com/ahshahid/spark into …
Mar 18, 2024
03149d5
SPARK-47320. Modified the code to ensure that for unambiguous attribu…
Mar 29, 2024
77b3201
SPARK-47320. cleaned up the code, made the addition of DataSet_ID_Tag…
Mar 29, 2024
5bdd4fd
Merge branch 'apache:master' into SPARK-47320
ahshahid Apr 5, 2024
525236d
Merge branch 'apache:master' into SPARK-47320
ahshahid Apr 6, 2024
fc0b3b1
Merge branch 'master' into SPARK-47320
Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.MetadataBuilder

trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {

Expand Down Expand Up @@ -518,26 +519,126 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
case _ => e
}

private def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = {
val metadataWithoutId = new MetadataBuilder()
.withMetadata(a.metadata)
.remove(LogicalPlan.DATASET_ID_KEY)
.remove(LogicalPlan.COL_POS_KEY)
.build()
a.withMetadata(metadataWithoutId)
}

private def resolveUsingDatasetId(
ua: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan,
datasetId: Long): Option[NamedExpression] = {
def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = {
var currentLp = lp
var depth = 0
while (true) {
if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(
_.contains(datasetId))) {
return Option(currentLp, depth)
} else {
if (currentLp.children.size == 1) {
currentLp = currentLp.children.head
} else {
// leaf node or node is a binary node
return None
}
}
depth += 1
}
None
}

val leftDefOpt = findUnaryNodeMatchingTagId(left)
val rightDefOpt = findUnaryNodeMatchingTagId(right)
val resolveOnAttribs = (leftDefOpt, rightDefOpt) match {

case (None, Some((lp, _))) => lp.output

case (Some((lp, _)), None) => lp.output

case (Some((lp1, depth1)), Some((lp2, depth2))) => if (depth1 == depth2) {
Seq.empty
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}

case _ => Seq.empty
}
if (resolveOnAttribs.isEmpty) {
None
} else {
AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(ua.name), conf.resolver)
}
}

private def resolveDataFrameColumn(
u: UnresolvedAttribute,
q: Seq[LogicalPlan]): Option[NamedExpression] = {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) return None
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")

val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty

val (resolved, matched) = resolveDataFrameColumnByPlanId(
u, planId, isMetadataAccess, q, 0)
if (!matched) {
// Can not find the target plan node with plan id, e.g.
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
val origAttrOpt = u.getTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG)
val resolvedOptWithDatasetId = if (origAttrOpt.isDefined) {
val md = origAttrOpt.get.metadata
if (md.contains(LogicalPlan.DATASET_ID_KEY)) {
val did = md.getLong(LogicalPlan.DATASET_ID_KEY)
val resolved = if (q.size == 1) {
val binaryNodeOpt = q.head.collectFirst {
case bn: BinaryNode => bn
}
binaryNodeOpt.flatMap(bn => resolveUsingDatasetId(u, bn.left, bn.right, did))
} else if (q.size == 2) {
resolveUsingDatasetId(u, q(0), q(1), did)
} else {
None
}
if (resolved.isEmpty) {
if (conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
origAttrOpt
} else {
origAttrOpt.map(stripColumnReferenceMetadata)
}
} else {
resolved
}
} else {
origAttrOpt
}
} else {
None
}
val resolvedOpt = if (resolvedOptWithDatasetId.isDefined) {
resolvedOptWithDatasetId
}
else {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) {
None
} else {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) return None
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")

val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty

val (resolved, matched) = resolveDataFrameColumnByPlanId(
u, planId, isMetadataAccess, q, 0)
if (!matched) {
// Can not find the target plan node with plan id, e.g.
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
}
resolved.map(_._1)
}
}
resolved.map(_._1)
resolvedOpt
}

private def resolveDataFrameColumnByPlanId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
Expand All @@ -30,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.StructType


abstract class LogicalPlan
extends QueryPlan[LogicalPlan]
with AnalysisHelper
Expand Down Expand Up @@ -199,6 +200,10 @@ object LogicalPlan {
// to the old code path.
private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col")
private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")
private[spark] val UNRESOLVED_ATTRIBUTE_MD_TAG = TreeNodeTag[AttributeReference]("orig-attr")
private[spark] val DATASET_ID_KEY = "__dataset_id"
private[spark] val COL_POS_KEY = "__col_position"
}

/**
Expand Down
107 changes: 91 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.Dataset.{DATASET_ID_KEY, DATASET_ID_TAG}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
Expand All @@ -47,7 +48,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand All @@ -71,9 +72,9 @@ import org.apache.spark.util.Utils

private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
val DATASET_ID_KEY = LogicalPlan.DATASET_ID_KEY
val COL_POS_KEY = LogicalPlan.COL_POS_KEY
val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG

def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
Expand Down Expand Up @@ -222,11 +223,9 @@ class Dataset[T] private[sql](

@transient private[sql] val logicalPlan: LogicalPlan = {
val plan = queryExecution.commandExecuted
if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
plan
}

Expand Down Expand Up @@ -1146,9 +1145,8 @@ class Dataset[T] private[sql](
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
tryAmbiguityResolution(right, joinExprs, joinType)
).queryExecution.analyzed.asInstanceOf[Join]

// If auto self join alias is disabled, return the plan.
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
Expand All @@ -1169,14 +1167,44 @@ class Dataset[T] private[sql](
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan)
}

private def tryAmbiguityResolution(
right: Dataset[_],
joinExprs: Option[Column],
joinType: String) = {
val planPart1 = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), None, JoinHint.NONE)).queryExecution.analyzed.asInstanceOf[Join]

val leftTagIdMap = planPart1.left.getTagValue(DATASET_ID_TAG)
val rightTagIdMap = planPart1.right.getTagValue(DATASET_ID_TAG)

val joinExprsRectified = joinExprs.map(_.expr transformUp {
case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) =>
// For attribute to remain attribute and not to UnResolved, only one leg should be tru
val leftLegWrong = isIncorrectlyResolved(attr, planPart1.left.outputSet,
leftTagIdMap.getOrElse(HashSet.empty[Long]))
val rightLegWrong = isIncorrectlyResolved(attr, planPart1.right.outputSet,
rightTagIdMap.getOrElse(HashSet.empty[Long]))
if (!planPart1.outputSet.contains(attr) || leftLegWrong || rightLegWrong) {
val ua = UnresolvedAttribute(Seq(attr.name))
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr)
ua
} else {
attr
}
})
Join(planPart1.left, planPart1.right, JoinType(joinType), joinExprsRectified, JoinHint.NONE)
}

/**
* Join with another `DataFrame`, using the given join expression. The following performs
* a full outer join between `df1` and `df2`.
*
* {{{
* // Scala:
* import org.apache.spark.sql.functions._
* df1.join(df2, $"df1Key" === $"df2Key", "outer")
* df1.join(df2, $"df1Key" === $"df2Key", "outer"
*
* // Java:
* import static org.apache.spark.sql.functions.*;
Expand Down Expand Up @@ -1305,11 +1333,23 @@ class Dataset[T] private[sql](
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.left.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
val ua = UnresolvedAttribute(Seq(a.name))
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a)
ua
}
val rightAsOfExpr = rightAsOf.expr.transformUp {
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.right.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
val ua = UnresolvedAttribute(Seq(a.name))
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a)
ua
}
withPlan {
AsOfJoin(
Expand Down Expand Up @@ -1482,8 +1522,8 @@ class Dataset[T] private[sql](
// `DetectAmbiguousSelfJoin` will remove it.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
case a: AttributeReference =>
// if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
Expand Down Expand Up @@ -1573,7 +1613,17 @@ class Dataset[T] private[sql](

case other => other
}
Project(untypedCols.map(_.named), logicalPlan)
val inputForProj = logicalPlan.outputSet
val namedExprs = untypedCols.map(ne => (ne.named transformUp {
case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) &&
(!inputForProj.contains(attr) ||
isIncorrectlyResolved(attr, inputForProj, HashSet(id))) =>
val ua = UnresolvedAttribute(Seq(attr.name))
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr)
ua
}).asInstanceOf[NamedExpression])
Project(namedExprs, logicalPlan)
}

/**
Expand Down Expand Up @@ -4221,6 +4271,31 @@ class Dataset[T] private[sql](
queryExecution.analyzed.semanticHash()
}

private def isIncorrectlyResolved(
attr: AttributeReference,
input: AttributeSet,
dataSetIdOfInput: HashSet[Long]): Boolean = {
val attrDatasetIdOpt = if (attr.metadata.contains(DATASET_ID_KEY)) {
Option(attr.metadata.getLong(DATASET_ID_KEY))
} else {
None
}
attrDatasetIdOpt.forall(attrId => {
val matchingInputset = input.filter(_.canonicalized == attr.canonicalized)
if (matchingInputset.isEmpty) {
true
} else {
matchingInputset.forall(x => {
if (x.metadata.contains(DATASET_ID_KEY)) {
attrId != x.metadata.getLong(DATASET_ID_KEY)
} else {
!dataSetIdOfInput.contains(attrId)
}
})
}
})
}

////////////////////////////////////////////////////////////////////////////
// For Python API
////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.plans.logical.AsOfJoin
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -173,4 +174,23 @@ class DataFrameAsOfJoinSuite extends QueryTest
)
)
}

test("SPARK-47217: Dedup of relations can impact projected columns resolution") {
val (df1, df2) = prepareForAsOfJoin()
val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"),
df2.col("b"), df1.col("a").as("aa"))

// In stock spark this would throw ambiguous column exception, even though it is not ambiguous
val asOfjoin2 = join1.joinAsOf(
df1, df1.col("a"), join1.col("a"), usingColumns = Seq.empty,
joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest")

asOfjoin2.queryExecution.assertAnalyzed()

val testDf = asOfjoin2.select(df1.col("a"))
val analyzed = testDf.queryExecution.analyzed
val attributeRefToCheck = analyzed.output.head
assert(analyzed.children(0).asInstanceOf[AsOfJoin].right.outputSet.
contains(attributeRefToCheck))
}
}
Loading